Skip to content

Commit 8f9bc6f

Browse files
authored
Merge pull request #2673 from devitocodes/allreduce-v4
mpi: fix multi conditional for allreduce
2 parents 1cfe91c + a9276c1 commit 8f9bc6f

File tree

2 files changed

+68
-8
lines changed

2 files changed

+68
-8
lines changed

devito/ir/clusters/algorithms.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -453,13 +453,18 @@ def callback(self, clusters, prefix, seen=None):
453453
def reduction_comms(clusters):
454454
processed = []
455455
fifo = []
456+
457+
def _update(reductions):
458+
for _, reds in groupby(reductions, key=lambda r: r.ispace):
459+
reds = list(reds)
460+
exprs = flatten([dr.exprs for dr in reds])
461+
processed.append(reds[0].rebuild(exprs=exprs))
462+
456463
for c in clusters:
457464
# Schedule the global distributed reductions encountered before `c`,
458465
# if `c`'s IterationSpace is such that the reduction can be carried out
459466
found, fifo = split(fifo, lambda dr: dr.ispace.is_subset(c.ispace))
460-
for ispace, reds in groupby(found, key=lambda r: r.ispace):
461-
exprs = flatten([dr.exprs for dr in reds])
462-
processed.append(c.rebuild(exprs=exprs, ispace=ispace))
467+
_update(found)
463468

464469
# Detect the global distributed reductions in `c`
465470
for e in c.exprs:
@@ -494,10 +499,7 @@ def reduction_comms(clusters):
494499
processed.append(c)
495500

496501
# Leftover reductions are placed at the very end
497-
for ispace, reds in groupby(fifo, key=lambda r: r.ispace):
498-
reds = list(reds)
499-
exprs = flatten([dr.exprs for dr in reds])
500-
processed.append(reds[0].rebuild(exprs=exprs, ispace=ispace))
502+
_update(fifo)
501503

502504
return processed
503505

tests/test_mpi.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2068,7 +2068,7 @@ def test_multi_allreduce_time(self, mode):
20682068
assert np.isclose(np.max(g.data), 4356.0)
20692069
assert np.isclose(np.max(h.data), 4356.0)
20702070

2071-
@pytest.mark.parallel(mode=1)
2071+
@pytest.mark.parallel(mode=2)
20722072
def test_multi_allreduce_time_cond(self, mode):
20732073
space_order = 8
20742074
nx, ny = 11, 11
@@ -2101,6 +2101,64 @@ def test_multi_allreduce_time_cond(self, mode):
21012101
assert np.allclose(g.data, expected)
21022102
assert np.allclose(h.data, expected)
21032103

2104+
@pytest.mark.parallel(mode=2)
2105+
def test_allreduce_multicond(self, mode):
2106+
space_order = 8
2107+
nx, ny = 11, 11
2108+
2109+
grid = Grid(shape=(nx, ny))
2110+
tt = grid.time_dim
2111+
nt = 20
2112+
ct = ConditionalDimension(name="ct", parent=tt, factor=2)
2113+
ct2 = ConditionalDimension(name="ct2", parent=tt, factor=8)
2114+
2115+
ux = TimeFunction(name="ux", grid=grid, time_order=1, space_order=space_order)
2116+
uy = TimeFunction(name="uy", grid=grid, time_order=1, space_order=space_order)
2117+
g = TimeFunction(name="g", grid=grid, dimensions=(ct, ), shape=(int(nt/2),),
2118+
time_dim=ct)
2119+
h = TimeFunction(name="h", grid=grid, dimensions=(ct, ), shape=(int(nt/2),),
2120+
time_dim=ct)
2121+
2122+
op = Operator([Eq(g, 0), Eq(ux.forward, tt), Inc(g, ux), Inc(h, ux),
2123+
Eq(uy, tt, implicit_dims=ct2)],
2124+
name="Op")
2125+
assert_structure(op, ['t', 't,x,y', 't,x,y', 't,x,y'], 'txyxyxy')
2126+
2127+
# Make sure the two allreduce calls are in the time the loop
2128+
iters = FindNodes(Iteration).visit(op)
2129+
for i in iters:
2130+
if i.dim.is_Time:
2131+
assert len(FindNodes(Call).visit(i)) == 2 # Two allreduce
2132+
else:
2133+
assert len(FindNodes(Call).visit(i)) == 0
2134+
2135+
# Check conditionals
2136+
conds = FindNodes(Conditional).visit(op)
2137+
assert len(conds) == 3
2138+
# First one is just g initialization
2139+
sym0 = FindSymbols().visit(conds[0])
2140+
assert set(sym0) == {ct.symbolic_factor, tt, g}
2141+
assert grid.distributor._obj_comm not in sym0
2142+
# Second one is g and h and allreduce
2143+
sym1 = FindSymbols().visit(conds[1])
2144+
assert g in sym1
2145+
assert h in sym1
2146+
assert ux in sym1
2147+
# The allreduce
2148+
assert grid.distributor._obj_comm in sym1
2149+
# Last one is only uy
2150+
sym2 = FindSymbols().visit(conds[-1])
2151+
assert g not in sym2
2152+
assert h not in sym2
2153+
assert uy in sym2
2154+
assert grid.distributor._obj_comm not in sym2
2155+
2156+
op.apply(time_m=0, time_M=nt-1)
2157+
2158+
expected = [nx * ny * max(t-1, 0) for t in range(0, nt, 2)]
2159+
assert np.allclose(g.data, expected)
2160+
assert np.allclose(h.data, expected)
2161+
21042162

21052163
class TestOperatorAdvanced:
21062164

0 commit comments

Comments
 (0)