33from collections .abc import Callable , Generator
44from contextlib import contextmanager
55from types import ModuleType
6- from typing import Any , cast
6+ from typing import cast
77
88import numpy as np
99import pytest
2323]
2424
2525
26- def at_op ( # type: ignore[no-any-explicit]
26+ def at_op (
2727 x : Array ,
2828 idx : Index ,
2929 op : _AtOp ,
3030 y : Array | object ,
31- ** kwargs : Any , # Test the default copy=None
31+ copy : bool | None = None ,
32+ xp : ModuleType | None = None ,
3233) -> Array :
3334 """
3435 Wrapper around at(x, idx).op(y, copy=copy, xp=xp).
@@ -39,30 +40,33 @@ def at_op( # type: ignore[no-any-explicit]
3940 which is not a common use case.
4041 """
4142 if isinstance (idx , (slice | tuple )):
42- return _at_op (x , None , pickle .dumps (idx ), op , y , ** kwargs )
43- return _at_op (x , idx , None , op , y , ** kwargs )
43+ return _at_op (x , None , pickle .dumps (idx ), op , y , copy = copy , xp = xp )
44+ return _at_op (x , idx , None , op , y , copy = copy , xp = xp )
4445
4546
46- def _at_op ( # type: ignore[no-any-explicit]
47+ def _at_op (
4748 x : Array ,
4849 idx : Index | None ,
4950 idx_pickle : bytes | None ,
5051 op : _AtOp ,
5152 y : Array | object ,
52- ** kwargs : Any ,
53+ copy : bool | None ,
54+ xp : ModuleType | None = None ,
5355) -> Array :
5456 """jitted helper of at_op"""
5557 if idx_pickle :
5658 idx = pickle .loads (idx_pickle )
5759 meth = cast (Callable [..., Array ], getattr (at (x , idx ), op .value )) # type: ignore[no-any-explicit]
58- return meth (y , ** kwargs )
60+ return meth (y , copy = copy , xp = xp )
5961
6062
6163lazy_xp_function (_at_op , static_argnames = ("op" , "idx_pickle" , "copy" , "xp" ))
6264
6365
6466@contextmanager
65- def assert_copy (array : Array , copy : bool | None ) -> Generator [None , None , None ]:
67+ def assert_copy (
68+ array : Array , copy : bool | None , expect_copy : bool | None = None
69+ ) -> Generator [None , None , None ]:
6670 if copy is False and not is_writeable_array (array ):
6771 with pytest .raises ((TypeError , ValueError )):
6872 yield
@@ -72,24 +76,23 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
7276 array_orig = xp .asarray (array , copy = True )
7377 yield
7478
75- if copy is None :
76- copy = not is_writeable_array (array )
77- xp_assert_equal (xp .all (array == array_orig ), xp .asarray (copy ))
79+ if expect_copy is None :
80+ expect_copy = copy
7881
82+ if expect_copy :
83+ # Original has not been modified
84+ xp_assert_equal (array , array_orig )
85+ elif expect_copy is False :
86+ # Original has been modified
87+ with pytest .raises (AssertionError ):
88+ xp_assert_equal (array , array_orig )
89+ # Test nothing for copy=None. Dask changes behaviour depending on
90+ # whether it's a special case of a bool mask with scalar RHS or not.
7991
92+
93+ @pytest .mark .parametrize ("copy" , [False , True , None ])
8094@pytest .mark .parametrize (
81- ("kwargs" , "expect_copy" ),
82- [
83- pytest .param ({"copy" : True }, True , id = "copy=True" ),
84- pytest .param ({"copy" : False }, False , id = "copy=False" ),
85- # Behavior is backend-specific
86- pytest .param ({"copy" : None }, None , id = "copy=None" ),
87- # Test that the copy parameter defaults to None
88- pytest .param ({}, None , id = "no copy kwarg" ),
89- ],
90- )
91- @pytest .mark .parametrize (
92- ("op" , "y" , "expect" ),
95+ ("op" , "y" , "expect_list" ),
9396 [
9497 (_AtOp .SET , 40.0 , [10.0 , 40.0 , 40.0 ]),
9598 (_AtOp .ADD , 40.0 , [10.0 , 60.0 , 70.0 ]),
@@ -102,14 +105,13 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
102105 ],
103106)
104107@pytest .mark .parametrize (
105- ("bool_mask" , "shaped_y " ),
108+ ("bool_mask" , "x_ndim" , "y_ndim " ),
106109 [
107- (False , False ),
108- (False , True ),
109- (True , False ), # Uses xp.where(idx, y, x) on JAX and Dask
110+ (False , 1 , 0 ),
111+ (False , 1 , 1 ),
112+ (True , 1 , 0 ), # Uses xp.where(idx, y, x) on JAX and Dask
110113 pytest .param (
111- True ,
112- True ,
114+ * (True , 1 , 1 ),
113115 marks = (
114116 pytest .mark .skip_xp_backend ( # test passes when copy=False
115117 Backend .JAX , reason = "bool mask update with shaped rhs"
@@ -119,29 +121,65 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
119121 ),
120122 ),
121123 ),
124+ (False , 0 , 0 ),
125+ (True , 0 , 0 ),
122126 ],
123127)
124128def test_update_ops (
125129 xp : ModuleType ,
126- kwargs : dict [str , bool | None ],
127- expect_copy : bool | None ,
130+ copy : bool | None ,
128131 op : _AtOp ,
129132 y : float ,
130- expect : list [float ],
133+ expect_list : list [float ],
131134 bool_mask : bool ,
132- shaped_y : bool ,
135+ x_ndim : int ,
136+ y_ndim : int ,
133137):
134- x = xp .asarray ([10.0 , 20.0 , 30.0 ])
135- idx = xp .asarray ([False , True , True ]) if bool_mask else slice (1 , None )
136- if shaped_y :
138+ if x_ndim == 1 :
139+ x = xp .asarray ([10.0 , 20.0 , 30.0 ])
140+ idx = xp .asarray ([False , True , True ]) if bool_mask else slice (1 , None )
141+ expect : list [float ] | float = expect_list
142+ else :
143+ idx = xp .asarray (True ) if bool_mask else ()
144+ # Pick an element that does change with the operation
145+ if op is _AtOp .MIN :
146+ x = xp .asarray (30.0 )
147+ expect = expect_list [2 ]
148+ else :
149+ x = xp .asarray (20.0 )
150+ expect = expect_list [1 ]
151+
152+ if y_ndim == 1 :
137153 y = xp .asarray ([y , y ])
138154
139- with assert_copy (x , expect_copy ):
140- z = at_op (x , idx , op , y , ** kwargs )
155+ with assert_copy (x , copy ):
156+ z = at_op (x , idx , op , y , copy = copy )
141157 assert isinstance (z , type (x ))
142158 xp_assert_equal (z , xp .asarray (expect ))
143159
144160
161+ @pytest .mark .parametrize ("op" , list (_AtOp ))
162+ def test_copy_default (xp : ModuleType , library : Backend , op : _AtOp ):
163+ """
164+ Test that the default copy behaviour is False for writeable arrays
165+ and True for read-only ones.
166+ """
167+ x = xp .asarray ([1.0 , 10.0 , 20.0 ])
168+ expect_copy = not is_writeable_array (x )
169+ meth = cast (Callable [..., Array ], getattr (at (x )[:2 ], op .value )) # type: ignore[no-any-explicit]
170+ with assert_copy (x , None , expect_copy ):
171+ _ = meth (2.0 )
172+
173+ x = xp .asarray ([1.0 , 10.0 , 20.0 ])
174+ # Dask's default copy value is True for bool masks,
175+ # even if the arrays are writeable.
176+ expect_copy = not is_writeable_array (x ) or library is Backend .DASK
177+ idx = xp .asarray ([True , True , False ])
178+ meth = cast (Callable [..., Array ], getattr (at (x , idx ), op .value )) # type: ignore[no-any-explicit]
179+ with assert_copy (x , None , expect_copy ):
180+ _ = meth (2.0 )
181+
182+
145183def test_copy_invalid ():
146184 a = np .asarray ([1 , 2 , 3 ])
147185 with pytest .raises (ValueError , match = "copy" ):
@@ -259,3 +297,46 @@ def test_no_inf_warnings(xp: ModuleType, bool_mask: bool):
259297 # inf - inf -> nan with a warning
260298 z = at_op (x , idx , _AtOp .SUBTRACT , math .inf )
261299 xp_assert_equal (z , xp .asarray ([math .inf , - math .inf , - math .inf ]))
300+
301+
302+ @pytest .mark .parametrize (
303+ "copy" ,
304+ [
305+ None ,
306+ pytest .param (
307+ False ,
308+ marks = [
309+ pytest .mark .skip_xp_backend (
310+ Backend .NUMPY , reason = "np.generic is read-only"
311+ ),
312+ pytest .mark .skip_xp_backend (
313+ Backend .NUMPY_READONLY , reason = "read-only backend"
314+ ),
315+ pytest .mark .skip_xp_backend (Backend .JAX , reason = "read-only backend" ),
316+ pytest .mark .skip_xp_backend (Backend .SPARSE , reason = "read-only backend" ),
317+ ],
318+ ),
319+ ],
320+ )
321+ @pytest .mark .parametrize ("bool_mask" , [False , True ])
322+ def test_gh134 (xp : ModuleType , bool_mask : bool , copy : bool | None ):
323+ """
324+ Test that xpx.at doesn't encroach in a bug of dask.array.Array.__setitem__, which
325+ blindly assumes that chunk contents are writeable np.ndarray objects:
326+
327+ https://github.com/dask/dask/issues/11722
328+
329+ In other words: when special-casing bool masks for Dask, unless the user explicitly
330+ asks for copy=False, do not needlessly write back to the input.
331+ """
332+ x = xp .zeros (1 )
333+
334+ # In numpy, we have a writeable np.ndarray in input and a read-only np.generic in
335+ # output. As both are Arrays, this behaviour is Array API compliant.
336+ # In Dask, we have a writeable da.Array on both sides, and if you call __setitem__
337+ # on it all seems fine, but when you compute() your graph is corrupted.
338+ y = x [0 ]
339+
340+ idx = xp .asarray (True ) if bool_mask else ()
341+ z = at_op (y , idx , _AtOp .SET , 1 , copy = copy )
342+ xp_assert_equal (z , xp .asarray (1 , dtype = x .dtype ))
0 commit comments