Skip to content

Commit 05ff24e

Browse files
committed
feat: added proposal for UNSAFE_BROADCAST partition
1 parent f6dd7a9 commit 05ff24e

File tree

1 file changed

+22
-17
lines changed

1 file changed

+22
-17
lines changed

pylops_mpi/DistributedArray.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,14 @@ class Partition(Enum):
1414
1515
Distributing data among different processes.
1616
17-
- ``BROADCAST``: Distributes data to all processes.
17+
- ``BROADCAST``: Distributes data to all processes
18+
(ensuring that data is kept consistent across processes)
19+
- ``UNSAFE_BROADCAST``: Distributes data to all processes
20+
(without ensuring that data is kept consistent across processes)
1821
- ``SCATTER``: Distributes unique portions to each process.
1922
"""
2023
BROADCAST = "Broadcast"
24+
UNSAFE_BROADCAST = "UnsafeBroadcast"
2125
SCATTER = "Scatter"
2226

2327

@@ -41,7 +45,7 @@ def local_split(global_shape: Tuple, base_comm: MPI.Comm,
4145
local_shape : :obj:`tuple`
4246
Shape of the local array.
4347
"""
44-
if partition == Partition.BROADCAST:
48+
if partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST]:
4549
local_shape = global_shape
4650
# Split the array
4751
else:
@@ -75,7 +79,7 @@ class DistributedArray:
7579
MPI Communicator over which array is distributed.
7680
Defaults to ``mpi4py.MPI.COMM_WORLD``.
7781
partition : :obj:`Partition`, optional
78-
Broadcast or Scatter the array. Defaults to ``Partition.SCATTER``.
82+
Broadcast, UnsafeBroadcast, or Scatter the array. Defaults to ``Partition.SCATTER``.
7983
axis : :obj:`int`, optional
8084
Axis along which distribution occurs. Defaults to ``0``.
8185
local_shapes : :obj:`list`, optional
@@ -102,8 +106,8 @@ def __init__(self, global_shape: Union[Tuple, Integral],
102106
raise IndexError(f"Axis {axis} out of range for DistributedArray "
103107
f"of shape {global_shape}")
104108
if partition not in Partition:
105-
raise ValueError(f"Should be either {Partition.BROADCAST} "
106-
f"or {Partition.SCATTER}")
109+
raise ValueError(f"Should be either {Partition.BROADCAST}, "
110+
f"{Partition.UNSAFE_BROADCAST} or {Partition.SCATTER}")
107111
self.dtype = dtype
108112
self._global_shape = _value_or_sized_to_tuple(global_shape)
109113
self._base_comm = base_comm
@@ -128,6 +132,9 @@ def __setitem__(self, index, value):
128132
`Partition.SCATTER` - Local Arrays are assigned their
129133
unique values.
130134
135+
`Partition.UNSAFE_SCATTER` - Local Arrays are assigned their
136+
unique values.
137+
131138
`Partition.BROADCAST` - The value at rank-0 is broadcasted
132139
and is assigned to all the ranks.
133140
@@ -139,12 +146,10 @@ def __setitem__(self, index, value):
139146
Represents the value that will be assigned to the local array at
140147
the specified index positions.
141148
"""
142-
# if self.partition is Partition.BROADCAST:
143-
# self.local_array[index] = self.base_comm.bcast(value)
144-
# else:
145-
# self.local_array[index] = value
146-
# testing this... avoid broadcasting and just let the user store the same value in each rank
147-
self.local_array[index] = value
149+
if self.partition is Partition.BROADCAST:
150+
self.local_array[index] = self.base_comm.bcast(value)
151+
else:
152+
self.local_array[index] = value
148153

149154
@property
150155
def global_shape(self):
@@ -288,7 +293,7 @@ def asarray(self):
288293
Global Array gathered at all ranks
289294
"""
290295
# Since the global array was replicated at all ranks
291-
if self.partition == Partition.BROADCAST:
296+
if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST]:
292297
# Get only self.local_array.
293298
return self.local_array
294299
# Gather all the local arrays and apply concatenation.
@@ -333,7 +338,7 @@ def to_dist(cls, x: NDArray,
333338
mask=mask,
334339
engine=get_module_name(get_array_module(x)),
335340
dtype=x.dtype)
336-
if partition == Partition.BROADCAST:
341+
if partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST]:
337342
dist_array[:] = x
338343
else:
339344
slices = [slice(None)] * x.ndim
@@ -352,7 +357,7 @@ def _check_local_shapes(self, local_shapes):
352357
raise ValueError(f"Length of local shapes is not equal to number of processes; "
353358
f"{len(local_shapes)} != {self.size}")
354359
# Check if local shape == global shape
355-
if self.partition is Partition.BROADCAST and local_shapes[self.rank] != self.global_shape:
360+
if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST] and local_shapes[self.rank] != self.global_shape:
356361
raise ValueError(f"Local shape is not equal to global shape at rank = {self.rank};"
357362
f"{local_shapes[self.rank]} != {self.global_shape}")
358363
elif self.partition is Partition.SCATTER:
@@ -481,9 +486,9 @@ def dot(self, dist_array):
481486

482487
# Convert to Partition.SCATTER if Partition.BROADCAST
483488
x = DistributedArray.to_dist(x=self.local_array) \
484-
if self.partition is Partition.BROADCAST else self
489+
if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST] else self
485490
y = DistributedArray.to_dist(x=dist_array.local_array) \
486-
if self.partition is Partition.BROADCAST else dist_array
491+
if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST] else dist_array
487492
# Flatten the local arrays and calculate dot product
488493
return self._allreduce_subcomm(np.dot(x.local_array.flatten(), y.local_array.flatten()))
489494

@@ -555,7 +560,7 @@ def norm(self, ord: Optional[int] = None,
555560
"""
556561
# Convert to Partition.SCATTER if Partition.BROADCAST
557562
x = DistributedArray.to_dist(x=self.local_array) \
558-
if self.partition is Partition.BROADCAST else self
563+
if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST] else self
559564
if axis == -1:
560565
# Flatten the local arrays and calculate norm
561566
return x._compute_vector_norm(x.local_array.flatten(), axis=0, ord=ord)

0 commit comments

Comments
 (0)