Skip to content

Commit 45bee36

Browse files
authored
Merge pull request #139 from mrava87/feat-asarraymasked
feat: added masked option to asarray
2 parents 9e25d8e + 5133c6b commit 45bee36

File tree

3 files changed

+67
-7
lines changed

3 files changed

+67
-7
lines changed

examples/plot_distributed_array.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,11 @@
140140
mask=mask)
141141
x[:] = (MPI.COMM_WORLD.Get_rank() % subsize + 1.) * np.ones(local_shape)
142142
xloc = x.asarray()
143+
xloc1 = x.asarray(masked=True)
144+
145+
if rank == 0:
146+
print('xloc (with repeated portions)', xloc)
147+
print('xloc (only effective portions):', xloc1)
143148

144149
# Dot product
145150
dot = x.dot(x)

pylops_mpi/DistributedArray.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -359,10 +359,15 @@ def sub_comm(self):
359359
"""
360360
return self._sub_comm
361361

362-
def asarray(self):
362+
def asarray(self, masked: bool = False):
363363
"""Global view of the array
364364
365-
Gather all the local arrays
365+
Gather all the local arrays from base communicator or subcommunicator
366+
367+
Parameters
368+
----------
369+
masked : :obj:`bool`
370+
Return local arrays of the subcommunicator (`True`) or base communicator (`False`).
366371
367372
Returns
368373
-------
@@ -375,10 +380,14 @@ def asarray(self):
375380
return self.local_array
376381

377382
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
378-
return nccl_asarray(self.base_comm_nccl, self.local_array, self.local_shapes, self.axis)
383+
return nccl_asarray(self.sub_comm if masked else self.base_comm,
384+
self.local_array, self.local_shapes, self.axis)
379385
else:
380386
# Gather all the local arrays and apply concatenation.
381-
final_array = self._allgather(self.local_array)
387+
if masked:
388+
final_array = self._allgather_subcomm(self.local_array)
389+
else:
390+
final_array = self._allgather(self.local_array)
382391
return np.concatenate(final_array, axis=self.axis)
383392

384393
@classmethod
@@ -503,6 +512,16 @@ def _allgather(self, send_buf, recv_buf=None):
503512
self.base_comm.Allgather(send_buf, recv_buf)
504513
return recv_buf
505514

515+
def _allgather_subcomm(self, send_buf, recv_buf=None):
516+
"""Allgather operation with subcommunicator
517+
"""
518+
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
519+
return nccl_allgather(self.sub_comm, send_buf, recv_buf)
520+
else:
521+
if recv_buf is None:
522+
return self.sub_comm.allgather(send_buf)
523+
self.sub_comm.Allgather(send_buf, recv_buf)
524+
506525
def _send(self, send_buf, dest, count=None, tag=None):
507526
""" Send operation
508527
"""

tests/test_distributedarray.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,12 +194,48 @@ def test_distributed_norm(par):
194194
assert_allclose(arr.norm(), np.linalg.norm(par['x'].flatten()), rtol=1e-13)
195195

196196

197+
@pytest.mark.mpi(min_size=2)
198+
@pytest.mark.parametrize("par", [(par6), (par8)])
199+
def test_distributed_masked(par):
200+
"""Test Asarray with masked array"""
201+
# Number of subcommunicators
202+
if MPI.COMM_WORLD.Get_size() % 2 == 0:
203+
nsub = 2
204+
elif MPI.COMM_WORLD.Get_size() % 3 == 0:
205+
nsub = 3
206+
else:
207+
pass
208+
subsize = max(1, MPI.COMM_WORLD.Get_size() // nsub)
209+
mask = np.repeat(np.arange(nsub), subsize)
210+
211+
# Replicate x as required in masked arrays
212+
x = par['x']
213+
if par['axis'] != 0:
214+
x = np.swapaxes(x, par['axis'], 0)
215+
for isub in range(1, nsub):
216+
x[(x.shape[0] // nsub) * isub:(x.shape[0] // nsub) * (isub + 1)] = x[:x.shape[0] // nsub]
217+
if par['axis'] != 0:
218+
x = np.swapaxes(x, 0, par['axis'])
219+
220+
arr = DistributedArray.to_dist(x=x, partition=par['partition'], mask=mask, axis=par['axis'])
221+
222+
# Global view
223+
xloc = arr.asarray()
224+
assert xloc.shape == x.shape
225+
226+
# Global masked view
227+
xmaskedloc = arr.asarray(masked=True)
228+
xmasked_shape = list(x.shape)
229+
xmasked_shape[par['axis']] = int(xmasked_shape[par['axis']] // nsub)
230+
assert xmaskedloc.shape == tuple(xmasked_shape)
231+
232+
197233
@pytest.mark.mpi(min_size=2)
198234
@pytest.mark.parametrize("par1, par2", [(par6, par7), (par6b, par7b),
199235
(par8, par9), (par8b, par9b)])
200236
def test_distributed_maskeddot(par1, par2):
201237
"""Test Distributed Dot product with masked array"""
202-
# number of subcommunicators
238+
# Number of subcommunicators
203239
if MPI.COMM_WORLD.Get_size() % 2 == 0:
204240
nsub = 2
205241
elif MPI.COMM_WORLD.Get_size() % 3 == 0:
@@ -208,7 +244,7 @@ def test_distributed_maskeddot(par1, par2):
208244
pass
209245
subsize = max(1, MPI.COMM_WORLD.Get_size() // nsub)
210246
mask = np.repeat(np.arange(nsub), subsize)
211-
print('subsize, mask', subsize, mask)
247+
212248
# Replicate x1 and x2 as required in masked arrays
213249
x1, x2 = par1['x'], par2['x']
214250
if par1['axis'] != 0:
@@ -234,7 +270,7 @@ def test_distributed_maskeddot(par1, par2):
234270
(par8), (par8b), (par9), (par9b)])
235271
def test_distributed_maskednorm(par):
236272
"""Test Distributed numpy.linalg.norm method with masked array"""
237-
# number of subcommunicators
273+
# Number of subcommunicators
238274
if MPI.COMM_WORLD.Get_size() % 2 == 0:
239275
nsub = 2
240276
elif MPI.COMM_WORLD.Get_size() % 3 == 0:

0 commit comments

Comments
 (0)