@@ -112,20 +112,20 @@ def _register_multiplications(
112112 def _matvec (self , x : DistributedArray ) -> DistributedArray :
113113 # If Partition.BROADCAST, then convert to Partition.SCATTER
114114 if x .partition is Partition .BROADCAST :
115- x = DistributedArray .to_dist (x = x .local_array )
115+ x = DistributedArray .to_dist (x = x .local_array , base_comm_nccl = x . base_comm_nccl )
116116 return self ._hmatvec (x )
117117
118118 def _rmatvec (self , x : DistributedArray ) -> DistributedArray :
119119 # If Partition.BROADCAST, then convert to Partition.SCATTER
120120 if x .partition is Partition .BROADCAST :
121- x = DistributedArray .to_dist (x = x .local_array )
121+ x = DistributedArray .to_dist (x = x .local_array , base_comm_nccl = x . base_comm_nccl )
122122 return self ._hrmatvec (x )
123123
124124 @reshaped
125125 def _matvec_forward (self , x : DistributedArray ) -> DistributedArray :
126126 ncp = get_module (x .engine )
127- y = DistributedArray (global_shape = x .global_shape , local_shapes = x .local_shapes ,
128- axis = x .axis , engine = x .engine , dtype = self .dtype )
127+ y = DistributedArray (global_shape = x .global_shape , base_comm_nccl = x .base_comm_nccl ,
128+ local_shapes = x . local_shapes , axis = x .axis , engine = x .engine , dtype = self .dtype )
129129 ghosted_x = x .add_ghost_cells (cells_back = 2 )
130130 y_forward = ghosted_x [2 :] - 2 * ghosted_x [1 :- 1 ] + ghosted_x [:- 2 ]
131131 if self .rank == self .size - 1 :
@@ -136,7 +136,8 @@ def _matvec_forward(self, x: DistributedArray) -> DistributedArray:
136136 @reshaped
137137 def _rmatvec_forward (self , x : DistributedArray ) -> DistributedArray :
138138 ncp = get_module (x .engine )
139- y = DistributedArray (global_shape = x .global_shape , local_shapes = x .local_shapes , axis = x .axis , dtype = self .dtype )
139+ y = DistributedArray (global_shape = x .global_shape , base_comm_nccl = x .base_comm_nccl ,
140+ local_shapes = x .local_shapes , axis = x .axis , engine = x .engine , dtype = self .dtype )
140141 y [:] = 0
141142 if self .rank == self .size - 1 :
142143 y [:- 2 ] += x [:- 2 ]
@@ -162,8 +163,8 @@ def _rmatvec_forward(self, x: DistributedArray) -> DistributedArray:
162163 @reshaped
163164 def _matvec_backward (self , x : DistributedArray ) -> DistributedArray :
164165 ncp = get_module (x .engine )
165- y = DistributedArray (global_shape = x .global_shape , local_shapes = x .local_shapes ,
166- axis = x .axis , engine = x .engine , dtype = self .dtype )
166+ y = DistributedArray (global_shape = x .global_shape , base_comm_nccl = x .base_comm_nccl ,
167+ local_shapes = x . local_shapes , axis = x .axis , engine = x .engine , dtype = self .dtype )
167168 ghosted_x = x .add_ghost_cells (cells_front = 2 )
168169 y_backward = ghosted_x [2 :] - 2 * ghosted_x [1 :- 1 ] + ghosted_x [:- 2 ]
169170 if self .rank == 0 :
@@ -174,8 +175,8 @@ def _matvec_backward(self, x: DistributedArray) -> DistributedArray:
174175 @reshaped
175176 def _rmatvec_backward (self , x : DistributedArray ) -> DistributedArray :
176177 ncp = get_module (x .engine )
177- y = DistributedArray (global_shape = x .global_shape , local_shapes = x .local_shapes ,
178- axis = x .axis , engine = x .engine , dtype = self .dtype )
178+ y = DistributedArray (global_shape = x .global_shape , base_comm_nccl = x .base_comm_nccl ,
179+ local_shapes = x . local_shapes , axis = x .axis , engine = x .engine , dtype = self .dtype )
179180 y [:] = 0
180181 ghosted_x = x .add_ghost_cells (cells_back = 2 )
181182 y_backward = ghosted_x [2 :]
@@ -201,8 +202,8 @@ def _rmatvec_backward(self, x: DistributedArray) -> DistributedArray:
201202 @reshaped
202203 def _matvec_centered (self , x : DistributedArray ) -> DistributedArray :
203204 ncp = get_module (x .engine )
204- y = DistributedArray (global_shape = x .global_shape , local_shapes = x .local_shapes ,
205- axis = x .axis , engine = x .engine , dtype = self .dtype )
205+ y = DistributedArray (global_shape = x .global_shape , base_comm_nccl = x .base_comm_nccl ,
206+ local_shapes = x . local_shapes , axis = x .axis , engine = x .engine , dtype = self .dtype )
206207 ghosted_x = x .add_ghost_cells (cells_front = 1 , cells_back = 1 )
207208 y_centered = ghosted_x [2 :] - 2 * ghosted_x [1 :- 1 ] + ghosted_x [:- 2 ]
208209 if self .rank == 0 :
@@ -221,8 +222,8 @@ def _matvec_centered(self, x: DistributedArray) -> DistributedArray:
221222 @reshaped
222223 def _rmatvec_centered (self , x : DistributedArray ) -> DistributedArray :
223224 ncp = get_module (x .engine )
224- y = DistributedArray (global_shape = x .global_shape , local_shapes = x .local_shapes ,
225- axis = x .axis , engine = x .engine , dtype = self .dtype )
225+ y = DistributedArray (global_shape = x .global_shape , base_comm_nccl = x .base_comm_nccl ,
226+ local_shapes = x . local_shapes , axis = x .axis , engine = x .engine , dtype = self .dtype )
226227 y [:] = 0
227228 ghosted_x = x .add_ghost_cells (cells_back = 2 )
228229 y_centered = ghosted_x [1 :- 1 ]
0 commit comments