@@ -14,10 +14,14 @@ class Partition(Enum):
1414
1515 Distributing data among different processes.
1616
17- - ``BROADCAST``: Distributes data to all processes.
17+ - ``BROADCAST``: Distributes data to all processes
18+ (ensuring that data is kept consistent across processes)
19+ - ``UNSAFE_BROADCAST``: Distributes data to all processes
20+ (without ensuring that data is kept consistent across processes)
1821 - ``SCATTER``: Distributes unique portions to each process.
1922 """
2023 BROADCAST = "Broadcast"
24+ UNSAFE_BROADCAST = "UnsafeBroadcast"
2125 SCATTER = "Scatter"
2226
2327
@@ -41,7 +45,7 @@ def local_split(global_shape: Tuple, base_comm: MPI.Comm,
4145 local_shape : :obj:`tuple`
4246 Shape of the local array.
4347 """
44- if partition == Partition .BROADCAST :
48+ if partition in [ Partition .BROADCAST , Partition . UNSAFE_BROADCAST ] :
4549 local_shape = global_shape
4650 # Split the array
4751 else :
@@ -75,7 +79,7 @@ class DistributedArray:
7579 MPI Communicator over which array is distributed.
7680 Defaults to ``mpi4py.MPI.COMM_WORLD``.
7781 partition : :obj:`Partition`, optional
78- Broadcast or Scatter the array. Defaults to ``Partition.SCATTER``.
82+ Broadcast, UnsafeBroadcast, or Scatter the array. Defaults to ``Partition.SCATTER``.
7983 axis : :obj:`int`, optional
8084 Axis along which distribution occurs. Defaults to ``0``.
8185 local_shapes : :obj:`list`, optional
@@ -102,8 +106,8 @@ def __init__(self, global_shape: Union[Tuple, Integral],
102106 raise IndexError (f"Axis { axis } out of range for DistributedArray "
103107 f"of shape { global_shape } " )
104108 if partition not in Partition :
105- raise ValueError (f"Should be either { Partition .BROADCAST } "
106- f"or { Partition .SCATTER } " )
109+ raise ValueError (f"Should be either { Partition .BROADCAST } , "
110+ f"{ Partition . UNSAFE_BROADCAST } or { Partition .SCATTER } " )
107111 self .dtype = dtype
108112 self ._global_shape = _value_or_sized_to_tuple (global_shape )
109113 self ._base_comm = base_comm
@@ -128,6 +132,9 @@ def __setitem__(self, index, value):
128132 `Partition.SCATTER` - Local Arrays are assigned their
129133 unique values.
130134
135+ `Partition.UNSAFE_SCATTER` - Local Arrays are assigned their
136+ unique values.
137+
131138 `Partition.BROADCAST` - The value at rank-0 is broadcasted
132139 and is assigned to all the ranks.
133140
@@ -139,12 +146,10 @@ def __setitem__(self, index, value):
139146 Represents the value that will be assigned to the local array at
140147 the specified index positions.
141148 """
142- # if self.partition is Partition.BROADCAST:
143- # self.local_array[index] = self.base_comm.bcast(value)
144- # else:
145- # self.local_array[index] = value
146- # testing this... avoid broadcasting and just let the user store the same value in each rank
147- self .local_array [index ] = value
149+ if self .partition is Partition .BROADCAST :
150+ self .local_array [index ] = self .base_comm .bcast (value )
151+ else :
152+ self .local_array [index ] = value
148153
149154 @property
150155 def global_shape (self ):
@@ -288,7 +293,7 @@ def asarray(self):
288293 Global Array gathered at all ranks
289294 """
290295 # Since the global array was replicated at all ranks
291- if self .partition == Partition .BROADCAST :
296+ if self .partition in [ Partition .BROADCAST , Partition . UNSAFE_BROADCAST ] :
292297 # Get only self.local_array.
293298 return self .local_array
294299 # Gather all the local arrays and apply concatenation.
@@ -333,7 +338,7 @@ def to_dist(cls, x: NDArray,
333338 mask = mask ,
334339 engine = get_module_name (get_array_module (x )),
335340 dtype = x .dtype )
336- if partition == Partition .BROADCAST :
341+ if partition in [ Partition .BROADCAST , Partition . UNSAFE_BROADCAST ] :
337342 dist_array [:] = x
338343 else :
339344 slices = [slice (None )] * x .ndim
@@ -352,7 +357,7 @@ def _check_local_shapes(self, local_shapes):
352357 raise ValueError (f"Length of local shapes is not equal to number of processes; "
353358 f"{ len (local_shapes )} != { self .size } " )
354359 # Check if local shape == global shape
355- if self .partition is Partition .BROADCAST and local_shapes [self .rank ] != self .global_shape :
360+ if self .partition in [ Partition .BROADCAST , Partition . UNSAFE_BROADCAST ] and local_shapes [self .rank ] != self .global_shape :
356361 raise ValueError (f"Local shape is not equal to global shape at rank = { self .rank } ;"
357362 f"{ local_shapes [self .rank ]} != { self .global_shape } " )
358363 elif self .partition is Partition .SCATTER :
@@ -481,9 +486,9 @@ def dot(self, dist_array):
481486
482487 # Convert to Partition.SCATTER if Partition.BROADCAST
483488 x = DistributedArray .to_dist (x = self .local_array ) \
484- if self .partition is Partition .BROADCAST else self
489+ if self .partition in [ Partition .BROADCAST , Partition . UNSAFE_BROADCAST ] else self
485490 y = DistributedArray .to_dist (x = dist_array .local_array ) \
486- if self .partition is Partition .BROADCAST else dist_array
491+ if self .partition in [ Partition .BROADCAST , Partition . UNSAFE_BROADCAST ] else dist_array
487492 # Flatten the local arrays and calculate dot product
488493 return self ._allreduce_subcomm (np .dot (x .local_array .flatten (), y .local_array .flatten ()))
489494
@@ -555,7 +560,7 @@ def norm(self, ord: Optional[int] = None,
555560 """
556561 # Convert to Partition.SCATTER if Partition.BROADCAST
557562 x = DistributedArray .to_dist (x = self .local_array ) \
558- if self .partition is Partition .BROADCAST else self
563+ if self .partition in [ Partition .BROADCAST , Partition . UNSAFE_BROADCAST ] else self
559564 if axis == - 1 :
560565 # Flatten the local arrays and calculate norm
561566 return x ._compute_vector_norm (x .local_array .flatten (), axis = 0 , ord = ord )
0 commit comments