Skip to content

Commit c1a31a9

Browse files
committed
api: dix sympy assumptions for complex valued objects
1 parent 601d99d commit c1a31a9

File tree

6 files changed

+198
-15
lines changed

6 files changed

+198
-15
lines changed

devito/mpi/routines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def _make_bundles(self, hs):
306306
except ValueError:
307307
for i in candidates:
308308
name = "bag_%s" % i.name
309-
bag = Bag(name=name, components=i)
309+
bag = Bag(name=name, components=(i,))
310310
halo_scheme = halo_scheme.add(bag, hse)
311311

312312
hs = hs._rebuild(halo_scheme=halo_scheme)

devito/passes/iet/misc.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,11 @@ def _(expr, lang):
235235

236236
@_lower_macro_math.register(SafeInv)
237237
def _(expr, lang):
238-
eps = np.finfo(expr.base.dtype).resolution**2
238+
try:
239+
eps = np.finfo(expr.base.dtype).resolution**2
240+
except ValueError:
241+
print(f"Warning: dtype not recognized in SafeInv for {expr.base}")
242+
eps = np.finfo(np.float32).resolution**2
239243
b = Cast('b', dtype=np.float32)
240244
return (('SAFEINV(a, b)',
241245
f'(((a) < {eps}F || ({b}) < {eps}F) ? (0.0F) : ((1.0F) / (a)))'),), {}

devito/types/basic.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -349,8 +349,6 @@ class AbstractSymbol(sympy.Symbol, Basic, Pickable, Evaluable):
349349
is_Symbol = True
350350

351351
# SymPy default assumptions
352-
is_real = True
353-
is_imaginary = False
354352
is_commutative = True
355353

356354
__rkwargs__ = ('name', 'dtype', 'is_const')
@@ -411,6 +409,12 @@ def _hashable_content(self):
411409
def dtype(self):
412410
return self._dtype
413411

412+
def _eval_is_real(self):
413+
return not self.is_imaginary
414+
415+
def _eval_is_imaginary(self):
416+
return np.iscomplexobj(self.dtype(0))
417+
414418
@property
415419
def indices(self):
416420
return ()
@@ -859,7 +863,6 @@ class AbstractFunction(sympy.Function, Basic, Pickable, Evaluable):
859863
is_AbstractFunction = True
860864

861865
# SymPy default assumptions
862-
is_imaginary = False
863866
is_commutative = True
864867

865868
# Devito default assumptions
@@ -888,6 +891,8 @@ def __new__(cls, *args, **kwargs):
888891
# Extract the `indices`, as perhaps they're explicitly provided
889892
dimensions, indices = cls.__indices_setup__(*args, **kwargs)
890893

894+
# Sympy assumptions
895+
891896
# If it's an alias or simply has a different name, ignore `function`.
892897
# These cases imply the construction of a new AbstractFunction off
893898
# an existing one! This relieves the pressure on the caller by not
@@ -955,6 +960,8 @@ def _sympystr(self, printer, **kwargs):
955960
return str(self)
956961

957962
_latex = _sympystr
963+
_eval_is_real = AbstractSymbol._eval_is_real
964+
_eval_is_imaginary = AbstractSymbol._eval_is_imaginary
958965

959966
def _pretty(self, printer, **kwargs):
960967
return printer._print_Function(self, func_name=self.name)
@@ -1315,10 +1322,6 @@ def is_const(self):
13151322
def is_transient(self):
13161323
return self._is_transient
13171324

