@@ -194,12 +194,48 @@ def test_distributed_norm(par):
194194 assert_allclose (arr .norm (), np .linalg .norm (par ['x' ].flatten ()), rtol = 1e-13 )
195195
196196
197+ @pytest .mark .mpi (min_size = 2 )
198+ @pytest .mark .parametrize ("par" , [(par6 ), (par8 )])
199+ def test_distributed_masked (par ):
200+ """Test Asarray with masked array"""
201+ # Number of subcommunicators
202+ if MPI .COMM_WORLD .Get_size () % 2 == 0 :
203+ nsub = 2
204+ elif MPI .COMM_WORLD .Get_size () % 3 == 0 :
205+ nsub = 3
206+ else :
207+ pass
208+ subsize = max (1 , MPI .COMM_WORLD .Get_size () // nsub )
209+ mask = np .repeat (np .arange (nsub ), subsize )
210+
211+ # Replicate x as required in masked arrays
212+ x = par ['x' ]
213+ if par ['axis' ] != 0 :
214+ x = np .swapaxes (x , par ['axis' ], 0 )
215+ for isub in range (1 , nsub ):
216+ x [(x .shape [0 ] // nsub ) * isub :(x .shape [0 ] // nsub ) * (isub + 1 )] = x [:x .shape [0 ] // nsub ]
217+ if par ['axis' ] != 0 :
218+ x = np .swapaxes (x , 0 , par ['axis' ])
219+
220+ arr = DistributedArray .to_dist (x = x , partition = par ['partition' ], mask = mask , axis = par ['axis' ])
221+
222+ # Global view
223+ xloc = arr .asarray ()
224+ assert xloc .shape == x .shape
225+
226+ # Global masked view
227+ xmaskedloc = arr .asarray (masked = True )
228+ xmasked_shape = list (x .shape )
229+ xmasked_shape [par ['axis' ]] = int (xmasked_shape [par ['axis' ]] // nsub )
230+ assert xmaskedloc .shape == tuple (xmasked_shape )
231+
232+
197233@pytest .mark .mpi (min_size = 2 )
198234@pytest .mark .parametrize ("par1, par2" , [(par6 , par7 ), (par6b , par7b ),
199235 (par8 , par9 ), (par8b , par9b )])
200236def test_distributed_maskeddot (par1 , par2 ):
201237 """Test Distributed Dot product with masked array"""
202- # number of subcommunicators
238+ # Number of subcommunicators
203239 if MPI .COMM_WORLD .Get_size () % 2 == 0 :
204240 nsub = 2
205241 elif MPI .COMM_WORLD .Get_size () % 3 == 0 :
@@ -208,7 +244,7 @@ def test_distributed_maskeddot(par1, par2):
208244 pass
209245 subsize = max (1 , MPI .COMM_WORLD .Get_size () // nsub )
210246 mask = np .repeat (np .arange (nsub ), subsize )
211- print ( 'subsize, mask' , subsize , mask )
247+
212248 # Replicate x1 and x2 as required in masked arrays
213249 x1 , x2 = par1 ['x' ], par2 ['x' ]
214250 if par1 ['axis' ] != 0 :
@@ -234,7 +270,7 @@ def test_distributed_maskeddot(par1, par2):
234270 (par8 ), (par8b ), (par9 ), (par9b )])
235271def test_distributed_maskednorm (par ):
236272 """Test Distributed numpy.linalg.norm method with masked array"""
237- # number of subcommunicators
273+ # Number of subcommunicators
238274 if MPI .COMM_WORLD .Get_size () % 2 == 0 :
239275 nsub = 2
240276 elif MPI .COMM_WORLD .Get_size () % 3 == 0 :
0 commit comments