@@ -307,6 +307,42 @@ def test_distributed_norm_nccl(par):
307307 assert_allclose (arr .norm ().get (), np .linalg .norm (par ["x" ].flatten ()), rtol = 1e-13 )
308308
309309
310+ @pytest .mark .mpi (min_size = 2 )
311+ @pytest .mark .parametrize ("par" , [(par6 ), (par8 )])
312+ def test_distributed_masked_nccl (par ):
313+ """Test Asarray with masked array"""
314+ # Number of subcommunicators
315+ if MPI .COMM_WORLD .Get_size () % 2 == 0 :
316+ nsub = 2
317+ elif MPI .COMM_WORLD .Get_size () % 3 == 0 :
318+ nsub = 3
319+ else :
320+ pass
321+ subsize = max (1 , MPI .COMM_WORLD .Get_size () // nsub )
322+ mask = np .repeat (np .arange (nsub ), subsize )
323+
324+ # Replicate x as required in masked arrays
325+ x_gpu = cp .asarray (par ['x' ])
326+ if par ['axis' ] != 0 :
327+ x_gpu = cp .swapaxes (x_gpu , par ['axis' ], 0 )
328+ for isub in range (1 , nsub ):
329+ x_gpu [(x_gpu .shape [0 ] // nsub ) * isub :(x_gpu .shape [0 ] // nsub ) * (isub + 1 )] = x_gpu [:x_gpu .shape [0 ] // nsub ]
330+ if par ['axis' ] != 0 :
331+ x_gpu = np .swapaxes (x_gpu , 0 , par ['axis' ])
332+
333+ arr = DistributedArray .to_dist (x = x_gpu , base_comm_nccl = nccl_comm , partition = par ['partition' ], mask = mask , axis = par ['axis' ])
334+
335+ # Global view
336+ xloc = arr .asarray ()
337+ assert xloc .shape == x_gpu .shape
338+
339+ # Global masked view
340+ xmaskedloc = arr .asarray (masked = True )
341+ xmasked_shape = list (x_gpu .shape )
342+ xmasked_shape [par ['axis' ]] = int (xmasked_shape [par ['axis' ]] // nsub )
343+ assert xmaskedloc .shape == tuple (xmasked_shape )
344+
345+
310346@pytest .mark .mpi (min_size = 2 )
311347@pytest .mark .parametrize (
312348 "par1, par2" , [(par6 , par7 ), (par6b , par7b ), (par8 , par9 ), (par8b , par9b )]
0 commit comments