Skip to content

Commit 55f84b2

Browse files
committed
compiler: patch multi-subdim isapce handling
1 parent 5b93b20 commit 55f84b2

File tree

6 files changed

+45
-28
lines changed

6 files changed

+45
-28
lines changed

devito/finite_differences/differentiable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def __rfloordiv__(self, other):
259259
from .elementary import floor
260260
return floor(other / self)
261261

262-
def safe_inv(self, ref, safe=False):
262+
def _inv(self, ref, safe=False):
263263
if safe:
264264
return SafeInv(self, ref or self)
265265
else:

devito/passes/clusters/implicit.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,13 +175,17 @@ def callback(self, clusters, prefix):
175175

176176
# Make sure the "implicit expressions" are scheduled in
177177
# the innermost loop such that the thicknesses can be computed
178-
edims = set(retrieve_dimensions(mapper.values(), deep=True))
179-
if dim not in edims or not edims.issubset(prefix.dimensions):
178+
def key(tkn):
179+
edims = set(retrieve_dimensions(tkn, deep=True))
180+
return dim._defines & edims and edims.issubset(prefix.dimensions)
181+
182+
mapper = {k: v for k, v in mapper.items() if key(v)}
183+
if not mapper:
180184
continue
181185

182186
found[d.functions].clusters.append(c)
183187
found[d.functions].mapper = reduce(found[d.functions].mapper,
184-
mapper, edims, prefix)
188+
mapper, {dim}, prefix)
185189

186190
# Turn the reduced mapper into a list of equations
187191
processed = []
@@ -262,7 +266,7 @@ def reduce(m0, m1, edims, prefix):
262266
def key(i):
263267
try:
264268
return i.indices[d]
265-
except AttributeError:
269+
except (KeyError, AttributeError):
266270
return i
267271

268272
mapper = {}

devito/passes/iet/misc.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
from devito.finite_differences import Max, Min
88
from devito.finite_differences.differentiable import SafeInv
9-
from devito.logger import warning
109
from devito.ir import (Any, Forward, DummyExpr, Iteration, EmptyList, Prodder,
1110
FindApplications, FindNodes, FindSymbols, Transformer,
1211
Uxreplace, filter_iterations, retrieve_iteration_tree,
@@ -155,7 +154,7 @@ def _generate_macros(iet, tracker=None, langbb=None, printer=CPrinter, **kwargs)
155154
for define, expr in headers)
156155

157156
# Generate Macros from higher-level SymPy objects
158-
mheaders, includes = _generate_macros_math(iet, langbb=langbb)
157+
mheaders, includes = _generate_macros_math(iet, langbb=langbb, printer=printer)
159158
includes = sorted(includes, key=str)
160159
headers.extend(sorted(mheaders, key=str))
161160

@@ -199,25 +198,25 @@ def _generate_macros_findexeds(iet, sregistry=None, tracker=None, **kwargs):
199198
return iet
200199

201200

202-
def _generate_macros_math(iet, langbb=None):
201+
def _generate_macros_math(iet, langbb=None, printer=CPrinter):
203202
headers = []
204203
includes = []
205204
for i in FindApplications().visit(iet):
206-
header, include = _lower_macro_math(i, langbb)
205+
header, include = _lower_macro_math(i, langbb, printer)
207206
headers.extend(header)
208207
includes.extend(include)
209208

210209
return headers, set(includes) - {None}
211210

212211

213212
@singledispatch
214-
def _lower_macro_math(expr, langbb):
213+
def _lower_macro_math(expr, langbb, printer):
215214
return (), {}
216215

217216

218217
@_lower_macro_math.register(Min)
219218
@_lower_macro_math.register(sympy.Min)
220-
def _(expr, langbb):
219+
def _(expr, langbb, printer):
221220
if has_integer_args(*expr.args):
222221
return (('MIN(a,b)', ('(((a) < (b)) ? (a) : (b))')),), {}
223222
else:
@@ -226,23 +225,28 @@ def _(expr, langbb):
226225

227226
@_lower_macro_math.register(Max)
228227
@_lower_macro_math.register(sympy.Max)
229-
def _(expr, langbb):
228+
def _(expr, langbb, printer):
230229
if has_integer_args(*expr.args):
231230
return (('MAX(a,b)', ('(((a) > (b)) ? (a) : (b))')),), {}
232231
else:
233232
return (), as_tuple(langbb.get('header-math'))
234233

235234

