1+ import pickle
12from collections .abc import Callable , Generator
23from contextlib import contextmanager
4+ from numbers import Number
35from types import ModuleType
46from typing import cast
57
1113from array_api_extra ._lib ._at import _AtOp
1214from array_api_extra ._lib ._testing import xp_assert_equal
1315from array_api_extra ._lib ._utils ._compat import array_namespace , is_writeable_array
14- from array_api_extra ._lib ._utils ._typing import Array
16+ from array_api_extra ._lib ._utils ._typing import Array , Index
17+ from array_api_extra .testing import lazy_xp_function
18+
19+
20+ def at_op (
21+ x : Array ,
22+ idx : Index ,
23+ op : _AtOp ,
24+ y : Array | Number ,
25+ copy : bool | None = None ,
26+ xp : ModuleType | None = None ,
27+ ) -> Array :
28+ """
29+ Wrapper around at(x, idx).op(y, copy=copy, xp=xp).
30+
31+ This is a hack to allow wrapping `at()` with `lazy_xp_function`.
32+ """
33+ if isinstance (idx , (slice | tuple )):
34+ return _at_op (x , None , pickle .dumps (idx ), op , y , copy = copy , xp = xp )
35+ return _at_op (x , idx , None , op , y , copy = copy , xp = xp )
36+
37+
38+ def _at_op (
39+ x : Array ,
40+ idx : Index | None ,
41+ idx_pickle : bytes | None ,
42+ op : _AtOp ,
43+ y : Array | Number ,
44+ copy : bool | None = None ,
45+ xp : ModuleType | None = None ,
46+ ) -> Array :
47+ """jitted helper of at_op"""
48+ if idx_pickle :
49+ idx = pickle .loads (idx_pickle )
50+ meth = cast (Callable [..., Array ], getattr (at (x , idx ), op .value )) # type: ignore[no-any-explicit]
51+ return meth (y , copy = copy , xp = xp )
52+
53+
54+ lazy_xp_function (_at_op , static_argnames = ("op" , "idx_pickle" , "copy" , "xp" ))
1555
1656
1757@contextmanager
@@ -43,7 +83,7 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
4383 ],
4484)
4585@pytest .mark .parametrize (
46- ("op" , "arg " , "expect" ),
86+ ("op" , "y " , "expect" ),
4787 [
4888 (_AtOp .SET , 40.0 , [10.0 , 40.0 , 40.0 ]),
4989 (_AtOp .ADD , 40.0 , [10.0 , 60.0 , 70.0 ]),
@@ -55,21 +95,52 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
5595 (_AtOp .MAX , 25.0 , [10.0 , 25.0 , 30.0 ]),
5696 ],
5797)
98+ @pytest .mark .parametrize (
99+ ("bool_mask" , "shaped_y" ),
100+ [
101+ (False , False ),
102+ (False , True ),
103+ pytest .param (
104+ True ,
105+ False ,
106+ marks = (
107+ pytest .mark .skip_xp_backend (Backend .JAX , reason = "TODO special case" ),
108+ pytest .mark .skip_xp_backend (Backend .DASK , reason = "TODO special case" ),
109+ ),
110+ ),
111+ pytest .param (
112+ True ,
113+ True ,
114+ marks = (
115+ pytest .mark .skip_xp_backend (
116+ Backend .JAX , reason = "bool mask update with shaped rhs"
117+ ),
118+ pytest .mark .skip_xp_backend (
119+ Backend .DASK , reason = "bool mask update with shaped rhs"
120+ ),
121+ ),
122+ ),
123+ ],
124+ )
58125def test_update_ops (
59126 xp : ModuleType ,
60127 kwargs : dict [str , bool | None ],
61128 expect_copy : bool | None ,
62129 op : _AtOp ,
63- arg : float ,
130+ y : float ,
64131 expect : list [float ],
132+ bool_mask : bool ,
133+ shaped_y : bool ,
65134):
66- array = xp .asarray ([10.0 , 20.0 , 30.0 ])
135+ x = xp .asarray ([10.0 , 20.0 , 30.0 ])
136+ idx = xp .asarray ([False , True , True ]) if bool_mask else slice (1 , None )
137+ if shaped_y :
138+ y = xp .asarray ([y , y ])
67139
68- with assert_copy (array , expect_copy ):
69- func = cast (Callable [..., Array ], getattr (at (array )[1 :], op .value )) # type: ignore[no-any-explicit]
70- y = func (arg , ** kwargs )
71- assert isinstance (y , type (array ))
72- xp_assert_equal (y , xp .asarray (expect ))
140+ with assert_copy (x , expect_copy ):
141+ z = at_op (x , idx , op , y , ** kwargs ) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
142+ assert isinstance (z , type (x ))
143+ xp_assert_equal (z , xp .asarray (expect ))
73144
74145
75146def test_copy_invalid ():
@@ -121,7 +192,6 @@ def test_iops_incompatible_dtype(op: _AtOp, copy: bool):
121192 UFuncTypeError: Cannot cast ufunc 'divide' output from dtype('float64')
122193 to dtype('int64') with casting rule 'same_kind'
123194 """
124- a = np .asarray ([2 , 4 ])
125- func = cast (Callable [..., Array ], getattr (at (a )[:], op .value )) # type: ignore[no-any-explicit]
195+ x = np .asarray ([2 , 4 ])
126196 with pytest .raises (TypeError , match = "Cannot cast ufunc" ):
127- func ( 1.1 , copy = copy )
197+ at_op ( x , slice ( None ), op , 1.1 , copy = copy )
0 commit comments