Skip to content

Commit a00cca8

Browse files
authored
Merge pull request #2782 from devitocodes/fd-eval-add
api: fix staggered evaluation for add/mul
2 parents d88639b + 26490d4 commit a00cca8

File tree

20 files changed

+280
-168
lines changed

20 files changed

+280
-168
lines changed

devito/arch/archinfo.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -619,18 +619,12 @@ def get_m1_llvm_path(language):
619619

620620
@memoized_func
621621
def check_cuda_runtime():
622-
libnames = ('libcudart.so', 'libcudart.dylib', 'cudart.dll')
623-
for libname in libnames:
624-
try:
625-
cuda = ctypes.CDLL(libname)
626-
except OSError:
627-
continue
628-
else:
629-
break
630-
else:
622+
libname = ctypes.util.find_library("cudart")
623+
if not libname:
631624
warning("Unable to check compatibility of NVidia driver and runtime")
632625
return
633626

627+
cuda = ctypes.CDLL(libname)
634628
driver_version = ctypes.c_int()
635629
runtime_version = ctypes.c_int()
636630

@@ -1069,6 +1063,32 @@ def march(self):
10691063
return 'tesla'
10701064
return None
10711065

1066+
@cached_property
1067+
def max_shm_per_block(self):
1068+
"""
1069+
Get the maximum amount of shared memory per thread block
1070+
"""
1071+
# Load libcudart
1072+
libname = ctypes.util.find_library("cudart")
1073+
if not libname:
1074+
return 64 * 1024 # 64 KB default
1075+
lib = ctypes.CDLL(libname)
1076+
1077+
cudaDevAttrMaxSharedMemoryPerBlockOptin = 97
1078+
# get current device
1079+
dev = ctypes.c_int()
1080+
lib.cudaGetDevice(ctypes.byref(dev))
1081+
1082+
# query attribute
1083+
value = ctypes.c_int()
1084+
lib.cudaDeviceGetAttribute(
1085+
ctypes.byref(value),
1086+
ctypes.c_int(cudaDevAttrMaxSharedMemoryPerBlockOptin),
1087+
dev
1088+
)
1089+
1090+
return value.value
1091+
10721092
def supports(self, query, language=None):
10731093
if language != 'cuda':
10741094
return False
@@ -1125,6 +1145,8 @@ class AmdDevice(Device):
11251145

11261146
max_mem_trans_nbytes = 256
11271147

1148+
max_shm_per_block = 64*1024 # 64 KB
1149+
11281150
@cached_property
11291151
def march(cls):
11301152
# TODO: this corresponds to Vega, which acts as the fallback `march`

devito/finite_differences/derivative.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -335,20 +335,6 @@ def __call__(self, x0=None, fd_order=None, side=None, method=None, **kwargs):
335335
except AttributeError:
336336
raise TypeError("fd_order incompatible with dimensions") from None
337337

338-
if isinstance(self.expr, Derivative):
339-
# In case this was called on a perfect cross-derivative `u.dxdy`
340-
# we need to propagate the call to the nested derivative
341-
rkwe = dict(rkw)
342-
rkwe.pop('weights', None)
343-
if 'x0' in rkwe:
344-
rkwe['x0'] = self._filter_dims(self.expr._filter_dims(rkw['x0']),
345-
neg=True)
346-
if fd_order is not None:
347-
fdo = self.expr._filter_dims(_fd_order)
348-
if fdo:
349-
rkwe['fd_order'] = fdo
350-
rkw['expr'] = self.expr(**rkwe)
351-
352338
if fd_order is not None:
353339
rkw['fd_order'] = self._filter_dims(_fd_order, as_tuple=True)
354340

@@ -530,9 +516,12 @@ def _eval_at(self, func):
530516
# it into `u(x + h_x/2).dx` and `v(x).dx`, since they require
531517
# different FD indices
532518
mapper = as_mapper(self.expr._args_diff, lambda i: i.staggered)
519+
if len(mapper) == 1:
520+
# All terms have the same staggering, we can use expr as is
521+
return self._rebuild(self.expr, **rkw)
533522
args = [self.expr.func(*v) for v in mapper.values()]
534523
args.extend([a for a in self.expr.args if a not in self.expr._args_diff])
535-
args = [self._rebuild(a, **rkw) for a in args]
524+
args = [self._rebuild(a)._eval_at(func) for a in args]
536525
return self.expr.func(*args)
537526
elif self.expr.is_Mul:
538527
# For Mul, We treat the basic case `u(x + h_x/2) * v(x) which is what appear

devito/finite_differences/differentiable.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,12 @@ def __rfloordiv__(self, other):
259259
from .elementary import floor
260260
return floor(other / self)
261261

262+
def _inv(self, ref, safe=False):
263+
if safe:
264+
return SafeInv(self, ref or self)
265+
else:
266+
return 1 / self
267+
262268
def __mod__(self, other):
263269
return Mod(self, other)
264270

