Skip to content

Commit b3408d8

Browse files
authored
Merge pull request #2629 from devitocodes/fix-mpi-hoisting-2
compiler: Revamp MPI hoisting and merging
2 parents 2a0076f + b3a3e00 commit b3408d8

File tree

7 files changed

+558
-142
lines changed

7 files changed

+558
-142
lines changed

devito/ir/iet/nodes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1516,8 +1516,8 @@ def arguments(self):
15161516
return self.halo_scheme.arguments
15171517

15181518
@property
1519-
def is_empty(self):
1520-
return len(self.halo_scheme) == 0
1519+
def is_void(self):
1520+
return self.halo_scheme.is_void
15211521

15221522
@property
15231523
def body(self):

devito/ir/iet/visitors.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,10 @@
2828
IndexedData, DeviceMap)
2929

3030

31-
__all__ = ['FindApplications', 'FindNodes', 'FindSections', 'FindSymbols',
32-
'MapExprStmts', 'MapHaloSpots', 'MapNodes', 'IsPerfectIteration',
33-
'printAST', 'CGen', 'CInterface', 'Transformer', 'Uxreplace']
31+
__all__ = ['FindApplications', 'FindNodes', 'FindWithin', 'FindSections',
32+
'FindSymbols', 'MapExprStmts', 'MapHaloSpots', 'MapNodes',
33+
'IsPerfectIteration', 'printAST', 'CGen', 'CInterface', 'Transformer',
34+
'Uxreplace']
3435

3536

3637
class Visitor(GenericVisitor):
@@ -1112,6 +1113,49 @@ def visit_Node(self, o, ret=None):
11121113
return ret
11131114

11141115

1116+
class FindWithin(FindNodes):
1117+
1118+
@classmethod
1119+
def default_retval(cls):
1120+
return [], False
1121+
1122+
"""
1123+
Like FindNodes, but given an additional parameter `within=(start, stop)`,
1124+
it starts collecting matching nodes only after `start` is found, and stops
1125+
collecting matching nodes after `stop` is found.
1126+
"""
1127+
1128+
def __init__(self, match, start, stop=None):
1129+
super().__init__(match)
1130+
self.start = start
1131+
self.stop = stop
1132+
1133+
def visit(self, o, ret=None):
1134+
found, _ = self._visit(o, ret=ret)
1135+
return found
1136+
1137+
def visit_Node(self, o, ret=None):
1138+
if ret is None:
1139+
ret = self.default_retval()
1140+
found, flag = ret
1141+
1142+
if o is self.start:
1143+
flag = True
1144+
1145+
if flag and self.rule(self.match, o):
1146+
found.append(o)
1147+
for i in o.children:
1148+
found, newflag = self._visit(i, ret=(found, flag))
1149+
if flag and not newflag:
1150+
return found, newflag
1151+
flag = newflag
1152+
1153+
if o is self.stop:
1154+
flag = False
1155+
1156+
return found, flag
1157+
1158+
11151159
class FindApplications(Visitor):
11161160

