@@ -129,19 +129,19 @@ def _register_multiplications(
129129    def  _matvec (self , x : DistributedArray ) ->  DistributedArray :
130130        # If Partition.BROADCAST, then convert to Partition.SCATTER 
131131        if  x .partition  is  Partition .BROADCAST :
132-             x  =  DistributedArray .to_dist (x = x .local_array , base_comm_nccl = x .base_comm_nccl )
132+             x  =  DistributedArray .to_dist (x = x .local_array , base_comm = x . base_comm ,  base_comm_nccl = x .base_comm_nccl )
133133        return  self ._hmatvec (x )
134134
135135    def  _rmatvec (self , x : DistributedArray ) ->  DistributedArray :
136136        # If Partition.BROADCAST, then convert to Partition.SCATTER 
137137        if  x .partition  is  Partition .BROADCAST :
138-             x  =  DistributedArray .to_dist (x = x .local_array , base_comm_nccl = x .base_comm_nccl )
138+             x  =  DistributedArray .to_dist (x = x .local_array , base_comm = x . base_comm ,  base_comm_nccl = x .base_comm_nccl )
139139        return  self ._hrmatvec (x )
140140
141141    @reshaped  
142142    def  _matvec_forward (self , x : DistributedArray ) ->  DistributedArray :
143143        ncp  =  get_module (x .engine )
144-         y  =  DistributedArray (global_shape = x .global_shape , base_comm_nccl = x .base_comm_nccl , local_shapes = x .local_shapes ,
144+         y  =  DistributedArray (global_shape = x .global_shape , base_comm = x . base_comm ,  base_comm_nccl = x .base_comm_nccl , local_shapes = x .local_shapes ,
145145                             axis = x .axis , engine = x .engine , dtype = self .dtype )
146146        ghosted_x  =  x .add_ghost_cells (cells_back = 1 )
147147        y_forward  =  ghosted_x [1 :] -  ghosted_x [:- 1 ]
@@ -153,7 +153,7 @@ def _matvec_forward(self, x: DistributedArray) -> DistributedArray:
153153    @reshaped  
154154    def  _rmatvec_forward (self , x : DistributedArray ) ->  DistributedArray :
155155        ncp  =  get_module (x .engine )
156-         y  =  DistributedArray (global_shape = x .global_shape , base_comm_nccl = x .base_comm_nccl , local_shapes = x .local_shapes ,
156+         y  =  DistributedArray (global_shape = x .global_shape , base_comm = x . base_comm ,  base_comm_nccl = x .base_comm_nccl , local_shapes = x .local_shapes ,
157157                             axis = x .axis , engine = x .engine , dtype = self .dtype )
158158        y [:] =  0 
159159        if  self .rank  ==  self .size  -  1 :
@@ -171,7 +171,7 @@ def _rmatvec_forward(self, x: DistributedArray) -> DistributedArray:
171171    @reshaped  
172172    def  _matvec_backward (self , x : DistributedArray ) ->  DistributedArray :
173173        ncp  =  get_module (x .engine )
174-         y  =  DistributedArray (global_shape = x .global_shape , base_comm_nccl = x .base_comm_nccl , local_shapes = x .local_shapes ,
174+         y  =  DistributedArray (global_shape = x .global_shape , base_comm = x . base_comm ,  base_comm_nccl = x .base_comm_nccl , local_shapes = x .local_shapes ,
175175                             axis = x .axis , engine = x .engine , dtype = self .dtype )
176176        ghosted_x  =  x .add_ghost_cells (cells_front = 1 )
177177        y_backward  =  ghosted_x [1 :] -  ghosted_x [:- 1 ]
@@ -183,7 +183,7 @@ def _matvec_backward(self, x: DistributedArray) -> DistributedArray:
183183    @reshaped  
184184    def  _rmatvec_backward (self , x : DistributedArray ) ->  DistributedArray :
185185        ncp  =  get_module (x .engine )
186-         y  =  DistributedArray (global_shape = x .global_shape , base_comm_nccl = x .base_comm_nccl , local_shapes = x .local_shapes ,
186+         y  =  DistributedArray (global_shape = x .global_shape , base_comm = x . base_comm ,  base_comm_nccl = x .base_comm_nccl , local_shapes = x .local_shapes ,
187187                             axis = x .axis , engine = x .engine , dtype = self .dtype )
188188        y [:] =  0 
189189        ghosted_x  =  x .add_ghost_cells (cells_back = 1 )
@@ -201,7 +201,7 @@ def _rmatvec_backward(self, x: DistributedArray) -> DistributedArray:
201201    @reshaped  
202202    def  _matvec_centered3 (self , x : DistributedArray ) ->  DistributedArray :
203203        ncp  =  get_module (x .engine )
204-         y  =  DistributedArray (global_shape = x .global_shape , base_comm_nccl = x .base_comm_nccl , local_shapes = x .local_shapes ,
204+         y  =  DistributedArray (global_shape = x .global_shape , base_comm = x . base_comm ,  base_comm_nccl = x .base_comm_nccl , local_shapes = x .local_shapes ,
205205                             axis = x .axis , engine = x .engine , dtype = self .dtype )
206206        ghosted_x  =  x .add_ghost_cells (cells_front = 1 , cells_back = 1 )
207207        y_centered  =  0.5  *  (ghosted_x [2 :] -  ghosted_x [:- 2 ])
@@ -221,7 +221,7 @@ def _matvec_centered3(self, x: DistributedArray) -> DistributedArray:
221221    @reshaped  
222222    def  _rmatvec_centered3 (self , x : DistributedArray ) ->  DistributedArray :
223223        ncp  =  get_module (x .engine )
224-         y  =  DistributedArray (global_shape = x .global_shape , base_comm_nccl = x .base_comm_nccl , local_shapes = x .local_shapes ,
224+         y  =  DistributedArray (global_shape = x .global_shape , base_comm = x . base_comm ,  base_comm_nccl = x .base_comm_nccl , local_shapes = x .local_shapes ,
225225                             axis = x .axis , engine = x .engine , dtype = self .dtype )
226226        y [:] =  0 
227227
@@ -249,7 +249,7 @@ def _rmatvec_centered3(self, x: DistributedArray) -> DistributedArray:
249249    @reshaped  
250250    def  _matvec_centered5 (self , x : DistributedArray ) ->  DistributedArray :
251251        ncp  =  get_module (x .engine )
252-         y  =  DistributedArray (global_shape = x .global_shape , base_comm_nccl = x .base_comm_nccl , local_shapes = x .local_shapes ,
252+         y  =  DistributedArray (global_shape = x .global_shape , base_comm = x . base_comm ,  base_comm_nccl = x .base_comm_nccl , local_shapes = x .local_shapes ,
253253                             axis = x .axis , engine = x .engine , dtype = self .dtype )
254254        ghosted_x  =  x .add_ghost_cells (cells_front = 2 , cells_back = 2 )
255255        y_centered  =  (
@@ -276,7 +276,7 @@ def _matvec_centered5(self, x: DistributedArray) -> DistributedArray:
276276    @reshaped  
277277    def  _rmatvec_centered5 (self , x : DistributedArray ) ->  DistributedArray :
278278        ncp  =  get_module (x .engine )
279-         y  =  DistributedArray (global_shape = x .global_shape , base_comm_nccl = x .base_comm_nccl , local_shapes = x .local_shapes ,
279+         y  =  DistributedArray (global_shape = x .global_shape , base_comm = x . base_comm ,  base_comm_nccl = x .base_comm_nccl , local_shapes = x .local_shapes ,
280280                             axis = x .axis , engine = x .engine , dtype = self .dtype )
281281        y [:] =  0 
282282        ghosted_x  =  x .add_ghost_cells (cells_back = 4 )
0 commit comments