Skip to content

Commit f854b0c

Browse files
Merge pull request #2245 from devitocodes/funcs_on_subdims
dsl: Introduce ability to define Functions on Subdomains
2 parents 34dba05 + 3b8ec13 commit f854b0c

26 files changed

+5416
-572
lines changed

devito/data/data.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -238,11 +238,12 @@ def __getitem__(self, glb_idx, comm_type, gather_rank=None):
238238
stop += sendcounts[i]
239239
data_slice = recvbuf[slice(start, stop, step)]
240240
shape = [r.stop-r.start for r in self._distributor.all_ranges[i]]
241-
idx = [slice(r.start, r.stop, r.step)
242-
for r in self._distributor.all_ranges[i]]
243-
for i in range(len(self.shape) - len(self._distributor.glb_shape)):
244-
shape.insert(i, glb_shape[i])
245-
idx.insert(i, slice(0, glb_shape[i]+1, 1))
241+
idx = [slice(r.start - d.glb_min, r.stop - d.glb_min, r.step)
242+
for r, d in zip(self._distributor.all_ranges[i],
243+
self._distributor.decomposition)]
244+
for j in range(len(self.shape) - len(self._distributor.glb_shape)):
245+
shape.insert(j, glb_shape[j])
246+
idx.insert(j, slice(0, glb_shape[j]+1, 1))
246247
retval[tuple(idx)] = data_slice.reshape(tuple(shape))
247248
return retval
248249
else:
@@ -329,6 +330,7 @@ def __getitem__(self, glb_idx, comm_type, gather_rank=None):
329330
@_check_idx
330331
def __setitem__(self, glb_idx, val, comm_type):
331332
loc_idx = self._index_glb_to_loc(glb_idx)
333+
332334
if loc_idx is NONLOCAL:
333335
# no-op
334336
return

devito/data/decomposition.py

