Skip to content

Commit 5b93b20

Browse files
committed
compiler: prevent hosted per-thread arrays are dereferenced within partree at read
1 parent 7c4f88e commit 5b93b20

File tree

8 files changed

+42
-17
lines changed

8 files changed

+42
-17
lines changed

devito/finite_differences/differentiable.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,12 @@ def __rfloordiv__(self, other):
259259
from .elementary import floor
260260
return floor(other / self)
261261

262+
def safe_inv(self, ref, safe=False):
263+
if safe:
264+
return SafeInv(self, ref or self)
265+
else:
266+
return 1 / self
267+
262268
def __mod__(self, other):
263269
return Mod(self, other)
264270

devito/passes/clusters/implicit.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,11 @@ def reduce(m0, m1, edims, prefix):
259259
else:
260260
func = min
261261

262-
key = lambda i: i.indices[d]
262+
def key(i):
263+
try:
264+
return i.indices[d]
265+
except AttributeError:
266+
return i
263267

264268
mapper = {}
265269
for k, e in m1.items():

devito/passes/iet/parpragma.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,9 @@ def _make_parregion(self, partree, parrays):
317317
i = n.write
318318
if not (i.is_Array or i.is_TempFunction):
319319
continue
320+
elif partree.dim in i.dimensions:
321+
# Non-local Array (full iteration space): no need to vector-expand
322+
continue
320323
elif i in parrays:
321324
pi = parrays[i]
322325
else:

devito/symbolics/inspection.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from sympy import (Function, Indexed, Integer, Mul, Number,
55
Pow, S, Symbol, Tuple)
66
from sympy.core.numbers import ImaginaryUnit
7+
from sympy.core.function import Application
78

