@@ -72,9 +72,15 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
7272 array_orig = xp .asarray (array , copy = True )
7373 yield
7474
75- if copy is None :
76- copy = not is_writeable_array (array )
77- xp_assert_equal (xp .all (array == array_orig ), xp .asarray (copy ))
75+ if copy is True :
76+ # Original has not been modified
77+ xp_assert_equal (array , array_orig )
78+ elif copy is False :
79+ # Original has been modified
80+ with pytest .raises (AssertionError ):
81+ xp_assert_equal (array , array_orig )
82+ # Test nothing for copy=None. Dask changes behaviour depending on
83+ # whether it's a special case of a bool mask with scalar RHS or not.
7884
7985
8086@pytest .mark .parametrize (
@@ -89,7 +95,7 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
8995 ],
9096)
9197@pytest .mark .parametrize (
92- ("op" , "y" , "expect " ),
98+ ("op" , "y" , "expect_list " ),
9399 [
94100 (_AtOp .SET , 40.0 , [10.0 , 40.0 , 40.0 ]),
95101 (_AtOp .ADD , 40.0 , [10.0 , 60.0 , 70.0 ]),
@@ -102,14 +108,13 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
102108 ],
103109)
104110@pytest .mark .parametrize (
105- ("bool_mask" , "shaped_y " ),
111+ ("bool_mask" , "x_ndim" , "y_ndim " ),
106112 [
107- (False , False ),
108- (False , True ),
109- (True , False ), # Uses xp.where(idx, y, x) on JAX and Dask
113+ (False , 1 , 0 ),
114+ (False , 1 , 1 ),
115+ (True , 1 , 0 ), # Uses xp.where(idx, y, x) on JAX and Dask
110116 pytest .param (
111- True ,
112- True ,
117+ * (True , 1 , 1 ),
113118 marks = (
114119 pytest .mark .skip_xp_backend ( # test passes when copy=False
115120 Backend .JAX , reason = "bool mask update with shaped rhs"
@@ -119,6 +124,8 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
119124 ),
120125 ),
121126 ),
127+ (False , 0 , 0 ),
128+ (True , 0 , 0 ),
122129 ],
123130)
124131def test_update_ops (
@@ -127,13 +134,26 @@ def test_update_ops(
127134 expect_copy : bool | None ,
128135 op : _AtOp ,
129136 y : float ,
130- expect : list [float ],
137+ expect_list : list [float ],
131138 bool_mask : bool ,
132- shaped_y : bool ,
139+ x_ndim : int ,
140+ y_ndim : int ,
133141):
134- x = xp .asarray ([10.0 , 20.0 , 30.0 ])
135- idx = xp .asarray ([False , True , True ]) if bool_mask else slice (1 , None )
136- if shaped_y :
142+ if x_ndim == 1 :
143+ x = xp .asarray ([10.0 , 20.0 , 30.0 ])
144+ idx = xp .asarray ([False , True , True ]) if bool_mask else slice (1 , None )
145+ expect : list [float ] | float = expect_list
146+ else :
147+ idx = xp .asarray (True ) if bool_mask else ()
148+ # Pick an element that does change with the operation
149+ if op is _AtOp .MIN :
150+ x = xp .asarray (30.0 )
151+ expect = expect_list [2 ]
152+ else :
153+ x = xp .asarray (20.0 )
154+ expect = expect_list [1 ]
155+
156+ if y_ndim == 1 :
137157 y = xp .asarray ([y , y ])
138158
139159 with assert_copy (x , expect_copy ):
@@ -259,3 +279,56 @@ def test_no_inf_warnings(xp: ModuleType, bool_mask: bool):
259279 # inf - inf -> nan with a warning
260280 z = at_op (x , idx , _AtOp .SUBTRACT , math .inf )
261281 xp_assert_equal (z , xp .asarray ([math .inf , - math .inf , - math .inf ]))
282+
283+
284+ @pytest .mark .parametrize (
285+ "copy" ,
286+ [
287+ None ,
288+ pytest .param (
289+ False ,
290+ marks = [
291+ pytest .mark .skip_xp_backend (
292+ Backend .NUMPY , reason = "np.generic is read-only"
293+ ),
294+ pytest .mark .skip_xp_backend (
295+ Backend .NUMPY_READONLY , reason = "read-only backend"
296+ ),
297+ pytest .mark .skip_xp_backend (Backend .JAX , reason = "read-only backend" ),
298+ pytest .mark .skip_xp_backend (Backend .SPARSE , reason = "read-only backend" ),
299+ pytest .mark .xfail_xp_backend (Backend .DASK , reason = "dask/dask#11722" ),
300+ ],
301+ ),
302+ ],
303+ )
304+ @pytest .mark .parametrize (
305+ "bool_mask" ,
306+ [
307+ pytest .param (
308+ False ,
309+ marks = pytest .mark .xfail_xp_backend (Backend .DASK , reason = "dask/dask#11722" ),
310+ ),
311+ True ,
312+ ],
313+ )
314+ def test_gh134 (xp : ModuleType , bool_mask : bool , copy : bool | None ):
315+ """
316+ Test that xpx.at doesn't encroach in a bug of dask.array.Array.__setitem__, which
317+ blindly assumes that chunk contents are writeable np.ndarray objects:
318+
319+ https://github.com/dask/dask/issues/11722
320+
321+ In other words: when special-casing bool masks for Dask, unless the user explicitly
322+ asks for copy=False, do not needlessly write back to the input.
323+ """
324+ x = xp .zeros (1 )
325+
326+ # In numpy, we have a writeable np.ndarray in input and a read-only np.generic in
327+ # output. As both are Arrays, this behaviour is Array API compliant.
328+ # In Dask, we have a writeable da.Array on both sides, and if you call __setitem__
329+ # on it all seems fine, but when you compute() your graph is corrupted.
330+ y = x [0 ]
331+
332+ idx = xp .asarray (True ) if bool_mask else ()
333+ z = at_op (y , idx , _AtOp .SET , 1 , copy = copy )
334+ xp_assert_equal (z , xp .asarray (1 , dtype = x .dtype ))
0 commit comments