33from collections .abc import Callable , Generator
44from contextlib import contextmanager
55from types import ModuleType
6- from typing import Any , cast
6+ from typing import cast
77
88import numpy as np
99import pytest
2323]
2424
2525
26- def at_op ( # type: ignore[no-any-explicit]
26+ def at_op (
2727 x : Array ,
2828 idx : Index ,
2929 op : _AtOp ,
3030 y : Array | object ,
31- ** kwargs : Any , # Test the default copy=None
31+ copy : bool | None = None ,
32+ xp : ModuleType | None = None ,
3233) -> Array :
3334 """
3435 Wrapper around at(x, idx).op(y, copy=copy, xp=xp).
@@ -39,30 +40,33 @@ def at_op( # type: ignore[no-any-explicit]
3940 which is not a common use case.
4041 """
4142 if isinstance (idx , (slice | tuple )):
42- return _at_op (x , None , pickle .dumps (idx ), op , y , ** kwargs )
43- return _at_op (x , idx , None , op , y , ** kwargs )
43+ return _at_op (x , None , pickle .dumps (idx ), op , y , copy = copy , xp = xp )
44+ return _at_op (x , idx , None , op , y , copy = copy , xp = xp )
4445
4546
46- def _at_op ( # type: ignore[no-any-explicit]
47+ def _at_op (
4748 x : Array ,
4849 idx : Index | None ,
4950 idx_pickle : bytes | None ,
5051 op : _AtOp ,
5152 y : Array | object ,
52- ** kwargs : Any ,
53+ copy : bool | None ,
54+ xp : ModuleType | None = None ,
5355) -> Array :
5456 """jitted helper of at_op"""
5557 if idx_pickle :
5658 idx = pickle .loads (idx_pickle )
5759 meth = cast (Callable [..., Array ], getattr (at (x , idx ), op .value )) # type: ignore[no-any-explicit]
58- return meth (y , ** kwargs )
60+ return meth (y , copy = copy , xp = xp )
5961
6062
6163lazy_xp_function (_at_op , static_argnames = ("op" , "idx_pickle" , "copy" , "xp" ))
6264
6365
6466@contextmanager
65- def assert_copy (array : Array , copy : bool | None ) -> Generator [None , None , None ]:
67+ def assert_copy (
68+ array : Array , copy : bool | None , expect_copy : bool | None = None
69+ ) -> Generator [None , None , None ]:
6670 if copy is False and not is_writeable_array (array ):
6771 with pytest .raises ((TypeError , ValueError )):
6872 yield
@@ -72,28 +76,21 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
7276 array_orig = xp .asarray (array , copy = True )
7377 yield
7478
75- if copy is True :
79+ if expect_copy is None :
80+ expect_copy = copy
81+
82+ if expect_copy :
7683 # Original has not been modified
7784 xp_assert_equal (array , array_orig )
78- elif copy is False :
85+ elif expect_copy is False :
7986 # Original has been modified
8087 with pytest .raises (AssertionError ):
8188 xp_assert_equal (array , array_orig )
8289 # Test nothing for copy=None. Dask changes behaviour depending on
8390 # whether it's a special case of a bool mask with scalar RHS or not.
8491
8592
86- @pytest .mark .parametrize (
87- ("kwargs" , "expect_copy" ),
88- [
89- pytest .param ({"copy" : True }, True , id = "copy=True" ),
90- pytest .param ({"copy" : False }, False , id = "copy=False" ),
91- # Behavior is backend-specific
92- pytest .param ({"copy" : None }, None , id = "copy=None" ),
93- # Test that the copy parameter defaults to None
94- pytest .param ({}, None , id = "no copy kwarg" ),
95- ],
96- )
93+ @pytest .mark .parametrize ("copy" , [False , True , None ])
9794@pytest .mark .parametrize (
9895 ("op" , "y" , "expect_list" ),
9996 [
@@ -130,8 +127,7 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
130127)
131128def test_update_ops (
132129 xp : ModuleType ,
133- kwargs : dict [str , bool | None ],
134- expect_copy : bool | None ,
130+ copy : bool | None ,
135131 op : _AtOp ,
136132 y : float ,
137133 expect_list : list [float ],
@@ -156,12 +152,34 @@ def test_update_ops(
156152 if y_ndim == 1 :
157153 y = xp .asarray ([y , y ])
158154
159- with assert_copy (x , expect_copy ):
160- z = at_op (x , idx , op , y , ** kwargs )
155+ with assert_copy (x , copy ):
156+ z = at_op (x , idx , op , y , copy = copy )
161157 assert isinstance (z , type (x ))
162158 xp_assert_equal (z , xp .asarray (expect ))
163159
164160
161+ @pytest .mark .parametrize ("op" , list (_AtOp ))
162+ def test_copy_default (xp : ModuleType , library : Backend , op : _AtOp ):
163+ """
164+ Test that the default copy behaviour is False for writeable arrays
165+ and True for read-only ones.
166+ """
167+ x = xp .asarray ([1.0 , 10.0 , 20.0 ])
168+ expect_copy = not is_writeable_array (x )
169+ meth = cast (Callable [..., Array ], getattr (at (x )[:2 ], op .value )) # type: ignore[no-any-explicit]
170+ with assert_copy (x , None , expect_copy ):
171+ _ = meth (2.0 )
172+
173+ x = xp .asarray ([1.0 , 10.0 , 20.0 ])
174+ # Dask's default copy value is True for bool masks,
175+ # even if the arrays are writeable.
176+ expect_copy = not is_writeable_array (x ) or library is Backend .DASK
177+ idx = xp .asarray ([True , True , False ])
178+ meth = cast (Callable [..., Array ], getattr (at (x , idx ), op .value )) # type: ignore[no-any-explicit]
179+ with assert_copy (x , None , expect_copy ):
180+ _ = meth (2.0 )
181+
182+
165183def test_copy_invalid ():
166184 a = np .asarray ([1 , 2 , 3 ])
167185 with pytest .raises (ValueError , match = "copy" ):
0 commit comments