@@ -111,7 +111,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
111111 if x .partition not in [Partition .BROADCAST , Partition .UNSAFE_BROADCAST ]:
112112 raise ValueError (f"x should have partition={ Partition .BROADCAST } ,{ Partition .UNSAFE_BROADCAST } "
113113 f"Got { x .partition } instead..." )
114- y = DistributedArray (global_shape = self .shape [0 ],
114+ y = DistributedArray (global_shape = self .shape [0 ],
115115 base_comm = x .base_comm ,
116116 base_comm_nccl = x .base_comm_nccl ,
117117 partition = x .partition ,
@@ -128,11 +128,6 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
128128 for isl in range (self .nsls [self .rank ]):
129129 y1 [isl ] = ncp .dot (self .G [isl ], x [isl ])
130130 # gather results
131- # TODO: _allgather is supposed to be private to DistributedArray
132- # but so far, we do not take base_comm_nccl as an argument to Op.
133- # For consistency, y._allgather has to be called here.
134- # Alternatively, we can also do if-else checking x.base_comm_nccl, but that means
135- # we have to call function from _nccl.py
136131 y [:] = ncp .vstack (y ._allgather (y1 )).ravel ()
137132 return y
138133
@@ -141,7 +136,7 @@ def _rmatvec(self, x: NDArray) -> NDArray:
141136 if x .partition not in [Partition .BROADCAST , Partition .UNSAFE_BROADCAST ]:
142137 raise ValueError (f"x should have partition={ Partition .BROADCAST } ,{ Partition .UNSAFE_BROADCAST } "
143138 f"Got { x .partition } instead..." )
144- y = DistributedArray (global_shape = self .shape [1 ],
139+ y = DistributedArray (global_shape = self .shape [1 ],
145140 base_comm = x .base_comm ,
146141 base_comm_nccl = x .base_comm_nccl ,
147142 partition = x .partition ,
@@ -176,8 +171,8 @@ def _rmatvec(self, x: NDArray) -> NDArray:
176171 if self .usematmul and isinstance (recv , ncp .ndarray ) :
177172 # unrolling
178173 chunk_size = self .ny * self .nz
179- num_partition = (len (recv )+ chunk_size - 1 ) // chunk_size
180- recv = ncp .vstack ([recv [i * chunk_size : (i + 1 ) * chunk_size ].reshape (self .nz , self .ny ).T for i in range (num_partition )])
174+ num_partition = (len (recv ) + chunk_size - 1 ) // chunk_size
175+ recv = ncp .vstack ([recv [i * chunk_size : (i + 1 ) * chunk_size ].reshape (self .nz , self .ny ).T for i in range (num_partition )])
181176 else :
182177 recv = ncp .vstack (recv )
183178 y [:] = recv .ravel ()
0 commit comments