@@ -964,7 +970,11 @@ def _subs(self, old, new, **hints):
964970

965971

966972
class DiffDerivative(IndexDerivative, DifferentiableOp):
967-
pass
973+
974+
def _eval_at(self, func):
975+
# Like EvalDerivative, a DiffDerivative must have already been evaluated
976+
# at a valid x0 and should not be re-evaluated at a different location
977+
return self
968978

969979

970980
# SymPy args ordering is the same for Derivatives and IndexDerivatives
@@ -1012,6 +1022,11 @@ def _new_rawargs(self, *args, **kwargs):
10121022
kwargs.pop('is_commutative', None)
10131023
return self.func(*args, **kwargs)
10141024

1025+
def _eval_at(self, func):
1026+
# An EvalDerivative must have already been evaluated at a valid x0
1027+
# and should not be re-evaluated at a different location
1028+
return self
1029+
10151030

10161031
class diffify:
10171032

@@ -1145,13 +1160,8 @@ def _(expr, x0, **kwargs):
11451160

11461161
@interp_for_fd.register(AbstractFunction)
11471162
def _(expr, x0, **kwargs):
1148-
from devito.finite_differences.derivative import Derivative
1149-
x0_expr = {d: v for d, v in x0.items() if v is not expr.indices_ref[d]}
1150-
if expr.is_parameter:
1151-
return expr
1152-
elif x0_expr:
1153-
dims = tuple((d, 0) for d in x0_expr)
1154-
fd_o = tuple([expr.interp_order]*len(dims))
1155-
return Derivative(expr, *dims, fd_order=fd_o, x0=x0_expr)
1163+
x0_expr = {d: v for d, v in x0.items() if v.has(d)}
1164+
if x0_expr:
1165+
return expr.subs({expr.indices[d]: v for d, v in x0_expr.items()})
11561166
else:
11571167
return expr

devito/passes/clusters/implicit.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,13 +175,17 @@ def callback(self, clusters, prefix):
175175

176176
# Make sure the "implicit expressions" are scheduled in
177177
# the innermost loop such that the thicknesses can be computed
178-
edims = set(retrieve_dimensions(mapper.values(), deep=True))
179-
if dim not in edims or not edims.issubset(prefix.dimensions):
178+
def key(tkn):
179+
edims = set(retrieve_dimensions(tkn, deep=True))
180+
return dim._defines & edims and edims.issubset(prefix.dimensions)
181+
182+
mapper = {k: v for k, v in mapper.items() if key(v)}
183+
if not mapper:
180184
continue
181185

182186
found[d.functions].clusters.append(c)
183187
found[d.functions].mapper = reduce(found[d.functions].mapper,
184-
mapper, edims, prefix)
188+
mapper, {dim}, prefix)
185189

186190
# Turn the reduced mapper into a list of equations
187191
processed = []
@@ -259,7 +263,11 @@ def reduce(m0, m1, edims, prefix):
259263
else:
260264
func = min
261265

262-
key = lambda i: i.indices[d]
266+
def key(i):
267+
try:
268+
return i.indices[d]
269+
except (KeyError, AttributeError):
270+
return i
263271

264272
mapper = {}
265273
for k, e in m1.items():

devito/passes/iet/misc.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
from devito.finite_differences import Max, Min
88
from devito.finite_differences.differentiable import SafeInv
9-
from devito.logger import warning
109
from devito.ir import (Any, Forward, DummyExpr, Iteration, EmptyList, Prodder,
1110
FindApplications, FindNodes, FindSymbols, Transformer,
1211
Uxreplace, filter_iterations, retrieve_iteration_tree,
@@ -155,7 +154,7 @@ def _generate_macros(iet, tracker=None, langbb=None, printer=CPrinter, **kwargs)
155154
for define, expr in headers)
156155

157156
# Generate Macros from higher-level SymPy objects
158-
mheaders, includes = _generate_macros_math(iet, langbb=langbb)
157+
mheaders, includes = _generate_macros_math(iet, langbb=langbb, printer=printer)
159158
includes = sorted(includes, key=str)
160159
headers.extend(sorted(mheaders, key=str))
161160

@@ -199,25 +198,25 @@ def _generate_macros_findexeds(iet, sregistry=None, tracker=None, **kwargs):
199198
return iet
200199

201200

202-
def _generate_macros_math(iet, langbb=None):
201+
def _generate_macros_math(iet, langbb=None, printer=CPrinter):
203202
headers = []
204203
includes = []
205204
for i in FindApplications().visit(iet):
206-
header, include = _lower_macro_math(i, langbb)
205+
header, include = _lower_macro_math(i, langbb, printer)
207206
headers.extend(header)
208207
includes.extend(include)
209208

210209
return headers, set(includes) - {None}
211210

212211

213212
@singledispatch
214-
def _lower_macro_math(expr, langbb):
213+
def _lower_macro_math(expr, langbb, printer):
215214
return (), {}
216215

