@@ -108,8 +108,9 @@ def __init__(
108108
109109 def _matvec (self , x : DistributedArray ) -> DistributedArray :
110110 ncp = get_module (x .engine )
111- if x .partition is not Partition .BROADCAST :
112- raise ValueError (f"x should have partition={ Partition .BROADCAST } , { x .partition } != { Partition .BROADCAST } " )
111+ if x .partition not in [Partition .BROADCAST , Partition .UNSAFE_BROADCAST ]:
112+ raise ValueError (f"x should have partition={ Partition .BROADCAST } ,{ Partition .UNSAFE_BROADCAST } "
113+ f"Got { x .partition } instead..." )
113114 y = DistributedArray (global_shape = self .shape [0 ], partition = Partition .BROADCAST ,
114115 engine = x .engine , dtype = self .dtype )
115116 x = x .local_array .reshape (self .dims ).squeeze ()
@@ -129,8 +130,9 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
129130
130131 def _rmatvec (self , x : NDArray ) -> NDArray :
131132 ncp = get_module (x .engine )
132- if x .partition is not Partition .BROADCAST :
133- raise ValueError (f"x should have partition={ Partition .BROADCAST } , { x .partition } != { Partition .BROADCAST } " )
133+ if x .partition not in [Partition .BROADCAST , Partition .UNSAFE_BROADCAST ]:
134+ raise ValueError (f"x should have partition={ Partition .BROADCAST } ,{ Partition .UNSAFE_BROADCAST } "
135+ f"Got { x .partition } instead..." )
134136 y = DistributedArray (global_shape = self .shape [1 ], partition = Partition .BROADCAST ,
135137 engine = x .engine , dtype = self .dtype )
136138 x = x .local_array .reshape (self .dimsd ).squeeze ()
0 commit comments