@@ -2068,7 +2068,7 @@ def test_multi_allreduce_time(self, mode):
20682068 assert np .isclose (np .max (g .data ), 4356.0 )
20692069 assert np .isclose (np .max (h .data ), 4356.0 )
20702070
2071- @pytest .mark .parallel (mode = 1 )
2071+ @pytest .mark .parallel (mode = 2 )
20722072 def test_multi_allreduce_time_cond (self , mode ):
20732073 space_order = 8
20742074 nx , ny = 11 , 11
@@ -2101,6 +2101,64 @@ def test_multi_allreduce_time_cond(self, mode):
21012101 assert np .allclose (g .data , expected )
21022102 assert np .allclose (h .data , expected )
21032103
2104+ @pytest .mark .parallel (mode = 2 )
2105+ def test_allreduce_multicond (self , mode ):
2106+ space_order = 8
2107+ nx , ny = 11 , 11
2108+
2109+ grid = Grid (shape = (nx , ny ))
2110+ tt = grid .time_dim
2111+ nt = 20
2112+ ct = ConditionalDimension (name = "ct" , parent = tt , factor = 2 )
2113+ ct2 = ConditionalDimension (name = "ct2" , parent = tt , factor = 8 )
2114+
2115+ ux = TimeFunction (name = "ux" , grid = grid , time_order = 1 , space_order = space_order )
2116+ uy = TimeFunction (name = "uy" , grid = grid , time_order = 1 , space_order = space_order )
2117+ g = TimeFunction (name = "g" , grid = grid , dimensions = (ct , ), shape = (int (nt / 2 ),),
2118+ time_dim = ct )
2119+ h = TimeFunction (name = "h" , grid = grid , dimensions = (ct , ), shape = (int (nt / 2 ),),
2120+ time_dim = ct )
2121+
2122+ op = Operator ([Eq (g , 0 ), Eq (ux .forward , tt ), Inc (g , ux ), Inc (h , ux ),
2123+ Eq (uy , tt , implicit_dims = ct2 )],
2124+ name = "Op" )
2125+ assert_structure (op , ['t' , 't,x,y' , 't,x,y' , 't,x,y' ], 'txyxyxy' )
2126+
2127+ # Make sure the two allreduce calls are in the time the loop
2128+ iters = FindNodes (Iteration ).visit (op )
2129+ for i in iters :
2130+ if i .dim .is_Time :
2131+ assert len (FindNodes (Call ).visit (i )) == 2 # Two allreduce
2132+ else :
2133+ assert len (FindNodes (Call ).visit (i )) == 0
2134+
2135+ # Check conditionals
2136+ conds = FindNodes (Conditional ).visit (op )
2137+ assert len (conds ) == 3
2138+ # First one is just g initialization
2139+ sym0 = FindSymbols ().visit (conds [0 ])
2140+ assert set (sym0 ) == {ct .symbolic_factor , tt , g }
2141+ assert grid .distributor ._obj_comm not in sym0
2142+ # Second one is g and h and allreduce
2143+ sym1 = FindSymbols ().visit (conds [1 ])
2144+ assert g in sym1
2145+ assert h in sym1
2146+ assert ux in sym1
2147+ # The allreduce
2148+ assert grid .distributor ._obj_comm in sym1
2149+ # Last one is only uy
2150+ sym2 = FindSymbols ().visit (conds [- 1 ])
2151+ assert g not in sym2
2152+ assert h not in sym2
2153+ assert uy in sym2
2154+ assert grid .distributor ._obj_comm not in sym2
2155+
2156+ op .apply (time_m = 0 , time_M = nt - 1 )
2157+
2158+ expected = [nx * ny * max (t - 1 , 0 ) for t in range (0 , nt , 2 )]
2159+ assert np .allclose (g .data , expected )
2160+ assert np .allclose (h .data , expected )
2161+
21042162
21052163class TestOperatorAdvanced :
21062164
0 commit comments