217216

218217
@_lower_macro_math.register(Min)
219218
@_lower_macro_math.register(sympy.Min)
220-
def _(expr, langbb):
219+
def _(expr, langbb, printer):
221220
if has_integer_args(*expr.args):
222221
return (('MIN(a,b)', ('(((a) < (b)) ? (a) : (b))')),), {}
223222
else:
@@ -226,23 +225,28 @@ def _(expr, langbb):
226225

227226
@_lower_macro_math.register(Max)
228227
@_lower_macro_math.register(sympy.Max)
229-
def _(expr, langbb):
228+
def _(expr, langbb, printer):
230229
if has_integer_args(*expr.args):
231230
return (('MAX(a,b)', ('(((a) > (b)) ? (a) : (b))')),), {}
232231
else:
233232
return (), as_tuple(langbb.get('header-math'))
234233

235234

236235
@_lower_macro_math.register(SafeInv)
237-
def _(expr, langbb):
236+
def _(expr, langbb, printer):
238237
try:
239-
eps = np.finfo(expr.base.dtype).resolution**2
240-
except ValueError:
241-
warning(f"dtype not recognized in SafeInv for {expr.base}, assuming float32")
242-
eps = np.finfo(np.float32).resolution**2
243-
b = Cast('b', dtype=np.float32)
238+
dtype = expr.base.dtype
239+
eps = np.finfo(dtype).resolution**2
240+
except (AttributeError, ValueError):
241+
dtype = np.float32
242+
eps = np.finfo(dtype).resolution**2
243+
244+
b = printer()._print(Cast('b', dtype=dtype))
245+
ext = 'F' if dtype is np.float32 else ''
246+
244247
return (('SAFEINV(a, b)',
245-
f'(((a) < {eps}F || ({b}) < {eps}F) ? (0.0F) : ((1.0F) / (a)))'),), {}
248+
f'(((a) < {eps}{ext} || ({b}) < {eps}{ext}) ? '
249+
f'(0.0{ext}) : ((1.0{ext}) / (a)))'),), {}
246250

247251

248252
@iet_pass

devito/passes/iet/parpragma.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,9 @@ def _make_parregion(self, partree, parrays):
317317
i = n.write
318318
if not (i.is_Array or i.is_TempFunction):
319319
continue
320+
elif partree.dim in i.dimensions:
321+
# Non-local Array (full iteration space): no need to vector-expand
322+
continue
320323
elif i in parrays:
321324
pi = parrays[i]
322325
else:

devito/symbolics/inspection.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from sympy import (Function, Indexed, Integer, Mul, Number,
55
Pow, S, Symbol, Tuple)
66
from sympy.core.numbers import ImaginaryUnit
7+
from sympy.core.function import Application
78

89
from devito.finite_differences import Derivative
910
from devito.finite_differences.differentiable import IndexDerivative
@@ -12,7 +13,7 @@
1213
from devito.symbolics.extended_sympy import (CallFromPointer, Cast,
1314
DefFunction, ReservedWord)
1415
from devito.symbolics.queries import q_routine
15-
from devito.tools import as_tuple, prod
16+
from devito.tools import as_tuple, prod, is_integer
1617
from devito.tools.dtypes_lowering import infer_dtype
1718

1819
__all__ = ['compare_ops', 'estimate_cost', 'has_integer_args', 'sympy_dtype']
@@ -116,7 +117,7 @@ def estimate_cost(exprs, estimate=False):
116117
estimate_values = {
117118
'elementary': 100,
118119
'pow': 50,
119-
'SafeInv': 10,
120+
'SafeInv': 50,
120121
'div': 5,
121122
'Abs': 5,
122123
'floor': 1,
@@ -211,6 +212,7 @@ def _(expr, estimate, seen):
211212

212213

213214
@_estimate_cost.register(Function)
215+
@_estimate_cost.register(Application)
214216
def _(expr, estimate, seen):
215217
if q_routine(expr):
216218
flops, _ = zip(*[_estimate_cost(a, estimate, seen) for a in expr.args])
@@ -227,6 +229,7 @@ def _(expr, estimate, seen):
227229
flops += 1
228230
else:
229231
flops = 0
232+
230233
return flops, False
231234

232235

@@ -284,12 +287,14 @@ def has_integer_args(*args):
284287
try:
285288
return np.issubdtype(args[0].dtype, np.integer)
286289
except AttributeError:
287-
return args[0].is_integer
290+
return is_integer(args[0])
288291

289292
res = True
290293
for a in args:
291294
try:
292-
if isinstance(a, INT):
295+
if isinstance(a, INT) or \
296+
is_integer(a) or \
297+
has_integer_args(a):
293298
res = res and True
294299
elif len(a.args) > 0:
295300
res = res and has_integer_args(*a.args)

0 commit comments

Comments
 (0)