Skip to content

Commit 802e203

Browse files
Merge pull request #116 from computationalmodelling/helper-cython
Refactor Helper functions and Zeeman for better performance
2 parents 5ee520d + df63f65 commit 802e203

File tree

12 files changed

+229
-96
lines changed

12 files changed

+229
-96
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
*.swp
66
*.pdf
77
*.tmp
8+
*.so
89
/local/
910

1011
# ignore automatically generated cython files

doc/core_eqs.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ described at a semi-classical level.
1818

1919
Interactions between magnetic moments are specified using the Heisenberg
2020
formalism.
21-
|
2221

2322
* Micromagnetics
2423

doc/index.rst

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ The code for Fidimag is available under an open source license on `GitHub
1212
<https://github.com/computationalmodelling/fidimag>`_.
1313

1414
Contents:
15-
1615
.. toctree::
1716
:maxdepth: 2
1817
:caption: Installation Instructions
@@ -31,6 +30,7 @@ Contents:
3130
ipynb/isolated_skyrmion
3231
ipynb/spin-polarised-current-driven-skyrmion
3332
ipynb/spin-waves-in-periodic-system
33+
ipynb/FMR-stdprob
3434

3535
.. toctree::
3636
:maxdepth: 2
@@ -46,4 +46,3 @@ Contents:
4646

4747
nebm
4848

49-
.. _GitHub

fidimag/common/helper.py

Lines changed: 6 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -8,71 +8,7 @@
88
from mpl_toolkits.mplot3d import Axes3D
99
from matplotlib.colors import colorConverter
1010
from matplotlib.collections import PolyCollection, LineCollection
11-
12-
13-
def normalise(a):
14-
"""
15-
normalise the given array a
16-
"""
17-
a.shape = (-1, 3)
18-
b = np.sqrt(a[:, 0] ** 2 + a[:, 1] ** 2 + a[:, 2] ** 2)
19-
ids = (b == 0)
20-
b[ids] = 1.0
21-
a[:, 0] /= b
22-
a[:, 1] /= b
23-
a[:, 2] /= b
24-
a.shape = (-1,)
25-
26-
27-
def init_vector(m0, mesh, norm=False, *args):
28-
n = mesh.n
29-
field = np.zeros((n, 3))
30-
31-
if isinstance(m0, list) or isinstance(m0, tuple):
32-
field[:, :] = m0
33-
field = np.reshape(field, 3 * n, order='C')
34-
35-
elif hasattr(m0, '__call__'):
36-
# Check only once that the function returns appropriately...
37-
v = m0(mesh.coordinates[0], *args)
38-
if len(v) != 3:
39-
raise Exception(
40-
'The length of the value in init_vector method must be 3.')
41-
for i in range(n):
42-
field[i, :] = m0(mesh.coordinates[i], *args)
43-
field = np.reshape(field, 3 * n, order='C')
44-
45-
elif isinstance(m0, np.ndarray):
46-
if m0.shape == (3, ):
47-
field[:] = m0 # broadcasting
48-
else:
49-
field.shape = (-1)
50-
field[:] = m0 # overwriting the whole thing
51-
field.shape = (-1,)
52-
if norm:
53-
normalise(field)
54-
55-
return field
56-
57-
58-
def init_scalar(value, mesh, *args):
59-
n = mesh.n
60-
mesh_v = np.zeros(n)
61-
62-
if isinstance(value, (int, float)):
63-
mesh_v[:] = value
64-
65-
elif hasattr(value, '__call__'):
66-
for i in range(n):
67-
mesh_v[i] = value(mesh.coordinates[i], *args)
68-
69-
elif isinstance(value, np.ndarray):
70-
if value.shape == mesh_v.shape:
71-
mesh_v[:] = value[:]
72-
else:
73-
raise ValueError("Array size must match the mesh size")
74-
75-
return mesh_v
11+
from fidimag.extensions.common_clib import normalise, init_scalar, init_vector, init_vector_func_fast
7612

7713

7814
def extract_data(mesh, npys, pos, comp='x'):
@@ -128,15 +64,15 @@ def plot_m(mesh, npy, comp='x', based=None):
12864
data = data - based
12965

13066
data.shape = (-1, 3)
131-
m = data[:,cmpi]
67+
m = data[:, cmpi]
13268

13369
nx = mesh.nx
13470
ny = mesh.ny
13571
nz = mesh.nz
13672

13773
m.shape = (nz, ny, nx)
13874

139-
m2 = m[0,:,:]
75+
m2 = m[0, :, :]
14076

14177
fig = plt.figure()
14278
# norm=color.Normalize(-1,1)
@@ -207,7 +143,7 @@ def plot_energy_3d(name, key_steps=50, filename=None):
207143
if each_n_step < 1:
208144
each_n_step = 1
209145

