@@ -111,7 +111,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
111111 if x .partition not in [Partition .BROADCAST , Partition .UNSAFE_BROADCAST ]:
112112 raise ValueError (f"x should have partition={ Partition .BROADCAST } ,{ Partition .UNSAFE_BROADCAST } "
113113 f"Got { x .partition } instead..." )
114- y = DistributedArray (global_shape = self .shape [0 ], partition = Partition . BROADCAST ,
114+ y = DistributedArray (global_shape = self .shape [0 ], partition = x . partition ,
115115 engine = x .engine , dtype = self .dtype )
116116 x = x .local_array .reshape (self .dims ).squeeze ()
117117 x = x [self .islstart [self .rank ]:self .islend [self .rank ]]
@@ -133,7 +133,7 @@ def _rmatvec(self, x: NDArray) -> NDArray:
133133 if x .partition not in [Partition .BROADCAST , Partition .UNSAFE_BROADCAST ]:
134134 raise ValueError (f"x should have partition={ Partition .BROADCAST } ,{ Partition .UNSAFE_BROADCAST } "
135135 f"Got { x .partition } instead..." )
136- y = DistributedArray (global_shape = self .shape [1 ], partition = Partition . BROADCAST ,
136+ y = DistributedArray (global_shape = self .shape [1 ], partition = x . partition ,
137137 engine = x .engine , dtype = self .dtype )
138138 x = x .local_array .reshape (self .dimsd ).squeeze ()
139139 x = x [self .islstart [self .rank ]:self .islend [self .rank ]]
0 commit comments