@@ -139,10 +139,12 @@ def __setitem__(self, index, value):
139139 Represents the value that will be assigned to the local array at
140140 the specified index positions.
141141 """
142- if self .partition is Partition .BROADCAST :
143- self .local_array [index ] = self .base_comm .bcast (value )
144- else :
145- self .local_array [index ] = value
142+ # if self.partition is Partition.BROADCAST:
143+ # self.local_array[index] = self.base_comm.bcast(value)
144+ # else:
145+ # self.local_array[index] = value
146+ # testing this... avoid broadcasting and just let the user store the same value in each rank
147+ self .local_array [index ] = value
146148
147149 @property
148150 def global_shape (self ):
@@ -527,6 +529,19 @@ def _compute_vector_norm(self, local_array: NDArray,
527529 recv_buf = np .power (recv_buf , 1. / ord )
528530 return recv_buf
529531
532+ def zeros_like (self ):
533+ """Creates a copy of the DistributedArray filled with zeros
534+ """
535+ arr = DistributedArray (global_shape = self .global_shape ,
536+ base_comm = self .base_comm ,
537+ partition = self .partition ,
538+ axis = self .axis ,
539+ local_shapes = self .local_shapes ,
540+ engine = self .engine ,
541+ dtype = self .dtype )
542+ arr [:] = 0.
543+ return arr
544+
530545 def norm (self , ord : Optional [int ] = None ,
531546 axis : int = - 1 ):
532547 """Distributed numpy.linalg.norm method
0 commit comments