210-
cc = lambda arg: colorConverter.to_rgba(arg, alpha=0.6)
146+
def cc(arg): return colorConverter.to_rgba(arg, alpha=0.6)
211147
colors = [cc('r'), cc('g'), cc('b'), cc('y')]
212148
facecolors = []
213149
line_data = []
@@ -241,5 +177,6 @@ def plot_energy_3d(name, key_steps=50, filename=None):
241177

242178

243179
def compute_RxRy(mesh, spin, nx_start=0, nx_stop=-1, ny_start=0, ny_stop=-1):
244-
res = clib.compute_RxRy(spin, mesh.nx, mesh.ny, mesh.nz, nx_start, nx_stop, ny_start, ny_stop)
180+
res = clib.compute_RxRy(spin, mesh.nx, mesh.ny,
181+
mesh.nz, nx_start, nx_stop, ny_start, ny_stop)
245182
return res

fidimag/common/lib/common_clib.pyx

Lines changed: 99 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
cimport numpy as np
2+
import numpy as np
13

24
# -----------------------------------------------------------------------------
35

@@ -33,7 +35,7 @@ cdef extern from "common_clib.h":
3335
double *mxH, double *mxmxH, double *mxmxH_last, double *tau,
3436
int* pins, int n)
3537

36-
void sd_compute_step (double *spin, double *spin_last, double *magnetisation,
38+
void sd_compute_step (double *spin, double *spin_last, double *magnetisation,
3739
double *field,
3840
double *mxH, double *mxmxH, double *mxmxH_last, double *tau,
3941
int *pins, int n, int counter, double tmin, double tmax)
@@ -128,8 +130,103 @@ def compute_sd_step(double [:] spin,
128130
int [:] pins,
129131
n, counter, tmin, tmax):
130132

131-
sd_compute_step(&spin[0], &spin_last[0], &magnetisation[0],
133+
sd_compute_step(&spin[0], &spin_last[0], &magnetisation[0],
132134
&field[0], &mxH[0],
133135
&mxmxH[0], &mxmxH_last[0], &tau[0], &pins[0],
134136
n, counter, tmin, tmax
135137
)
138+
139+
def normalise(a):
140+
"""
141+
normalise the given array a
142+
"""
143+
a.shape = (-1, 3)
144+
b = np.sqrt(a[:, 0] ** 2 + a[:, 1] ** 2 + a[:, 2] ** 2)
145+
ids = (b == 0)
146+
b[ids] = 1.0
147+
a[:, 0] /= b
148+
a[:, 1] /= b
149+
a[:, 2] /= b
150+
a.shape = (-1,)
151+
152+
def init_scalar(value, mesh, *args):
153+
154+
n = mesh.n
155+
156+
mesh_v = np.zeros(n)
157+
158+
if isinstance(value, (int, float)):
159+
mesh_v[:] = value
160+
elif hasattr(value, '__call__'):
161+
for i in range(n):
162+
mesh_v[i] = value(mesh.coordinates[i], *args)
163+
164+
elif isinstance(value, np.ndarray):
165+
if value.shape == mesh_v.shape:
166+
mesh_v[:] = value[:]
167+
else:
168+
raise ValueError("Array size must match the mesh size")
169+
170+
return mesh_v
171+
172+
def init_vector(m0, mesh, norm=False, *args):
173+
174+
n = mesh.n
175+
176+
spin = np.zeros((n, 3))
177+
178+
if isinstance(m0, list) or isinstance(m0, tuple):
179+
spin[:, :] = m0
180+
spin = np.reshape(spin, 3 * n, order='C')
181+
182+
elif hasattr(m0, '__call__'):
183+
v = m0(mesh.coordinates[0], *args)
184+
if len(v) != 3:
185+
raise Exception(
186+
'The length of the value in init_vector method must be 3.')
187+
for i in range(n):
188+
spin[i, :] = m0(mesh.coordinates[i], *args)
189+
spin = np.reshape(spin, 3 * n, order='C')
190+
191+
elif isinstance(m0, np.ndarray):
192+
if m0.shape == (3, ):
193+
spin[:] = m0 # broadcasting
194+
else:
195+
spin.shape = (-1)
196+
spin[:] = m0 # overwriting the whole thing
197+
198+
spin.shape = (-1,)
199+
200+
if norm:
201+
normalise(spin)
202+
203+
return spin
204+
205+
def init_vector_func_fast(m0, mesh, double[:] field, norm=False, *args):
206+
"""
207+
An unsafe method of setting the field. Depends on
208+
the setter code being memory safe.
209+
210+
m0 must be a Python function that takes the mesh
211+
and field as arguments. Within that, the user
212+
must handle evaluating the function at different
213+
coordinate points. It needs to be able to handle
214+
the spatial dependence itself, and write the
215+
field valuse into the field array. This can
216+
be written with Cython which will give much
217+
better performance. For example:
218+
219+
from libc.math cimport sin
220+
221+
def fast_sin_init(mesh, double[:] field, *params):
222+
t, axis, Bmax, fc = params
223+
for i in range(mesh.n):
224+
field[3*i+0] = Bmax * axis[0] * sin(fc*t)
225+
field[3*i+1] = Bmax * axis[1] * sin(fc*t)
226+
field[3*i+2] = Bmax * axis[2] * sin(fc*t)
227+
228+
"""
229+
m0(mesh, field, *args)
230+
if norm:
231+
normalise(field)
232+
return field

fidimag/micro/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,3 @@
77
from .dmi import DMI
88
from .baryakhtar import LLBar, LLBarFull
99
from .simple_demag import SimpleDemag
10-

fidimag/micro/zeeman.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
2-
32
from fidimag.common.constant import mu_0
43
import fidimag.common.helper as helper
4+
import inspect
55

66

77
class Zeeman(object):
@@ -115,34 +115,19 @@ class TimeZeeman(Zeeman):
115115

116116
"""
117117
The time dependent external field, also can vary with space
118-
119-
The function time_fun must be a function which takes two arguments:
120-
121-
def time_fun(pos, t):
122-
x, y, z = pos
123-
# compute Bx, By, Bz as a function of x y, z and t.
124-
Bx = ...
125-
By = ...
126-
Bz = ...
127-
return (Bx, By, Bz)
128-
129-
130118
"""
131119

132-
def __init__(self, H0, time_fun, name='TimeZeeman'):
133-
120+
def __init__(self, H0, time_fun, extra_args=[], name='TimeZeeman'):
134121
self.H0 = H0
135122
self.time_fun = time_fun
136123
self.name = name
137124
self.jac = True
125+
self.extra_args = extra_args
138126

139127
def setup(self, mesh, spin, Ms):
140128
super(TimeZeeman, self).setup(mesh, spin, Ms)
141129
self.H_init = self.field.copy()
142130

143131
def compute_field(self, t=0, spin=None):
144-
self.field[:] = helper.init_vector(self.time_fun,
145-
self.mesh,
146-
False,
147-
t)
132+
self.field[:] = self.H_init[:] * self.time_fun(t, *self.extra_args)
148133
return self.field

fidimag/user/README.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# User Extensions
2+
3+
The user extensions directory is here to allow for
4+
the user to be able to straightforwardly run compiled
5+
code within Fidimag for performance reasons. We
6+
consider this an advanced feature and do not recommend
7+
trying this unless you have experience writing and
8+
building C/Cython programs.
9+
10+
Some of the energy classes perform callbacks to
11+
user-supplied functions. Performance for this
12+
is generally poor, as there is an overhead to
13+
calling Python functions repeatedly. Hence,
14+
we place this folder here to allow you to expose
15+
functions written in Cython/C conveniently.
16+
17+
An example has been supplied. We suggest copying
18+
the folder and modifying each of the files in
19+
this. Please note that we have automated the
20+
building of the extensions, but you can only
21+
have a single Cython .pyx file per directory,
22+
because a single Cython module is created in
23+
each folder. The module that is created will have
24+
the name of this file.
25+
26+
You do not explicitly need to write an __init__.py file;
27+
your extension will be importable immediately from
28+
fidimag.extensions.user.$FOLDERNAME
29+
30+
The __init__.py file lets you do the slightly shorter:
31+
from fidimag.user.$FOLDERNAME import *
32+

fidimag/user/example/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from fidimag.extensions.user.example import *

fidimag/user/example/example.pyx

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from libc.math cimport cos, sin
2+
3+
def fast_sin_init(mesh, double[:] field, *params):
4+
t, axis, Bmax, fc = params
5+
for i in range(mesh.n):
6+
field[3*i+0] = Bmax * axis[0] * sin(fc*t)
7+
field[3*i+1] = Bmax * axis[1] * sin(fc*t)
8+
field[3*i+2] = Bmax * axis[2] * sin(fc*t)
9+
10+
def TimeZeemanFast_test_time_fun(mesh, double[:] field, *params):
11+
cdef int i
12+
t, frequency = params
13+
for i in range(mesh.n):
14+
field[3*i+0] = 0
15+
field[3*i+1] = 0
16+
field[3*i+2] = 10 * cos(frequency * t)

0 commit comments

Comments
 (0)