Skip to content

Commit 5ee520d

Browse files
committed
Refactor init_vector to avoid unnecessary repeated checking of function return arguments
1 parent 6047b35 commit 5ee520d

File tree

1 file changed

+17
-21
lines changed

1 file changed

+17
-21
lines changed

fidimag/common/helper.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,47 +25,43 @@ def normalise(a):
2525

2626

2727
def init_vector(m0, mesh, norm=False, *args):
28-
2928
n = mesh.n
30-
31-
spin = np.zeros((n, 3))
29+
field = np.zeros((n, 3))
3230

3331
if isinstance(m0, list) or isinstance(m0, tuple):
34-
spin[:, :] = m0
35-
spin = np.reshape(spin, 3 * n, order='C')
32+
field[:, :] = m0
33+
field = np.reshape(field, 3 * n, order='C')
3634

3735
elif hasattr(m0, '__call__'):
36+
# Check only once that the function returns appropriately...
37+
v = m0(mesh.coordinates[0], *args)
38+
if len(v) != 3:
39+
raise Exception(
40+
'The length of the value in init_vector method must be 3.')
3841
for i in range(n):
39-
v = m0(mesh.coordinates[i], *args)
40-
if len(v) != 3:
41-
raise Exception(
42-
'The length of the value in init_vector method must be 3.')
43-
spin[i, :] = v[:]
44-
spin = np.reshape(spin, 3 * n, order='C')
42+
field[i, :] = m0(mesh.coordinates[i], *args)
43+
field = np.reshape(field, 3 * n, order='C')
4544

4645
elif isinstance(m0, np.ndarray):
4746
if m0.shape == (3, ):
48-
spin[:] = m0 # broadcasting
47+
field[:] = m0 # broadcasting
4948
else:
50-
spin.shape = (-1)
51-
spin[:] = m0 # overwriting the whole thing
52-
53-
spin.shape = (-1,)
54-
49+
field.shape = (-1)
50+
field[:] = m0 # overwriting the whole thing
51+
field.shape = (-1,)
5552
if norm:
56-
normalise(spin)
53+
normalise(field)
5754

58-
return spin
55+
return field
5956

6057

6158
def init_scalar(value, mesh, *args):
62-
6359
n = mesh.n
64-
6560
mesh_v = np.zeros(n)
6661

6762
if isinstance(value, (int, float)):
6863
mesh_v[:] = value
64+
6965
elif hasattr(value, '__call__'):
7066
for i in range(n):
7167
mesh_v[i] = value(mesh.coordinates[i], *args)

0 commit comments

Comments
 (0)