1318-
@property
1319-
def is_real(self):
1320-
return not np.iscomplex(self.dtype(0))
1321-
13221325
@property
13231326
def is_persistent(self):
13241327
"""

examples/seismic/tutorials/17_fourier_mode.ipynb

Lines changed: 152 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,152 @@
191191
"cell_type": "code",
192192
"execution_count": 9,
193193
"metadata": {},
194+
"outputs": [
195+
{
196+
"name": "stdout",
197+
"output_type": "stream",
198+
"text": [
199+
"#define _POSIX_C_SOURCE 200809L\n",
200+
"#define START(S) struct timeval start_ ## S , end_ ## S ; gettimeofday(&start_ ## S , NULL);\n",
201+
"#define STOP(S,T) gettimeofday(&end_ ## S, NULL); T->S += (double)(end_ ## S .tv_sec-start_ ## S.tv_sec)+(double)(end_ ## S .tv_usec-start_ ## S .tv_usec)/1000000;\n",
202+
"#define MAX(a,b) (((a) > (b)) ? (a) : (b))\n",
203+
"\n",
204+
"#include \"stdlib.h\"\n",
205+
"#include \"math.h\"\n",
206+
"#include \"sys/time.h\"\n",
207+
"#include \"omp.h\"\n",
208+
"#include \"complex.h\"\n",
209+
"\n",
210+
"struct dataobj\n",
211+
"{\n",
212+
" void *restrict data;\n",
213+
" unsigned long * size;\n",
214+
" unsigned long * npsize;\n",
215+
" unsigned long * dsize;\n",
216+
" int * hsize;\n",
217+
" int * hofs;\n",
218+
" int * oofs;\n",
219+
" void * dmap;\n",
220+
"} ;\n",
221+
"\n",
222+
"struct profiler\n",
223+
"{\n",
224+
" double section0;\n",
225+
" double section1;\n",
226+
" double section2;\n",
227+
"} ;\n",
228+
"\n",
229+
"\n",
230+
"int Kernel(struct dataobj *restrict damp_vec, struct dataobj *restrict freq_modes_vec, struct dataobj *restrict rec_vec, struct dataobj *restrict rec_coords_vec, struct dataobj *restrict src_vec, struct dataobj *restrict src_coords_vec, struct dataobj *restrict u_vec, struct dataobj *restrict vp_vec, const int x_M, const int x_m, const int y_M, const int y_m, const float dt, const float o_x, const float o_y, const int p_rec_M, const int p_rec_m, const int p_src_M, const int p_src_m, const int time_M, const int time_m, const int nthreads, const int nthreads_nonaffine, struct profiler * timers)\n",
231+
"{\n",
232+
" float (*restrict damp)[damp_vec->size[1]] __attribute__ ((aligned (64))) = (float (*)[damp_vec->size[1]]) damp_vec->data;\n",
233+
" float _Complex (*restrict freq_modes)[freq_modes_vec->size[1]] __attribute__ ((aligned (64))) = (float _Complex (*)[freq_modes_vec->size[1]]) freq_modes_vec->data;\n",
234+
" float (*restrict rec)[rec_vec->size[1]] __attribute__ ((aligned (64))) = (float (*)[rec_vec->size[1]]) rec_vec->data;\n",
235+
" float (*restrict rec_coords)[rec_coords_vec->size[1]] __attribute__ ((aligned (64))) = (float (*)[rec_coords_vec->size[1]]) rec_coords_vec->data;\n",
236+
" float (*restrict src)[src_vec->size[1]] __attribute__ ((aligned (64))) = (float (*)[src_vec->size[1]]) src_vec->data;\n",
237+
" float (*restrict src_coords)[src_coords_vec->size[1]] __attribute__ ((aligned (64))) = (float (*)[src_coords_vec->size[1]]) src_coords_vec->data;\n",
238+
" float (*restrict u)[u_vec->size[1]][u_vec->size[2]] __attribute__ ((aligned (64))) = (float (*)[u_vec->size[1]][u_vec->size[2]]) u_vec->data;\n",
239+
" float (*restrict vp)[vp_vec->size[1]] __attribute__ ((aligned (64))) = (float (*)[vp_vec->size[1]]) vp_vec->data;\n",
240+
"\n",
241+
" float _Complex r2 = 1.0F/(dt*dt);\n",
242+
" float _Complex r3 = 1.0F/dt;\n",
243+
"\n",
244+
" for (int time = time_m, t0 = (time)%(3), t1 = (time + 2)%(3), t2 = (time + 1)%(3); time <= time_M; time += 1, t0 = (time)%(3), t1 = (time + 2)%(3), t2 = (time + 1)%(3))\n",
245+
" {\n",
246+
" float _Complex r1 = cexpf(6.28318530717959e-2F*time*_Complex_I*dt);\n",
247+
" START(section0)\n",
248+
" #pragma omp parallel num_threads(nthreads)\n",
249+
" {\n",
250+
" #pragma omp for schedule(dynamic,1)\n",
251+
" for (int x = x_m; x <= x_M; x += 1)\n",
252+
" {\n",
253+
" #pragma omp simd aligned(damp,freq_modes,u,vp:16)\n",
254+
" for (int y = y_m; y <= y_M; y += 1)\n",
255+
" {\n",
256+
" float _Complex r4 = 1.0F/(vp[x + 2][y + 2]*vp[x + 2][y + 2]);\n",
257+
" u[t2][x + 2][y + 2] = (-r4*(-2.0F*r2*u[t0][x + 2][y + 2] + r2*u[t1][x + 2][y + 2]) + r3*damp[x + 2][y + 2]*u[t0][x + 2][y + 2] + 1.0e-2F*(u[t0][x + 1][y + 2] + u[t0][x + 2][y + 1] + u[t0][x + 2][y + 3] + u[t0][x + 3][y + 2]) - 3.99999991e-2F*u[t0][x + 2][y + 2])/(r4*r2 + r3*damp[x + 2][y + 2]);\n",
258+
" freq_modes[x][y] += r1*u[t0][x + 2][y + 2];\n",
259+
" }\n",
260+
" }\n",
261+
" }\n",
262+
" STOP(section0,timers)\n",
263+
"\n",
264+
" START(section1)\n",
265+
" #pragma omp parallel num_threads(nthreads_nonaffine)\n",
266+
" {\n",
267+
" int chunk_size = (int)(MAX(1, (int)((1.0/3.0)*(p_src_M - p_src_m + 1)/nthreads_nonaffine)));\n",
268+
" #pragma omp for schedule(dynamic,chunk_size)\n",
269+
" for (int p_src = p_src_m; p_src <= p_src_M; p_src += 1)\n",
270+
" {\n",
271+
" for (int rsrcx = 0; rsrcx <= 1; rsrcx += 1)\n",
272+
" {\n",
273+
" for (int rsrcy = 0; rsrcy <= 1; rsrcy += 1)\n",
274+
" {\n",
275+
" int posx = (int)(floorf(1.0e-1*(-o_x + src_coords[p_src][0])));\n",
276+
" int posy = (int)(floorf(1.0e-1*(-o_y + src_coords[p_src][1])));\n",
277+
" float px = 1.0e-1F*(-o_x + src_coords[p_src][0]) - floorf(1.0e-1F*(-o_x + src_coords[p_src][0]));\n",
278+
" float py = 1.0e-1F*(-o_y + src_coords[p_src][1]) - floorf(1.0e-1F*(-o_y + src_coords[p_src][1]));\n",
279+
" if (rsrcx + posx >= x_m - 1 && rsrcy + posy >= y_m - 1 && rsrcx + posx <= x_M + 1 && rsrcy + posy <= y_M + 1)\n",
280+
" {\n",
281+
" float r0 = 3.06250F*(vp[posx + 2][posy + 2]*vp[posx + 2][posy + 2])*(rsrcx*px + (1 - rsrcx)*(1 - px))*(rsrcy*py + (1 - rsrcy)*(1 - py))*src[time][p_src];\n",
282+
" #pragma omp atomic update\n",
283+
" u[t2][rsrcx + posx + 2][rsrcy + posy + 2] += r0;\n",
284+
" }\n",
285+
" }\n",
286+
" }\n",
287+
" }\n",
288+
" }\n",
289+
" STOP(section1,timers)\n",
290+
"\n",
291+
" START(section2)\n",
292+
" #pragma omp parallel num_threads(nthreads_nonaffine)\n",
293+
" {\n",
294+
" int chunk_size = (int)(MAX(1, (int)((1.0/3.0)*(p_rec_M - p_rec_m + 1)/nthreads_nonaffine)));\n",
295+
" #pragma omp for schedule(dynamic,chunk_size)\n",
296+
" for (int p_rec = p_rec_m; p_rec <= p_rec_M; p_rec += 1)\n",
297+
" {\n",
298+
" float r7 = 1.0e-1F*(-o_x + rec_coords[p_rec][0]);\n",
299+
" float r5 = floorf(r7);\n",
300+
" int posx = (int)r5;\n",
301+
" float r8 = 1.0e-1F*(-o_y + rec_coords[p_rec][1]);\n",
302+
" float r6 = floorf(r8);\n",
303+
" int posy = (int)r6;\n",
304+
" float px = -r5 + r7;\n",
305+
" float py = -r6 + r8;\n",
306+
" float sum = 0.0F;\n",
307+
"\n",
308+
" for (int rrecx = 0; rrecx <= 1; rrecx += 1)\n",
309+
" {\n",
310+
" for (int rrecy = 0; rrecy <= 1; rrecy += 1)\n",
311+
" {\n",
312+
" if (rrecx + posx >= x_m - 1 && rrecy + posy >= y_m - 1 && rrecx + posx <= x_M + 1 && rrecy + posy <= y_M + 1)\n",
313+
" {\n",
314+
" sum += (rrecx*px + (1 - rrecx)*(1 - px))*(rrecy*py + (1 - rrecy)*(1 - py))*u[t2][rrecx + posx + 2][rrecy + posy + 2];\n",
315+
" }\n",
316+
" }\n",
317+
" }\n",
318+
"\n",
319+
" rec[time][p_rec] = sum;\n",
320+
" }\n",
321+
" }\n",
322+
" STOP(section2,timers)\n",
323+
" }\n",
324+
"\n",
325+
" return 0;\n",
326+
"}\n",
327+
"\n"
328+
]
329+
}
330+
],
331+
"source": [
332+
"#NBVAL_IGNORE_OUTPUT\n",
333+
"print(op)"
334+
]
335+
},
336+
{
337+
"cell_type": "code",
338+
"execution_count": 10,
339+
"metadata": {},
194340
"outputs": [
195341
{
196342
"name": "stderr",
@@ -203,14 +349,14 @@
203349
"data": {
204350
"text/plain": [
205351
"PerformanceSummary([(PerfKey(name='section0', rank=None),\n",
206-
" PerfEntry(time=0.01551399999999999, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
352+
" PerfEntry(time=0.016613000000000006, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
207353
" (PerfKey(name='section1', rank=None),\n",
208-
" PerfEntry(time=0.0023819999999999913, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
354+
" PerfEntry(time=0.002539999999999992, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
209355
" (PerfKey(name='section2', rank=None),\n",
210-
" PerfEntry(time=0.002333999999999994, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[]))])"
356+
" PerfEntry(time=0.0025529999999999923, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[]))])"
211357
]
212358
},
213-
"execution_count": 9,
359+
"execution_count": 10,
214360
"metadata": {},
215361
"output_type": "execute_result"
216362
}
@@ -222,7 +368,7 @@
222368
},
223369
{
224370
"cell_type": "code",
225-
"execution_count": 10,
371+
"execution_count": 11,
226372
"metadata": {},
227373
"outputs": [
228374
{
@@ -250,7 +396,7 @@
250396
},
251397
{
252398
"cell_type": "code",
253-
"execution_count": 11,
399+
"execution_count": 12,
254400
"metadata": {},
255401
"outputs": [],
256402
"source": [

tests/test_builtins.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,23 @@ def test_inner_sparse(self):
374374
term2 = np.inner(rec0.data.reshape(-1), rec1.data.reshape(-1))
375375
assert np.isclose(term1/term2 - 1, 0.0, rtol=0.0, atol=1e-5)
376376

377+
@pytest.mark.parametrize('dtype', [np.float32, np.complex64])
378+
def test_norm_dense(self, dtype):
379+
"""
380+
Test that norm produces the correct result against NumPy
381+
"""
382+
grid = Grid((101, 101), extent=(1000., 1000.))
383+
384+
f = Function(name='f', grid=grid, dtype=dtype)
385+
386+
f.data[:] = 1 + np.random.randn(*f.shape).astype(grid.dtype)
387+
if np.iscomplexobj(f.data):
388+
f.data[:] += 1j*np.random.randn(*f.shape).astype(grid.dtype)
389+
term1 = np.linalg.norm(f.data)
390+
term2 = norm(f)
391+
assert np.isreal(term2)
392+
assert np.isclose(term1/term2 - 1, 0.0, rtol=0.0, atol=1e-5)
393+
377394
def test_norm_sparse(self):
378395
"""
379396
Test that norm produces the correct result against NumPy

tests/test_symbolics.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,19 @@ def test_modified_sympy_assumptions():
112112
assert s2 == s1
113113

114114

115+
def test_real():
116+
for dtype in [np.float32, np.complex64]:
117+
c = Constant(name='c', dtype=dtype)
118+
assert c.is_real is not np.iscomplexobj(dtype(0))
119+
assert c.is_imaginary is np.iscomplexobj(dtype(0))
120+
f = Function(name='f', dtype=dtype, grid=Grid((11,)))
121+
assert f.is_real is not np.iscomplexobj(dtype(0))
122+
assert f.is_imaginary is np.iscomplexobj(dtype(0))
123+
s = dSymbol(name='s', dtype=dtype)
124+
assert s.is_real is not np.iscomplexobj(dtype(0))
125+
assert s.is_imaginary is np.iscomplexobj(dtype(0))
126+
127+
115128
def test_constant():
116129
c = Constant(name='c')
117130

0 commit comments

Comments
 (0)