@@ -194,12 +194,48 @@ def test_distributed_norm(par):
194194    assert_allclose (arr .norm (), np .linalg .norm (par ['x' ].flatten ()), rtol = 1e-13 )
195195
196196
197+ @pytest .mark .mpi (min_size = 2 ) 
198+ @pytest .mark .parametrize ("par" , [(par6 ), (par8 )]) 
199+ def  test_distributed_masked (par ):
200+     """Test Asarray with masked array""" 
201+     # Number of subcommunicators 
202+     if  MPI .COMM_WORLD .Get_size () %  2  ==  0 :
203+         nsub  =  2 
204+     elif  MPI .COMM_WORLD .Get_size () %  3  ==  0 :
205+         nsub  =  3 
206+     else :
207+         pass 
208+     subsize  =  max (1 , MPI .COMM_WORLD .Get_size () //  nsub )
209+     mask  =  np .repeat (np .arange (nsub ), subsize )
210+ 
211+     # Replicate x as required in masked arrays 
212+     x  =  par ['x' ]
213+     if  par ['axis' ] !=  0 :
214+         x  =  np .swapaxes (x , par ['axis' ], 0 )
215+     for  isub  in  range (1 , nsub ):
216+         x [(x .shape [0 ] //  nsub ) *  isub :(x .shape [0 ] //  nsub ) *  (isub  +  1 )] =  x [:x .shape [0 ] //  nsub ]
217+     if  par ['axis' ] !=  0 :
218+         x  =  np .swapaxes (x , 0 , par ['axis' ])
219+ 
220+     arr  =  DistributedArray .to_dist (x = x , partition = par ['partition' ], mask = mask , axis = par ['axis' ])
221+ 
222+     # Global view 
223+     xloc  =  arr .asarray ()
224+     assert  xloc .shape  ==  x .shape 
225+ 
226+     # Global masked view 
227+     xmaskedloc  =  arr .asarray (masked = True )
228+     xmasked_shape  =  list (x .shape )
229+     xmasked_shape [par ['axis' ]] =  int (xmasked_shape [par ['axis' ]] //  nsub )
230+     assert  xmaskedloc .shape  ==  tuple (xmasked_shape )
231+ 
232+ 
197233@pytest .mark .mpi (min_size = 2 ) 
198234@pytest .mark .parametrize ("par1, par2" , [(par6 , par7 ), (par6b , par7b ), 
199235                                        (par8 , par9 ), (par8b , par9b )]) 
200236def  test_distributed_maskeddot (par1 , par2 ):
201237    """Test Distributed Dot product with masked array""" 
202-     # number  of subcommunicators 
238+     # Number  of subcommunicators 
203239    if  MPI .COMM_WORLD .Get_size () %  2  ==  0 :
204240        nsub  =  2 
205241    elif  MPI .COMM_WORLD .Get_size () %  3  ==  0 :
@@ -208,7 +244,7 @@ def test_distributed_maskeddot(par1, par2):
208244        pass 
209245    subsize  =  max (1 , MPI .COMM_WORLD .Get_size () //  nsub )
210246    mask  =  np .repeat (np .arange (nsub ), subsize )
211-      print ( 'subsize, mask' ,  subsize ,  mask ) 
247+ 
212248    # Replicate x1 and x2 as required in masked arrays 
213249    x1 , x2  =  par1 ['x' ], par2 ['x' ]
214250    if  par1 ['axis' ] !=  0 :
@@ -234,7 +270,7 @@ def test_distributed_maskeddot(par1, par2):
234270                                 (par8 ), (par8b ), (par9 ), (par9b )]) 
235271def  test_distributed_maskednorm (par ):
236272    """Test Distributed numpy.linalg.norm method with masked array""" 
237-     # number  of subcommunicators 
273+     # Number  of subcommunicators 
238274    if  MPI .COMM_WORLD .Get_size () %  2  ==  0 :
239275        nsub  =  2 
240276    elif  MPI .COMM_WORLD .Get_size () %  3  ==  0 :
0 commit comments