Skip to content

Commit 6d494db

Browse files
Merge pull request #2587 from devitocodes/hotfix-parlang-lowering-3
compiler: Improve quality of generated code with Bundles and MPI
2 parents fefcfba + fb0d903 commit 6d494db

34 files changed

+595
-209
lines changed

devito/ir/iet/nodes.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,16 @@ def ccode(self):
102102

103103
@property
104104
def view(self):
105-
"""A representation of the IET rooted in ``self``."""
105+
"""A high-level representation of the IET rooted in `self`."""
106106
from devito.ir.iet.visitors import printAST
107107
return printAST(self)
108108

109+
@property
110+
def view_cir(self):
111+
from devito.ir.iet.visitors import CGen
112+
from devito.passes.iet.languages.CIR import CIRPrinter
113+
return str(CGen(printer=CIRPrinter).visit(self))
114+
109115
@property
110116
def children(self):
111117
"""Return the traversable children."""
@@ -148,7 +154,7 @@ def writes(self):
148154
return ()
149155

150156
def _signature_items(self):
151-
return (str(self),)
157+
return (self.view_cir,)
152158

153159

154160
class ExprStmt:

devito/ir/support/basic.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from devito.ir.support.utils import AccessMode, extrema
99
from devito.ir.support.vector import LabeledVector, Vector
1010
from devito.symbolics import (compare_ops, retrieve_indexed, retrieve_terminals,
11-
q_constant, q_affine, q_routine, search, uxreplace)
11+
q_constant, q_comp_acc, q_affine, q_routine, search,
12+
uxreplace)
1213
from devito.tools import (Tag, as_mapper, as_tuple, is_integer, filter_sorted,
1314
flatten, memoized_meth, memoized_generator)
1415
from devito.types import (ComponentAccess, Dimension, DimensionTuple, Fence,
@@ -529,9 +530,16 @@ def __hash__(self):
529530
(self.source, self.sink, self.source.timestamp == self.sink.timestamp)
530531
)
531532

532-
@property
533+
@cached_property
533534
def function(self):
534-
return self.source.function
535+
if q_comp_acc(self.source.access) and not q_comp_acc(self.sink.access):
536+
# E.g., `source=ab[x].x` and `sink=ab[x]` -> `a(x)`
537+
return self.source.access.function_access
538+
elif not q_comp_acc(self.source.access) and q_comp_acc(self.sink.access):
539+
# E.g., `source=ab[x]` and `sink=ab[x].y` -> `b(x)`
540+
return self.sink.access.function_access
541+
else:
542+
return self.source.function
535543

536544
@property
537545
def findices(self):
@@ -955,7 +963,7 @@ def reads_gen(self):
955963
@memoized_generator
956964
def reads_smart_gen(self, f):
957965
"""
958-
Generate all read access to a given function.
966+
Generate all read accesses to a given function.
959967
960968
StencilDimensions, if any, are replaced with their extrema.
961969

devito/ir/support/properties.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,13 @@ def __init__(self, name, val=None):
8585
is used for iteration spaces that are larger than the data space.
8686
"""
8787

