@@ -156,16 +156,40 @@ def test_getitem_mask(shape, data):
156
156
)
157
157
158
158
159
- @given (hh .shapes (min_side = 1 ), st .data ())
159
+ @given (hh .shapes (), st .data ())
160
160
def test_setitem_mask (shape , data ):
161
161
x = data .draw (xps .arrays (xps .scalar_dtypes (), shape = shape ), label = "x" )
162
162
key = data .draw (xps .arrays (dtype = xp .bool , shape = shape ), label = "key" )
163
- value = data .draw (xps .from_dtype (x .dtype ), label = "value" ) # TODO: more values
163
+ value = data .draw (
164
+ xps .from_dtype (x .dtype ) | xps .arrays (dtype = x .dtype , shape = ()), label = "value"
165
+ )
164
166
165
167
res = xp .asarray (x , copy = True )
166
168
res [key ] = value
167
169
168
- # TODO
170
+ ph .assert_dtype ("__setitem__" , x .dtype , res .dtype , repr_name = "x.dtype" )
171
+ ph .assert_shape ("__setitem__" , res .shape , x .shape , repr_name = "x.dtype" )
172
+
173
+ scalar_type = dh .get_scalar_type (x .dtype )
174
+ for idx in sh .ndindex (x .shape ):
175
+ if key [idx ]:
176
+ if isinstance (value , Scalar ):
177
+ ph .assert_scalar_equals (
178
+ "__setitem__" ,
179
+ scalar_type ,
180
+ idx ,
181
+ scalar_type (res [idx ]),
182
+ value ,
183
+ repr_name = "modified x" ,
184
+ )
185
+ else :
186
+ ph .assert_0d_equals (
187
+ "__setitem__" , "value" , value , f"modified x[{ idx } ]" , res [idx ]
188
+ )
189
+ else :
190
+ ph .assert_0d_equals (
191
+ "__setitem__" , f"old x[{ idx } ]" , x [idx ], f"modified x[{ idx } ]" , res [idx ]
192
+ )
169
193
170
194
171
195
def make_param (method_name : str , dtype : DataType , stype : ScalarType ) -> Param :
0 commit comments