Lines changed: 49 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,26 @@ def index_glb_to_loc(self, *args, rel=True):
204204
>>> d.index_glb_to_loc((1, 6), rel=False)
205205
(5, 6)
206206
"""
207+
# Offset the loc_abs_min, loc_abs_max, glb_min, and glb_max
208+
# In the case of a Function defined on a SubDomain, the global indices
209+
# for accessing the data associated with this Function will run from
210+
# `ltkn` to `x_M-rtkn-1`. However, the indices for accessing this array
211+
# will run from `0` to `x_M-ltkn-rtkn-1`. As such, the global minimum
212+
# (`ltkn`) should be subtracted for the purpose of indexing into the local
213+
# array.
214+
if not self.loc_empty:
215+
loc_abs_min = self.loc_abs_min - self.glb_min
216+
loc_abs_max = self.loc_abs_max - self.glb_min
217+
glb_max = self.glb_max - self.glb_min
218+
else:
219+
loc_abs_min = self.loc_abs_min
220+
loc_abs_max = self.loc_abs_max
221+
glb_max = self.glb_max
222+
223+
glb_min = 0
207224

208-
base = self.loc_abs_min if rel is True else 0
209-
top = self.loc_abs_max
225+
base = loc_abs_min if rel else 0
226+
top = loc_abs_max
210227

211228
if len(args) == 1:
212229
glb_idx = args[0]
@@ -217,11 +234,11 @@ def index_glb_to_loc(self, *args, rel=True):
217234
return None
218235
# -> Handle negative index
219236
if glb_idx < 0:
220-
glb_idx = self.glb_max + glb_idx + 1
237+
glb_idx = glb_max + glb_idx + 1
221238
# -> Do the actual conversion
222-
if self.loc_abs_min <= glb_idx <= self.loc_abs_max:
239+
if loc_abs_min <= glb_idx <= loc_abs_max:
223240
return glb_idx - base
224-
elif self.glb_min <= glb_idx <= self.glb_max:
241+
elif glb_min <= glb_idx <= glb_max:
225242
return None
226243
else:
227244
# This should raise an exception when used to access a numpy.array
@@ -239,30 +256,32 @@ def index_glb_to_loc(self, *args, rel=True):
239256
elif isinstance(glb_idx, slice):
240257
if self.loc_empty:
241258
return slice(-1, -3)
242-
if glb_idx.step >= 0 and glb_idx.stop == self.glb_min:
243-
glb_idx_min = self.glb_min if glb_idx.start is None \
259+
if glb_idx.step >= 0 and glb_idx.stop == glb_min:
260+
glb_idx_min = glb_min if glb_idx.start is None \
244261
else glb_idx.start
245-
glb_idx_max = self.glb_min
262+
glb_idx_max = glb_min
246263
retfunc = lambda a, b: slice(a, b, glb_idx.step)
247264
elif glb_idx.step >= 0:
248-
glb_idx_min = self.glb_min if glb_idx.start is None \
265+
glb_idx_min = glb_min if glb_idx.start is None \
249266
else glb_idx.start
250-
glb_idx_max = self.glb_max if glb_idx.stop is None \
267+
glb_idx_max = glb_max \
268+
if glb_idx.stop is None \
251269
else glb_idx.stop-1
252270
retfunc = lambda a, b: slice(a, b + 1, glb_idx.step)
253271
else:
254-
glb_idx_min = self.glb_min if glb_idx.stop is None \
272+
glb_idx_min = glb_min if glb_idx.stop is None \
255273
else glb_idx.stop+1
256-
glb_idx_max = self.glb_max if glb_idx.start is None \
274+
glb_idx_max = glb_max if glb_idx.start is None \
257275
else glb_idx.start
258276
retfunc = lambda a, b: slice(b, a - 1, glb_idx.step)
259277
else:
260278
raise TypeError("Cannot convert index from `%s`" % type(glb_idx))
261279
# -> Handle negative min/max
262280
if glb_idx_min is not None and glb_idx_min < 0:
263-
glb_idx_min = self.glb_max + glb_idx_min + 1
281+
glb_idx_min = glb_max + glb_idx_min + 1
264282
if glb_idx_max is not None and glb_idx_max < 0:
265-
glb_idx_max = self.glb_max + glb_idx_max + 1
283+
glb_idx_max = glb_max + glb_idx_max + 1
284+
266285
# -> Do the actual conversion
267286
# Compute loc_min. For a slice with step > 0 this will be
268287
# used to produce slice.start and for a slice with step < 0 slice.stop.
@@ -271,19 +290,19 @@ def index_glb_to_loc(self, *args, rel=True):
271290
# coincide with loc_abs_min.
272291
if isinstance(glb_idx, slice) and glb_idx.step is not None \
273292
and glb_idx.step > 1:
274-
if glb_idx_min > self.loc_abs_max:
293+
if glb_idx_min > loc_abs_max:
275294
return retfunc(-1, -3)
276295
elif glb_idx.start is None: # glb start is zero.
277-
loc_min = self.loc_abs_min - base \
296+
loc_min = loc_abs_min - base \
278297
+ np.mod(glb_idx.step - np.mod(base, glb_idx.step),
279298
glb_idx.step)
280299
else: # glb start is given explicitly
281-
loc_min = self.loc_abs_min - base \
300+
loc_min = loc_abs_min - base \
282301
+ np.mod(glb_idx.step - np.mod(base - glb_idx.start,
283302
glb_idx.step), glb_idx.step)
284-
elif glb_idx_min is None or glb_idx_min < self.loc_abs_min:
285-
loc_min = self.loc_abs_min - base
286-
elif glb_idx_min > self.loc_abs_max:
303+
elif glb_idx_min is None or glb_idx_min < loc_abs_min:
304+
loc_min = loc_abs_min - base
305+
elif glb_idx_min > loc_abs_max:
287306
return retfunc(-1, -3)
288307
else:
289308
loc_min = glb_idx_min - base
@@ -294,19 +313,19 @@ def index_glb_to_loc(self, *args, rel=True):
294313
# coincide with loc_abs_max.
295314
if isinstance(glb_idx, slice) and glb_idx.step is not None \
296315
and glb_idx.step < -1:
297-
if glb_idx_max < self.loc_abs_min:
316+
if glb_idx_max < loc_abs_min:
298317
return retfunc(-1, -3)
299318
elif glb_idx.start is None:
300319
loc_max = top - base \
301-
+ np.mod(glb_idx.step - np.mod(top - self.glb_max,
320+
+ np.mod(glb_idx.step - np.mod(top - glb_max,
302321
glb_idx.step), glb_idx.step)
303322
else:
304323
loc_max = top - base \
305324
+ np.mod(glb_idx.step - np.mod(top - glb_idx.start,
306325
glb_idx.step), glb_idx.step)
307-
elif glb_idx_max is None or glb_idx_max > self.loc_abs_max:
308-
loc_max = self.loc_abs_max - base
309-
elif glb_idx_max < self.loc_abs_min:
326+
elif glb_idx_max is None or glb_idx_max > loc_abs_max:
327+
loc_max = loc_abs_max - base
328+
elif glb_idx_max < loc_abs_min:
310329
return retfunc(-1, -3)
311330
else:
312331
loc_max = glb_idx_max - base
@@ -321,19 +340,19 @@ def index_glb_to_loc(self, *args, rel=True):
321340
return None
322341
abs_ofs, side = args
323342
if side == LEFT:
324-
rel_ofs = self.glb_min + abs_ofs - base
343+
rel_ofs = glb_min + abs_ofs - base
325344
if abs_ofs >= base and abs_ofs <= top:
326345
return rel_ofs
327346
elif abs_ofs > top:
328347
return top + 1
329348
else:
330349
return None
331350
else:
332-
rel_ofs = abs_ofs - (self.glb_max - top)
333-
if abs_ofs >= self.glb_max - top and abs_ofs <= self.glb_max - base:
351+
rel_ofs = abs_ofs - (glb_max - top)
352+
if abs_ofs >= glb_max - top and abs_ofs <= glb_max - base:
334353
return rel_ofs
335-
elif abs_ofs > self.glb_max - base:
336-
return self.glb_max - base + 1
354+
elif abs_ofs > glb_max - base:
355+
return glb_max - base + 1
337356
else:
338357
return None
339358
else:

devito/deprecations.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,23 @@ class DevitoDeprecation():
77
@cached_property
88
def coeff_warn(self):
99
warn("The Coefficient API is deprecated and will be removed, coefficients should"
10-
"be passed directly to the derivative object `u.dx(weights=...)",
10+
" be passed directly to the derivative object `u.dx(weights=...)",
1111
DeprecationWarning, stacklevel=2)
1212
return
1313

1414
@cached_property
1515
def symbolic_warn(self):
1616
warn("coefficients='symbolic' is deprecated, coefficients should"
17-
"be passed directly to the derivative object `u.dx(weights=...)",
17+
" be passed directly to the derivative object `u.dx(weights=...)",
18+
DeprecationWarning, stacklevel=2)
19+
return
20+
21+
@cached_property
22+
def subdomain_warn(self):
23+
warn("Passing `SubDomain`s to `Grid` on instantiation using `mygrid ="
24+
" Grid(..., subdomains=(mydomain, ...))` is deprecated. The `Grid`"
25+
" should instead be passed as a kwarg when instantiating a subdomain"
26+
" `mydomain = MyDomain(grid=mygrid)`",
1827
DeprecationWarning, stacklevel=2)
1928
return
2029

devito/finite_differences/differentiable.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def time_order(self):
6464
@cached_property
6565
def grid(self):
6666
grids = {getattr(i, 'grid', None) for i in self._args_diff} - {None}
67+
grids = {g.root for g in grids}
6768
if len(grids) > 1:
6869
warning("Expression contains multiple grids, returning first found")
6970
try:
@@ -86,6 +87,11 @@ def dimensions(self):
8687
return tuple(filter_ordered(flatten(getattr(i, 'dimensions', ())
8788
for i in self._args_diff)))
8889

90+
@cached_property
91+
def root_dimensions(self):
92+
"""Tuple of root Dimensions of the physical space Dimensions."""
93+
return tuple(d.root for d in self.dimensions if d.is_Space)
94+
8995
@property
9096
def indices_ref(self):
9197
"""The reference indices of the object (indices at first creation)."""
@@ -317,7 +323,7 @@ def laplacian(self, shift=None, order=None, method='FD', **kwargs):
317323
"""
318324
w = kwargs.get('weights', kwargs.get('w'))
319325
order = order or self.space_order
320-
space_dims = [d for d in self.dimensions if d.is_Space]
326+
space_dims = self.root_dimensions
321327
shift_x0 = make_shift_x0(shift, (len(space_dims),))
322328
derivs = tuple('d%s2' % d.name for d in space_dims)
323329
return Add(*[getattr(self, d)(x0=shift_x0(shift, space_dims[i], None, i),
@@ -344,7 +350,7 @@ def div(self, shift=None, order=None, method='FD', **kwargs):
344350
Custom weights for the finite difference coefficients.
345351
"""
346352
w = kwargs.get('weights', kwargs.get('w'))
347-
space_dims = [d for d in self.dimensions if d.is_Space]
353+
space_dims = self.root_dimensions
348354
shift_x0 = make_shift_x0(shift, (len(space_dims),))
349355
order = order or self.space_order
350356
return Add(*[getattr(self, 'd%s' % d.name)(x0=shift_x0(shift, d, None, i),
@@ -371,7 +377,7 @@ def grad(self, shift=None, order=None, method='FD', **kwargs):
371377
Custom weights for the finite
372378
"""
373379
from devito.types.tensor import VectorFunction, VectorTimeFunction
374-
space_dims = [d for d in self.dimensions if d.is_Space]
380+
space_dims = self.root_dimensions
375381
shift_x0 = make_shift_x0(shift, (len(space_dims),))
376382
order = order or self.space_order
377383
w = kwargs.get('weights', kwargs.get('w'))
@@ -387,7 +393,7 @@ def biharmonic(self, weight=1):
387393
Generates a symbolic expression for the weighted biharmonic operator w.r.t.
388394
all spatial Dimensions Laplace(weight * Laplace (self))
389395
"""
390-
space_dims = [d for d in self.dimensions if d.is_Space]
396+
space_dims = self.root_dimensions
391397
derivs = tuple('d%s2' % d.name for d in space_dims)
392398
return Add(*[getattr(self.laplace * weight, d) for d in derivs])
393399

devito/ir/clusters/cluster.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ def functions(self):
193193
def grid(self):
194194
grids = set(f.grid for f in self.functions if f.is_AbstractFunction)
195195
grids.discard(None)
196+
grids = {g.root for g in grids}
196197
if len(grids) == 0:
197198
return None
198199
elif len(grids) == 1:

devito/ir/equations/algorithms.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44
from devito.symbolics import (retrieve_indexed, uxreplace, retrieve_dimensions,
55
retrieve_functions)
6-
from devito.tools import Ordering, as_tuple, flatten, filter_sorted, filter_ordered
6+
from devito.tools import (Ordering, as_tuple, flatten, filter_sorted, filter_ordered,
7+
frozendict)
78
from devito.types import (Dimension, Eq, IgnoreDimSort, SubDimension,
89
ConditionalDimension)
910
from devito.types.array import Array
@@ -112,11 +113,7 @@ def lower_exprs(expressions, subs=None, **kwargs):
112113
def _lower_exprs(expressions, subs):
113114
processed = []
114115
for expr in as_tuple(expressions):
115-
try:
116-
dimension_map = expr.subdomain.dimension_map
117-
except AttributeError:
118-
# Some Relationals may be pure SymPy objects, thus lacking the subdomain
119-
dimension_map = {}
116+
dimension_map = _make_dimension_map(expr)
120117

121118
# Handle Functions (typical case)
122119
mapper = {f: _lower_exprs(f.indexify(subs=dimension_map), subs)
@@ -160,6 +157,30 @@ def _lower_exprs(expressions, subs):
160157
return processed.pop()
161158

162159

160+
def _make_dimension_map(expr):
161+
"""
162+
Make the dimension_map for an expression. In the basic case, this is extracted
163+
directly from the SubDomain attached to the expression.
164+
165+
The indices of a Function defined on a SubDomain will all be the SubDimensions of
166+
that SubDomain. In this case, the dimension_map should be extended with
167+
`{ix_f: ix_i, iy_f: iy_i}` where `ix_f` is the SubDimension on which the Function is
168+
defined, and `ix_i` is the SubDimension to be iterated over.
169+
"""
170+
try:
171+
dimension_map = {**expr.subdomain.dimension_map}
172+
except AttributeError:
173+
# Some Relationals may be pure SymPy objects, thus lacking the SubDomain
174+
dimension_map = {}
175+
else:
176+
functions = [f for f in retrieve_functions(expr) if f._is_on_subdomain]
177+
for f in functions:
178+
dimension_map.update({d: expr.subdomain.dimension_map[d.root]
179+
for d in f.space_dimensions if d.is_Sub})
180+
181+
return frozendict(dimension_map)
182+
183+
163184
def concretize_subdims(exprs, **kwargs):
164185
"""
165186
Given a list of expressions, return a new list where all user-defined
@@ -206,7 +227,14 @@ def _(v, mapper, rebuilt, sregistry):
206227

207228
@_concretize_subdims.register(Eq)
208229
def _(expr, mapper, rebuilt, sregistry):
209-
for d in expr.free_symbols:
230+
# Split and reorder symbols so SubDimensions are processed before lone Thicknesses
231+
# This means that if a Thickness appears both in the expression and attached to
232+
# a SubDimension, it gets concretised with the SubDimension.
233+
thicknesses = {i for i in expr.free_symbols if isinstance(i, Thickness)}
234+
symbols = expr.free_symbols.difference(thicknesses)
235+
236+
# Iterate over all other symbols before iterating over standalone thicknesses
237+
for d in tuple(symbols) + tuple(thicknesses):
210238
_concretize_subdims(d, mapper, rebuilt, sregistry)
211239

212240
# Subdimensions can be hiding in implicit dims

0 commit comments

Comments
 (0)