|
17 | 17 | from devito.symbolics import (Byref, CondNe, FieldFromPointer, FieldFromComposite, |
18 | 18 | IndexedPointer, Macro, cast, subs_op_args) |
19 | 19 | from devito.tools import (as_mapper, dtype_to_mpitype, dtype_len, infer_datasize, |
20 | | - flatten, generator, is_integer, split) |
21 | | -from devito.types import (Array, Bag, Dimension, Eq, Symbol, LocalObject, |
22 | | - CompositeObject, CustomDimension) |
| 20 | + flatten, generator, is_integer) |
| 21 | +from devito.types import (Array, Bag, BundleView, Dimension, Eq, Symbol, |
| 22 | + LocalObject, CompositeObject, CustomDimension) |
23 | 23 |
|
24 | 24 | __all__ = ['HaloExchangeBuilder', 'ReductionBuilder', 'mpi_registry'] |
25 | 25 |
|
@@ -292,19 +292,28 @@ def _make_bundles(self, hs): |
292 | 292 |
|
293 | 293 | mapper = as_mapper(halo_scheme.fmapper, lambda i: halo_scheme.fmapper[i]) |
294 | 294 | for hse, components in mapper.items(): |
295 | | - # We recast everything as Bags for simplicity -- worst case scenario |
296 | | - # all Bags only have one component. Existing Bundles are preserved |
297 | 295 | halo_scheme = halo_scheme.drop(components) |
298 | | - bundles, candidates = split(tuple(components), lambda i: i.is_Bundle) |
299 | | - for b in bundles: |
300 | | - halo_scheme = halo_scheme.add(b, hse) |
301 | 296 |
|
| 297 | + # Existing Bundles are preserved |
| 298 | + if hse.bundle: |
| 299 | + if set(components) == set(hse.bundle.components): |
| 300 | + halo_scheme = halo_scheme.add(hse.bundle, hse) |
| 301 | + else: |
| 302 | + name = f'bundleview_{hse.bundle.name}' |
| 303 | + bundle_view = BundleView( |
| 304 | + name=name, components=components, parent=hse.bundle |
| 305 | + ) |
| 306 | + halo_scheme = halo_scheme.add(bundle_view, hse) |
| 307 | + continue |
| 308 | + |
| 309 | + # We recast everything else as Bags for simplicity -- worst case |
| 310 | + # scenario all Bags only have one component. |
302 | 311 | try: |
303 | | - name = "bag_%s" % "".join(f.name for f in candidates) |
304 | | - bag = Bag(name=name, components=candidates) |
| 312 | + name = "bag_%s" % "".join(f.name for f in components) |
| 313 | + bag = Bag(name=name, components=components) |
305 | 314 | halo_scheme = halo_scheme.add(bag, hse) |
306 | 315 | except ValueError: |
307 | | - for i in candidates: |
| 316 | + for i in components: |
308 | 317 | name = "bag_%s" % i.name |
309 | 318 | bag = Bag(name=name, components=i) |
310 | 319 | halo_scheme = halo_scheme.add(bag, hse) |
@@ -363,10 +372,17 @@ def _make_copy(self, f, hse, key, swap=False): |
363 | 372 | else: |
364 | 373 | swap = lambda i, j: (j, i) |
365 | 374 | name = 'scatter%s' % key |
| 375 | + |
366 | 376 | if isinstance(f, Bag): |
367 | 377 | for i, c in enumerate(f.components): |
368 | 378 | eqns.append(Eq(*swap(buf[[i] + bdims], c[findices]))) |
| 379 | + elif isinstance(f, BundleView): |
| 380 | + assert f.parent is hse.bundle |
| 381 | + for i, c in enumerate(f.components): |
| 382 | + indices = [f.parent.components.index(c), *findices] |
| 383 | + eqns.append(Eq(*swap(buf[[i] + bdims], f.parent[indices]))) |
369 | 384 | else: |
| 385 | + assert f.is_Bundle |
370 | 386 | for i in range(f.ncomp): |
371 | 387 | eqns.append(Eq(*swap(buf[[i] + bdims], f[[i] + findices]))) |
372 | 388 |
|
@@ -724,7 +740,7 @@ def _make_halowait(self, f, hse, key, wait, msg=None): |
724 | 740 |
|
725 | 741 | parameters = list(f.handles) + list(fixed.values()) + [nb, msg] |
726 | 742 |
|
727 | | - return Callable('halowait%d' % key, iet, 'void', parameters, ('static',)) |
| 743 | + return HaloWait(f'halowait{key}', iet, parameters) |
728 | 744 |
|
729 | 745 | def _call_halowait(self, name, f, hse, msg): |
730 | 746 | nb = f.grid.distributor._obj_neighborhood |
@@ -763,7 +779,7 @@ def _make_region(self, hs, key): |
763 | 779 | def _make_msg(self, f, hse, key): |
764 | 780 | # Only retain the halos required by the Diag scheme |
765 | 781 | halos = sorted(i for i in hse.halos if isinstance(i.dim, tuple)) |
766 | | - return MPIMsgEnriched('msg%d' % key, f, halos) |
| 782 | + return MPIMsgEnriched(f'msg{key}', f, halos) |
767 | 783 |
|
768 | 784 | def _make_sendrecv(self, *args, **kwargs): |
769 | 785 | return |
@@ -852,7 +868,7 @@ def _make_halowait(self, f, hse, key, *args, msg=None): |
852 | 868 | ncomms = Symbol(name='ncomms') |
853 | 869 | iet = Iteration([waitsend, waitrecv, scatter], dim, ncomms - 1) |
854 | 870 | parameters = f.handles + tuple(fixed.values()) + (msg, ncomms) |
855 | | - return Callable('halowait%d' % key, iet, 'void', parameters, ('static',)) |
| 871 | + return HaloWait(f'halowait{key}', iet, parameters) |
856 | 872 |
|
857 | 873 | def _call_halowait(self, name, f, hse, msg): |
858 | 874 | args = f.handles + tuple(hse.loc_indices.values()) + (msg, msg.npeers) |
@@ -1034,9 +1050,11 @@ def __init__(self, name, body, parameters, bufg, bufs): |
1034 | 1050 |
|
1035 | 1051 |
|
1036 | 1052 | class HaloUpdate(MPICallable): |
| 1053 | + pass |
1037 | 1054 |
|
1038 | | - def __init__(self, name, body, parameters): |
1039 | | - super().__init__(name, body, parameters) |
| 1055 | + |
| 1056 | +class HaloWait(MPICallable): |
| 1057 | + pass |
1040 | 1058 |
|
1041 | 1059 |
|
1042 | 1060 | class Remainder(ElementalFunction): |
@@ -1238,12 +1256,14 @@ class MPIMsgEnriched(MPIMsg): |
1238 | 1256 | _C_field_ofsg = 'ofsg' |
1239 | 1257 | _C_field_from = 'fromrank' |
1240 | 1258 | _C_field_to = 'torank' |
| 1259 | + _C_field_components = 'components' |
1241 | 1260 |
|
1242 | 1261 | fields = MPIMsg.fields + [ |
1243 | 1262 | (_C_field_ofss, POINTER(c_int)), |
1244 | 1263 | (_C_field_ofsg, POINTER(c_int)), |
1245 | 1264 | (_C_field_from, c_int), |
1246 | | - (_C_field_to, c_int) |
| 1265 | + (_C_field_to, c_int), |
| 1266 | + (_C_field_components, POINTER(c_int)), |
1247 | 1267 | ] |
1248 | 1268 |
|
1249 | 1269 | def _arg_defaults(self, allocator, alias=None, args=None): |
@@ -1282,6 +1302,17 @@ def _arg_defaults(self, allocator, alias=None, args=None): |
1282 | 1302 | ofss.append(f._offset_owned[dim].left) |
1283 | 1303 | entry.ofss = (c_int*len(ofss))(*ofss) |
1284 | 1304 |
|
| 1305 | + # Track the component accesses for packing/unpacking as numbers |
| 1306 | + # representing the field being accessed (that is: .x -> 0, .y -> 1, |
| 1307 | + # .z -> 2, .w -> 3), if any |
| 1308 | + if isinstance(self.target, BundleView): |
| 1309 | + ncomp = self.target.ncomp |
| 1310 | + component_indices = self.target.component_indices |
| 1311 | + entry.components = (c_int*ncomp)(*component_indices) |
| 1312 | + elif self.target.is_Bundle: |
| 1313 | + ncomp = self.target.ncomp |
| 1314 | + entry.components = (c_int*ncomp)(*range(ncomp)) |
| 1315 | + |
1285 | 1316 | return {self.name: self.value} |
1286 | 1317 |
|
1287 | 1318 |
|
|
0 commit comments