1+ import pickle
12from collections .abc import Callable , Generator
23from contextlib import contextmanager
4+ from numbers import Number
35from types import ModuleType
46from typing import cast
57
1113from array_api_extra ._lib ._at import _AtOp
1214from array_api_extra ._lib ._testing import xp_assert_equal
1315from array_api_extra ._lib ._utils ._compat import array_namespace , is_writeable_array
14- from array_api_extra ._lib ._utils ._typing import Array
16+ from array_api_extra ._lib ._utils ._typing import Array , Index
17+ from array_api_extra .testing import lazy_xp_function
18+
19+
20+ def at_op (
21+ x : Array ,
22+ idx : Index ,
23+ op : _AtOp ,
24+ y : Array | Number ,
25+ copy : bool | None = None ,
26+ xp : ModuleType | None = None ,
27+ ) -> Array :
28+ """
29+ Wrapper around at(x, idx).op(y, copy=copy, xp=xp).
30+
31+ This is a hack to allow wrapping `at()` with `lazy_xp_function`.
32+ """
33+ if isinstance (idx , (slice | tuple )):
34+ return _at_op (x , None , pickle .dumps (idx ), op , y , copy = copy , xp = xp )
35+ return _at_op (x , idx , None , op , y , copy = copy , xp = xp )
36+
37+
38+ def _at_op (
39+ x : Array ,
40+ idx : Index | None ,
41+ idx_pickle : bytes | None ,
42+ op : _AtOp ,
43+ y : Array | Number ,
44+ copy : bool | None = None ,
45+ xp : ModuleType | None = None ,
46+ ) -> Array :
47+ """jitted helper of at_op"""
48+ if idx_pickle :
49+ idx = pickle .loads (idx_pickle )
50+ meth = cast (Callable [..., Array ], getattr (at (x , idx ), op .value )) # type: ignore[no-any-explicit]
51+ return meth (y , copy = copy , xp = xp )
52+
53+
54+ lazy_xp_function (_at_op , static_argnames = ("op" , "idx_pickle" , "copy" , "xp" ))
1555
1656
1757@contextmanager
@@ -43,7 +83,7 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
4383 ],
4484)
4585@pytest .mark .parametrize (
46- ("op" , "arg " , "expect" ),
86+ ("op" , "y " , "expect" ),
4787 [
4888 (_AtOp .SET , 40.0 , [10.0 , 40.0 , 40.0 ]),
4989 (_AtOp .ADD , 40.0 , [10.0 , 60.0 , 70.0 ]),
@@ -55,21 +95,51 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
5595 (_AtOp .MAX , 25.0 , [10.0 , 25.0 , 30.0 ]),
5696 ],
5797)
98+ @pytest .mark .parametrize (
99+ ("bool_mask" , "shaped_y" ),
100+ [
101+ (False , False ),
102+ pytest .param (
103+ True ,
104+ False ,
105+ marks = (
106+ pytest .mark .skip_xp_backend (Backend .JAX , reason = "TODO special case" ),
107+ pytest .mark .skip_xp_backend (Backend .DASK , reason = "TODO special case" ),
108+ ),
109+ ),
110+ pytest .param (
111+ True ,
112+ True ,
113+ marks = (
114+ pytest .mark .skip_xp_backend (
115+ Backend .JAX , reason = "bool mask update with shaped rhs"
116+ ),
117+ pytest .mark .skip_xp_backend (
118+ Backend .DASK , reason = "bool mask update with shaped rhs"
119+ ),
120+ ),
121+ ),
122+ ],
123+ )
58124def test_update_ops (
59125 xp : ModuleType ,
60126 kwargs : dict [str , bool | None ],
61127 expect_copy : bool | None ,
62128 op : _AtOp ,
63- arg : float ,
129+ y : float ,
64130 expect : list [float ],
131+ bool_mask : bool ,
132+ shaped_y : bool ,
65133):
66- array = xp .asarray ([10.0 , 20.0 , 30.0 ])
134+ x = xp .asarray ([10.0 , 20.0 , 30.0 ])
135+ idx = xp .asarray ([False , True , True ]) if bool_mask else slice (1 , None )
136+ if shaped_y :
137+ y = xp .asarray ([y , y ])
67138
68- with assert_copy (array , expect_copy ):
69- func = cast (Callable [..., Array ], getattr (at (array )[1 :], op .value )) # type: ignore[no-any-explicit]
70- y = func (arg , ** kwargs )
71- assert isinstance (y , type (array ))
72- xp_assert_equal (y , xp .asarray (expect ))
139+ with assert_copy (x , expect_copy ):
140+ z = at_op (x , idx , op , y , ** kwargs ) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
141+ assert isinstance (z , type (x ))
142+ xp_assert_equal (z , xp .asarray (expect ))
73143
74144
75145def test_copy_invalid ():
@@ -121,7 +191,6 @@ def test_iops_incompatible_dtype(op: _AtOp, copy: bool):
121191 UFuncTypeError: Cannot cast ufunc 'divide' output from dtype('float64')
122192 to dtype('int64') with casting rule 'same_kind'
123193 """
124- a = np .asarray ([2 , 4 ])
125- func = cast (Callable [..., Array ], getattr (at (a )[:], op .value )) # type: ignore[no-any-explicit]
194+ x = np .asarray ([2 , 4 ])
126195 with pytest .raises (TypeError , match = "Cannot cast ufunc" ):
127- func ( 1.1 , copy = copy )
196+ at_op ( x , slice ( None ), op , 1.1 , copy = copy )
0 commit comments