89
from devito.finite_differences import Derivative
910
from devito.finite_differences.differentiable import IndexDerivative
@@ -116,7 +117,7 @@ def estimate_cost(exprs, estimate=False):
116117
estimate_values = {
117118
'elementary': 100,
118119
'pow': 50,
119-
'SafeInv': 10,
120+
'SafeInv': 50,
120121
'div': 5,
121122
'Abs': 5,
122123
'floor': 1,
@@ -211,6 +212,7 @@ def _(expr, estimate, seen):
211212

212213

213214
@_estimate_cost.register(Function)
215+
@_estimate_cost.register(Application)
214216
def _(expr, estimate, seen):
215217
if q_routine(expr):
216218
flops, _ = zip(*[_estimate_cost(a, estimate, seen) for a in expr.args])
@@ -227,6 +229,7 @@ def _(expr, estimate, seen):
227229
flops += 1
228230
else:
229231
flops = 0
232+
230233
return flops, False
231234

232235

devito/types/basic.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -844,7 +844,7 @@ def __init_finalize__(self, *args, **kwargs):
844844

845845
# Averaging mode for off the grid evaluation
846846
self._avg_mode = kwargs.get('avg_mode', 'arithmetic')
847-
if self._avg_mode not in ['arithmetic', 'harmonic']:
847+
if self._avg_mode not in ['arithmetic', 'harmonic', 'safe_harmonic']:
848848
raise ValueError("Invalid averaging mode_mode %s, accepted values are"
849849
" arithmetic or harmonic" % self._avg_mode)
850850

@@ -878,8 +878,8 @@ def __halo_setup__(self, **kwargs):
878878
halo = tuple(kwargs.get('halo', ((0, 0),)*self.ndim))
879879
return DimensionTuple(*halo, getters=self.dimensions)
880880

881-
def __padding_setup__(self, **kwargs):
882-
padding = tuple(kwargs.get('padding', ((0, 0),)*self.ndim))
881+
def __padding_setup__(self, padding=None, **kwargs):
882+
padding = tuple(padding or ((0, 0),)*self.ndim)
883883
return DimensionTuple(*padding, getters=self.dimensions)
884884

885885
@cached_property
@@ -984,7 +984,7 @@ def c0(self):
984984
def _eval_deriv(self):
985985
return self
986986

987-
@cached_property
987+
@property
988988
def _grid_map(self):
989989
"""
990990
Mapper of off-grid interpolation points indices for each dimension.
@@ -1044,14 +1044,13 @@ def _evaluate(self, **kwargs):
10441044
return self
10451045

10461046
io = self.interp_order
1047-
if self._avg_mode == 'harmonic':
1048-
retval = 1 / self
1049-
else:
1050-
retval = self
1047+
retval = self.subs({i.subs(subs): self.indices_ref[d]
1048+
for d, i in mapper.items()})
1049+
if 'harmonic' in self._avg_mode:
1050+
retval = retval.safe_inv(retval, safe='safe' in self._avg_mode)
10511051

10521052
# Apply interpolation from inner most dim
10531053
for d, i in mapper.items():
1054-
retval = retval._subs(i.subs(subs), self.indices_ref[d])
10551054
retval = retval.diff(d, deriv_order=0, fd_order=io, x0={d: i})
10561055

10571056
# Evaluate. Since we used `self.function` it will be on the grid when
@@ -1060,9 +1059,9 @@ def _evaluate(self, **kwargs):
10601059
retval = retval.subs(subs)
10611060

10621061
# If harmonic averaging, invert at the end
1063-
if self._avg_mode == 'harmonic':
1064-
from devito.finite_differences.differentiable import SafeInv
1065-
retval = SafeInv(retval, self.function.subs(subs))
1062+
if 'harmonic' in self._avg_mode:
1063+
retval = retval.safe_inv(self.function.subs(subs),
1064+
safe='safe' in self._avg_mode)
10661065

10671066
return retval
10681067

devito/types/dense.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1537,6 +1537,15 @@ def _time_buffering(self):
15371537
def _time_buffering_default(self):
15381538
return self._time_buffering and not isinstance(self.save, Buffer)
15391539

1540+
def _evaluate(self, **kwargs):
1541+
retval = super()._evaluate(**kwargs)
1542+
if not self._time_buffering and not retval.is_Function:
1543+
# Saved TimeFunction might need streaming, expand interpolations
1544+
# for easier processing.
1545+
return retval.evaluate
1546+
else:
1547+
return retval
1548+
15401549
def _arg_check(self, args, intervals, **kwargs):
15411550
super()._arg_check(args, intervals, **kwargs)
15421551

examples/seismic/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def _initialize_physics(self, vp, space_order, **kwargs):
308308
vs = kwargs.pop('vs')
309309
self.lam = self._gen_phys_param((vp**2 - 2. * vs**2)/b, 'lam', space_order)
310310
self.mu = self._gen_phys_param(vs**2 / b, 'mu', space_order,
311-
avg_mode='harmonic')
311+
avg_mode='safe_harmonic')
312312
else:
313313
# All other seismic models have at least a velocity
314314
self.vp = self._gen_phys_param(vp, 'vp', space_order)

tests/test_differentiable.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def test_avg_mode(ndim, io):
119119

120120
a0 = Function(name="a0", grid=grid, **kw)
121121
a = Function(name="a", grid=grid, **kw)
122-
b = Function(name="b", grid=grid, avg_mode='harmonic', **kw)
122+
b = Function(name="b", grid=grid, avg_mode='safe_harmonic', **kw)
123123

124124
a0_avg = a0._eval_at(v)
125125
a_avg = a._eval_at(v).evaluate.simplify()
@@ -141,7 +141,8 @@ def test_avg_mode(ndim, io):
141141
assert sympy.simplify(a_avg - expected) == 0
142142

143143
# Harmonic average, h(a[.5]) = 1/(.5/a[0] + .5/a[1])
144-
expected = (sum(c / b.subs(arg) for c, arg in zip(ndcoeffs.flatten(), args)))
144+
expected = (sum(c * SafeInv(b.subs(arg), b.subs(arg))
145+
for c, arg in zip(ndcoeffs.flatten(), args)))
145146
assert sympy.simplify(b_avg.args[0] - expected) == 0
146147
assert isinstance(b_avg, SafeInv)
147148
assert b_avg.base == b

0 commit comments

Comments
 (0)