11171161
"""

devito/ir/support/basic.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -382,32 +382,36 @@ def distance(self, other):
382382
if disjoint_test(self[n], other[n], sai, sit):
383383
return Vector(S.ImaginaryUnit)
384384

385+
# Compute the distance along the current IterationInterval
385386
if self.function._mem_shared:
386387
# Special case: the distance between two regular, thread-shared
387-
# objects fallbacks to zero, as any other value would be nonsensical
388+
# objects falls back to zero, as any other value would be
389+
# nonsensical
390+
ret.append(S.Zero)
391+
elif degenerating_dimensions(sai, oai):
392+
# Special case: `sai` and `oai` may be different symbolic objects
393+
# but they can be proved to systematically generate the same value
388394
ret.append(S.Zero)
389-
390395
elif sai and oai and sai._defines & sit.dim._defines:
391-
# E.g., `self=R<f,[t + 1, x]>`, `self.itintervals=(time, x)`, `ai=t`
396+
# E.g., `self=R<f,[t + 1, x]>`, `self.itintervals=(time, x)`,
397+
# and `ai=t`
392398
if sit.direction is Backward:
393399
ret.append(other[n] - self[n])
394400
else:
395401
ret.append(self[n] - other[n])
396-
397402
elif not sai and not oai:
398403
# E.g., `self=R<a,[3]>` and `other=W<a,[4]>`
399404
if self[n] - other[n] == 0:
400405
ret.append(S.Zero)
401406
else:
402407
break
403-
404408
elif sai in self.ispace and oai in other.ispace:
405409
# E.g., `self=R<f,[x, y]>`, `sai=time`,
406410
# `self.itintervals=(time, x, y)`, `n=0`
407411
continue
408-
409412
else:
410-
# E.g., `self=R<u,[t+1, ii_src_0+1, ii_src_1+2]>`, `fi=p_src`, `n=1`
413+
# E.g., `self=R<u,[t+1, ii_src_0+1, ii_src_1+2]>`, `fi=p_src`,
414+
# and `n=1`
411415
return vinf(ret)
412416

413417
n = len(ret)
@@ -1408,3 +1412,19 @@ def disjoint_test(e0, e1, d, it):
14081412
i1 = sympy.Interval(min(p10, p11), max(p10, p11))
14091413

14101414
return not bool(i0.intersect(i1))
1415+
1416+
1417+
def degenerating_dimensions(d0, d1):
1418+
"""
1419+
True if `d0` and `d1` are Dimensions that are possibly symbolically
1420+
different, but they can be proved to systematically degenerate to the
1421+
same value, False otherwise.
1422+
"""
1423+
# Case 1: ModuloDimensions of size 1
1424+
try:
1425+
if d0.is_Modulo and d1.is_Modulo and d0.modulo == d1.modulo == 1:
1426+
return True
1427+
except AttributeError:
1428+
pass
1429+
1430+
return False

devito/mpi/halo_scheme.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ def __hash__(self):
4343
return hash((self.loc_indices, self.loc_dirs, self.halos, self.dims,
4444
self.bundle))
4545

46+
@cached_property
47+
def loc_values(self):
48+
return frozenset(self.loc_indices.values())
49+
4650
def union(self, other):
4751
"""
4852
Return a new HaloSchemeEntry that is the union of this and `other`.
@@ -384,6 +388,10 @@ def owned_size(self):
384388
mapper[d] = (max(maxl, s.left), max(maxr, s.right))
385389
return mapper
386390

391+
@cached_property
392+
def functions(self):
393+
return frozenset(self.fmapper)
394+
387395
@cached_property
388396
def dimensions(self):
389397
retval = set()
@@ -413,6 +421,38 @@ def loc_values(self):
413421
def arguments(self):
414422
return self.dimensions | set(flatten(self.honored.values()))
415423

424+
def issubset(self, other):
425+
"""
426+
Check if `self` is a subset of `other`.
427+
"""
428+
if not isinstance(other, HaloScheme):
429+
return False
430+
431+
if not all(f in other.fmapper for f in self.fmapper):
432+
return False
433+
434+
for f, hse0 in self.fmapper.items():
435+
hse1 = other.fmapper[f]
436+
437+
# Clearly, `hse0`'s halos must be a subset of `hse1`'s halos...
438+
if not hse0.halos.issubset(hse1.halos) or \
439+
hse0.bundle is not hse1.bundle:
440+
return False
441+
442+
# But now, to be a subset, `hse0`'s must be expecting such halos
443+
# at a time index that is less than or equal to that of `hse1`
444+
if hse0.loc_dirs != hse1.loc_dirs:
445+
return False
446+
447+
loc_dirs = hse0.loc_dirs
448+
raw_loc_indices = {d: (hse0.loc_indices[d], hse1.loc_indices[d])
449+
for d in hse0.loc_indices}
450+
projected_loc_indices, _ = process_loc_indices(raw_loc_indices, loc_dirs)
451+
if projected_loc_indices != hse1.loc_indices:
452+
return False
453+
454+
return True
455+
416456
def project(self, functions):
417457
"""
418458
Create a new HaloScheme that only retains the HaloSchemeEntries corresponding

0 commit comments

Comments
 (0)