236235
@_lower_macro_math.register(SafeInv)
237-
def _(expr, langbb):
236+
def _(expr, langbb, printer):
238237
try:
239-
eps = np.finfo(expr.base.dtype).resolution**2
240-
except ValueError:
241-
warning(f"dtype not recognized in SafeInv for {expr.base}, assuming float32")
242-
eps = np.finfo(np.float32).resolution**2
243-
b = Cast('b', dtype=np.float32)
238+
dtype = expr.base.dtype
239+
eps = np.finfo(dtype).resolution**2
240+
except (AttributeError, ValueError):
241+
dtype = np.float32
242+
eps = np.finfo(dtype).resolution**2
243+
244+
b = printer()._print(Cast('b', dtype=dtype))
245+
ext = 'F' if dtype is np.float32 else ''
246+
244247
return (('SAFEINV(a, b)',
245-
f'(((a) < {eps}F || ({b}) < {eps}F) ? (0.0F) : ((1.0F) / (a)))'),), {}
248+
f'(((a) < {eps}{ext} || ({b}) < {eps}{ext}) ? '
249+
f'(0.0{ext}) : ((1.0{ext}) / (a)))'),), {}
246250

247251

248252
@iet_pass

devito/symbolics/inspection.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from devito.symbolics.extended_sympy import (CallFromPointer, Cast,
1414
DefFunction, ReservedWord)
1515
from devito.symbolics.queries import q_routine
16-
from devito.tools import as_tuple, prod
16+
from devito.tools import as_tuple, prod, is_integer
1717
from devito.tools.dtypes_lowering import infer_dtype
1818

1919
__all__ = ['compare_ops', 'estimate_cost', 'has_integer_args', 'sympy_dtype']
@@ -287,12 +287,14 @@ def has_integer_args(*args):
287287
try:
288288
return np.issubdtype(args[0].dtype, np.integer)
289289
except AttributeError:
290-
return args[0].is_integer
290+
return is_integer(args[0])
291291

292292
res = True
293293
for a in args:
294294
try:
295-
if isinstance(a, INT):
295+
if isinstance(a, INT) or \
296+
is_integer(a) or \
297+
has_integer_args(a):
296298
res = res and True
297299
elif len(a.args) > 0:
298300
res = res and has_integer_args(*a.args)

devito/types/basic.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,6 +1028,14 @@ def _grid_map(self):
10281028

10291029
return mapper
10301030

1031+
@property
1032+
def is_harmonic(self):
1033+
return self.avg_mode == 'harmonic' or self.avg_mode == 'safe_harmonic'
1034+
1035+
@property
1036+
def is_harmonic_safe(self):
1037+
return self.avg_mode == 'safe_harmonic'
1038+
10311039
def _evaluate(self, **kwargs):
10321040
"""
10331041
Evaluate off the grid with 2nd order interpolation.
@@ -1046,8 +1054,8 @@ def _evaluate(self, **kwargs):
10461054
io = self.interp_order
10471055
retval = self.subs({i.subs(subs): self.indices_ref[d]
10481056
for d, i in mapper.items()})
1049-
if 'harmonic' in self._avg_mode:
1050-
retval = retval.safe_inv(retval, safe='safe' in self._avg_mode)
1057+
if self.is_harmonic:
1058+
retval = retval._inv(retval, safe=self.is_harmonic_safe)
10511059

10521060
# Apply interpolation from inner most dim
10531061
for d, i in mapper.items():
@@ -1059,9 +1067,8 @@ def _evaluate(self, **kwargs):
10591067
retval = retval.subs(subs)
10601068

10611069
# If harmonic averaging, invert at the end
1062-
if 'harmonic' in self._avg_mode:
1063-
retval = retval.safe_inv(self.function.subs(subs),
1064-
safe='safe' in self._avg_mode)
1070+
if self.is_harmonic:
1071+
retval = retval._inv(self.function.subs(subs), safe=self.is_harmonic_safe)
10651072

10661073
return retval
10671074

devito/types/dense.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1541,7 +1541,7 @@ def _evaluate(self, **kwargs):
15411541
retval = super()._evaluate(**kwargs)
15421542
if not self._time_buffering and not retval.is_Function:
15431543
# Saved TimeFunction might need streaming, expand interpolations
1544-
# for easier processing.
1544+
# for easier processing
15451545
return retval.evaluate
15461546
else:
15471547
return retval

0 commit comments

Comments
 (0)