88+
INBOUND_IF_RELAXED = Property('inbound-if-relaxed')
89+
"""
90+
Similar to INBOUND, but devised for IncrDimensions whose iteration space is
91+
_not_ larger than the data space and, as such, still require full lowering
92+
(see the `relax_incr_dimensions` pass for more info).
93+
"""
94+
8895
PREFETCHABLE = Property('prefetchable')
8996
"""
9097
A Dimension along which prefetching is feasible and beneficial.
@@ -295,6 +302,9 @@ def block(self, dims, kind='default'):
295302
def inbound(self, dims):
296303
return self.add(dims, INBOUND)
297304

305+
def inbound_if_relaxed(self, dims):
306+
return self.add(dims, INBOUND_IF_RELAXED)
307+
298308
def init_core_shm(self, dims):
299309
properties = self.add(dims, INIT_CORE_SHM)
300310
properties = properties.drop(properties={INIT_HALO_LEFT_SHM,
@@ -327,7 +337,8 @@ def is_affine(self, dims):
327337
return any(AFFINE in self.get(d, ()) for d in as_tuple(dims))
328338

329339
def is_inbound(self, dims):
330-
return any(INBOUND in self.get(d, ()) for d in as_tuple(dims))
340+
return any({INBOUND, INBOUND_IF_RELAXED}.intersection(self.get(d, set()))
341+
for d in as_tuple(dims))
331342

332343
def is_sequential(self, dims):
333344
return any(SEQUENTIAL in self.get(d, ()) for d in as_tuple(dims))

devito/mpi/halo_scheme.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,18 @@ class HaloLabel(Tag):
3030
class HaloSchemeEntry(EnrichedTuple):
3131

3232
__rargs__ = ('loc_indices', 'loc_dirs', 'halos', 'dims')
33+
__rkwargs__ = ('bundle',)
3334

34-
def __new__(cls, loc_indices, loc_dirs, halos, dims, getters=None):
35+
def __new__(cls, loc_indices, loc_dirs, halos, dims, bundle=None, getters=None):
36+
getters = cls.__rargs__ + cls.__rkwargs__
3537
items = [frozendict(loc_indices), frozendict(loc_dirs),
36-
frozenset(halos), frozenset(dims)]
37-
kwargs = dict(zip(cls.__rargs__, items))
38-
return super().__new__(cls, *items, getters=cls.__rargs__, **kwargs)
38+
frozenset(halos), frozenset(dims), bundle]
39+
kwargs = dict(zip(getters, items))
40+
return super().__new__(cls, *items, getters=getters, **kwargs)
3941

4042
def __hash__(self):
41-
return hash((self.loc_indices, self.loc_dirs, self.halos, self.dims))
43+
return hash((self.loc_indices, self.loc_dirs, self.halos, self.dims,
44+
self.bundle))
4245

4346
def union(self, other):
4447
"""
@@ -47,7 +50,8 @@ def union(self, other):
4750
exception is raised.
4851
"""
4952
if self.loc_indices != other.loc_indices or \
50-
self.loc_dirs != other.loc_dirs:
53+
self.loc_dirs != other.loc_dirs or \
54+
self.bundle is not other.bundle:
5155
raise HaloSchemeException(
5256
"Inconsistency found while building a HaloScheme"
5357
)
@@ -56,7 +60,7 @@ def union(self, other):
5660
dims = self.dims | other.dims
5761

5862
return HaloSchemeEntry(self.loc_indices, self.loc_dirs, halos, dims,
59-
getters=self.getters)
63+
bundle=self.bundle, getters=self.getters)
6064

6165

6266
Halo = namedtuple('Halo', 'dim side')
@@ -168,7 +172,7 @@ def union(self, halo_schemes):
168172
elif not v.loc_indices or hse.loc_indices == v.loc_indices:
169173
loc_indices, loc_dirs = hse.loc_indices, hse.loc_dirs
170174
else:
171-
# The `loc_dirs` must match otherwise it'd be a symptom there's
175+
# These must match otherwise it'd be a symptom there's
172176
# something horribly broken elsewhere!
173177
assert hse.loc_dirs == v.loc_dirs
174178
assert list(hse.loc_indices) == list(v.loc_indices)
@@ -185,7 +189,11 @@ def union(self, halo_schemes):
185189
halos = hse.halos | v.halos
186190
dims = hse.dims | v.dims
187191

188-
fmapper[k] = HaloSchemeEntry(loc_indices, loc_dirs, halos, dims)
192+
assert hse.bundle is v.bundle
193+
194+
fmapper[k] = HaloSchemeEntry(
195+
loc_indices, loc_dirs, halos, dims, bundle=hse.bundle
196+
)
189197

190198
# Compute the `honored` union
191199
for d, v in i.honored.items():
@@ -641,8 +649,12 @@ def _uxreplace_dispatch_haloscheme(hs0, rule):
641649
for i, v in rule.items():
642650
if i is f:
643651
# Yes!
644-
g = v
645-
hse = hse0
652+
if v.is_Bundle:
653+
g = f
654+
hse = hse0._rebuild(bundle=v)
655+
else:
656+
g = v
657+
hse = hse0
646658

647659
elif i.is_Indexed and i.function is f and v.is_Indexed:
648660
# Yes, but through an Indexed, hence the `loc_indices` may now

devito/mpi/routines.py

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
from devito.symbolics import (Byref, CondNe, FieldFromPointer, FieldFromComposite,
1818
IndexedPointer, Macro, cast, subs_op_args)
1919
from devito.tools import (as_mapper, dtype_to_mpitype, dtype_len, infer_datasize,
20-
flatten, generator, is_integer, split)
21-
from devito.types import (Array, Bag, Dimension, Eq, Symbol, LocalObject,
22-
CompositeObject, CustomDimension)
20+
flatten, generator, is_integer)
21+
from devito.types import (Array, Bag, BundleView, Dimension, Eq, Symbol,
22+
LocalObject, CompositeObject, CustomDimension)
2323

2424
__all__ = ['HaloExchangeBuilder', 'ReductionBuilder', 'mpi_registry']
2525

@@ -292,19 +292,28 @@ def _make_bundles(self, hs):
292292

293293
mapper = as_mapper(halo_scheme.fmapper, lambda i: halo_scheme.fmapper[i])
294294
for hse, components in mapper.items():
295-
# We recast everything as Bags for simplicity -- worst case scenario
296-
# all Bags only have one component. Existing Bundles are preserved
297295
halo_scheme = halo_scheme.drop(components)
298-
bundles, candidates = split(tuple(components), lambda i: i.is_Bundle)
299-
for b in bundles:
300-
halo_scheme = halo_scheme.add(b, hse)
301296

297+
# Existing Bundles are preserved
298+
if hse.bundle:
299+
if set(components) == set(hse.bundle.components):
300+
halo_scheme = halo_scheme.add(hse.bundle, hse)
301+
else:
302+
name = f'bundleview_{hse.bundle.name}'
303+
bundle_view = BundleView(
304+
name=name, components=components, parent=hse.bundle
305+
)
306+
halo_scheme = halo_scheme.add(bundle_view, hse)
307+
continue
308+
309+
# We recast everything else as Bags for simplicity -- worst case
310+
# scenario all Bags only have one component.
302311
try:
303-
name = "bag_%s" % "".join(f.name for f in candidates)
304-
bag = Bag(name=name, components=candidates)
312+
name = "bag_%s" % "".join(f.name for f in components)
313+
bag = Bag(name=name, components=components)
305314
halo_scheme = halo_scheme.add(bag, hse)
306315
except ValueError:
307-
for i in candidates:
316+
for i in components:
308317
name = "bag_%s" % i.name
309318
bag = Bag(name=name, components=i)
310319
halo_scheme = halo_scheme.add(bag, hse)
@@ -363,10 +372,17 @@ def _make_copy(self, f, hse, key, swap=False):
363372
else:
364373
swap = lambda i, j: (j, i)
365374
name = 'scatter%s' % key
375+
366376
if isinstance(f, Bag):
367377
for i, c in enumerate(f.components):
368378
eqns.append(Eq(*swap(buf[[i] + bdims], c[findices])))
379+
elif isinstance(f, BundleView):
380+
assert f.parent is hse.bundle
381+
for i, c in enumerate(f.components):
382+
indices = [f.parent.components.index(c), *findices]
383+
eqns.append(Eq(*swap(buf[[i] + bdims], f.parent[indices])))
369384
else:
385+
assert f.is_Bundle
370386
for i in range(f.ncomp):
371387
eqns.append(Eq(*swap(buf[[i] + bdims], f[[i] + findices])))
372388

@@ -724,7 +740,7 @@ def _make_halowait(self, f, hse, key, wait, msg=None):
724740

725741
parameters = list(f.handles) + list(fixed.values()) + [nb, msg]
726742

727-
return Callable('halowait%d' % key, iet, 'void', parameters, ('static',))
743+
return HaloWait(f'halowait{key}', iet, parameters)
728744

729745
def _call_halowait(self, name, f, hse, msg):
730746
nb = f.grid.distributor._obj_neighborhood
@@ -763,7 +779,7 @@ def _make_region(self, hs, key):
763779
def _make_msg(self, f, hse, key):
764780
# Only retain the halos required by the Diag scheme
765781
halos = sorted(i for i in hse.halos if isinstance(i.dim, tuple))
766-
return MPIMsgEnriched('msg%d' % key, f, halos)
782+
return MPIMsgEnriched(f'msg{key}', f, halos)
767783

768784
def _make_sendrecv(self, *args, **kwargs):
769785
return
@@ -852,7 +868,7 @@ def _make_halowait(self, f, hse, key, *args, msg=None):
852868
ncomms = Symbol(name='ncomms')
853869
iet = Iteration([waitsend, waitrecv, scatter], dim, ncomms - 1)
854870
parameters = f.handles + tuple(fixed.values()) + (msg, ncomms)
855-
return Callable('halowait%d' % key, iet, 'void', parameters, ('static',))
871+
return HaloWait(f'halowait{key}', iet, parameters)
856872

857873
def _call_halowait(self, name, f, hse, msg):
858874
args = f.handles + tuple(hse.loc_indices.values()) + (msg, msg.npeers)
@@ -1034,9 +1050,11 @@ def __init__(self, name, body, parameters, bufg, bufs):
10341050

10351051

10361052
class HaloUpdate(MPICallable):
1053+
pass
10371054

1038-
def __init__(self, name, body, parameters):
1039-
super().__init__(name, body, parameters)
1055+
1056+
class HaloWait(MPICallable):
1057+
pass
10401058

10411059

10421060
class Remainder(ElementalFunction):
@@ -1238,12 +1256,14 @@ class MPIMsgEnriched(MPIMsg):
12381256
_C_field_ofsg = 'ofsg'
12391257
_C_field_from = 'fromrank'
12401258
_C_field_to = 'torank'
1259+
_C_field_components = 'components'
12411260

12421261
fields = MPIMsg.fields + [
12431262
(_C_field_ofss, POINTER(c_int)),
12441263
(_C_field_ofsg, POINTER(c_int)),
12451264
(_C_field_from, c_int),
1246-
(_C_field_to, c_int)
1265+
(_C_field_to, c_int),
1266+
(_C_field_components, POINTER(c_int)),
12471267
]
12481268

12491269
def _arg_defaults(self, allocator, alias=None, args=None):
@@ -1282,6 +1302,17 @@ def _arg_defaults(self, allocator, alias=None, args=None):
12821302
ofss.append(f._offset_owned[dim].left)
12831303
entry.ofss = (c_int*len(ofss))(*ofss)
12841304

1305+
# Track the component accesses for packing/unpacking as numbers
1306+
# representing the field being accessed (that is: .x -> 0, .y -> 1,
1307+
# .z -> 2, .w -> 3), if any
1308+
if isinstance(self.target, BundleView):
1309+
ncomp = self.target.ncomp
1310+
component_indices = self.target.component_indices
1311+
entry.components = (c_int*ncomp)(*component_indices)
1312+
elif self.target.is_Bundle:
1313+
ncomp = self.target.ncomp
1314+
entry.components = (c_int*ncomp)(*range(ncomp))
1315+
12851316
return {self.name: self.value}
12861317

12871318

devito/operator/operator.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
from devito.data import default_allocator
1414
from devito.exceptions import (CompilationError, ExecutionError, InvalidArgument,
1515
InvalidOperator)
16-
from devito.logger import debug, info, perf, warning, is_log_enabled_for, switch_log_level
16+
from devito.logger import (debug, info, perf, warning, is_log_enabled_for,
17+
switch_log_level)
1718
from devito.ir.equations import LoweredEq, lower_exprs, concretize_subdims
1819
from devito.ir.clusters import ClusterGroup, clusterize
19-
from devito.ir.iet import (Callable, CInterface, EntryFunction, FindSymbols, MetaCall,
20-
derive_parameters, iet_build)
20+
from devito.ir.iet import (Callable, CInterface, EntryFunction, FindSymbols,
21+
MetaCall, derive_parameters, iet_build)
2122
from devito.ir.support import AccessMode, SymbolRegistry
2223
from devito.ir.stree import stree_build
2324
from devito.operator.profiling import create_profile
@@ -26,8 +27,7 @@
2627
from devito.parameters import configuration
2728
from devito.passes import (Graph, lower_index_derivatives, generate_implicit,
2829
generate_macros, minimize_symbols, unevaluate,
29-
error_mapper, is_on_device)
30-
from devito.passes.iet.dtypes import lower_dtypes
30+
error_mapper, is_on_device, lower_dtypes)
3131
from devito.symbolics import estimate_cost, subs_op_args
3232
from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_mapper, as_tuple,
3333
flatten, filter_sorted, frozendict, is_integer,
@@ -488,7 +488,7 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs):
488488
# Extract the necessary macros from the symbolic objects
489489
generate_macros(graph, **kwargs)
490490

491-
# Add type specific metadata
491+
# Target-specific lowering
492492
lower_dtypes(graph, **kwargs)
493493

494494
# Target-independent optimizations

devito/passes/clusters/buffering.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,7 @@ class InjectBuffers(Queue):
123123
def __init__(self, mapper, sregistry, options):
124124
super().__init__()
125125

126-
# Sort the mapper so that we always process the same Function in the
127-
# same order, hence we get deterministic code generation
128-
self.mapper = {i: mapper[i] for i in sorted(mapper, key=lambda i: i.name)}
126+
self.mapper = mapper
129127

130128
self.sregistry = sregistry
131129
self.options = options
@@ -302,6 +300,9 @@ def generate_buffers(clusters, key, sregistry, options, **kwargs):
302300
# {candidate buffered Function -> [Clusters that access it]}
303301
bfmap = map_buffered_functions(clusters, key)
304302

303+
# Sort for deterministic code generation
304+
bfmap = {i: bfmap[i] for i in sorted(bfmap, key=lambda i: i.name)}
305+
305306
# {buffered Function -> Buffer}
306307
xds = {}
307308
mapper = {}
@@ -718,7 +719,7 @@ def offset_from_centre(d, indices):
718719
# `time/factor` -- the starting pointing at time_m or time_M
719720
v = indices[0]
720721
try:
721-
p = sum(v.args[1:])
722+
p = v.func(*[i for i in v.args if not is_integer(i)])
722723
if not ((p - v).is_Integer or (p - v).is_Symbol):
723724
raise ValueError
724725
except (IndexError, ValueError):

0 commit comments

Comments
 (0)