1+ import pickle
12from collections .abc import Callable , Generator
23from contextlib import contextmanager
4+ from numbers import Number
35from types import ModuleType
46from typing import cast
57
1113from array_api_extra ._lib ._at import _AtOp
1214from array_api_extra ._lib ._testing import xp_assert_equal
1315from array_api_extra ._lib ._utils ._compat import array_namespace , is_writeable_array
14- from array_api_extra ._lib ._utils ._typing import Array
16+ from array_api_extra ._lib ._utils ._typing import Array , Index
17+ from array_api_extra .testing import lazy_xp_function
18+
19+
20+ def at_op (
21+ x : Array ,
22+ idx : Index ,
23+ op : _AtOp ,
24+ y : Array | Number ,
25+ copy : bool | None = None ,
26+ xp : ModuleType | None = None ,
27+ ) -> Array :
28+ """
29+ Wrapper around at(x, idx).op(y, copy=copy, xp=xp).
30+
31+ This is a hack to allow wrapping `at()` with `lazy_xp_function`.
32+ For clarity, at() itself works inside jax.jit without hacks; this is
33+ just a workaround for when one wants to apply jax.jit to `at()` directly,
34+ which is not a common use case.
35+ """
36+ if isinstance (idx , (slice | tuple )):
37+ return _at_op (x , None , pickle .dumps (idx ), op , y , copy = copy , xp = xp )
38+ return _at_op (x , idx , None , op , y , copy = copy , xp = xp )
39+
40+
41+ def _at_op (
42+ x : Array ,
43+ idx : Index | None ,
44+ idx_pickle : bytes | None ,
45+ op : _AtOp ,
46+ y : Array | Number ,
47+ copy : bool | None = None ,
48+ xp : ModuleType | None = None ,
49+ ) -> Array :
50+ """jitted helper of at_op"""
51+ if idx_pickle :
52+ idx = pickle .loads (idx_pickle )
53+ meth = cast (Callable [..., Array ], getattr (at (x , idx ), op .value )) # type: ignore[no-any-explicit]
54+ return meth (y , copy = copy , xp = xp )
55+
56+
57+ lazy_xp_function (_at_op , static_argnames = ("op" , "idx_pickle" , "copy" , "xp" ))
1558
1659
1760@contextmanager
@@ -43,7 +86,7 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
4386 ],
4487)
4588@pytest .mark .parametrize (
46- ("op" , "arg " , "expect" ),
89+ ("op" , "y " , "expect" ),
4790 [
4891 (_AtOp .SET , 40.0 , [10.0 , 40.0 , 40.0 ]),
4992 (_AtOp .ADD , 40.0 , [10.0 , 60.0 , 70.0 ]),
@@ -55,21 +98,52 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
5598 (_AtOp .MAX , 25.0 , [10.0 , 25.0 , 30.0 ]),
5699 ],
57100)
101+ @pytest .mark .parametrize (
102+ ("bool_mask" , "shaped_y" ),
103+ [
104+ (False , False ),
105+ (False , True ),
106+ pytest .param (
107+ True ,
108+ False ,
109+ marks = (
110+ pytest .mark .skip_xp_backend (Backend .JAX , reason = "TODO special case" ),
111+ pytest .mark .skip_xp_backend (Backend .DASK , reason = "TODO special case" ),
112+ ),
113+ ),
114+ pytest .param (
115+ True ,
116+ True ,
117+ marks = (
118+ pytest .mark .skip_xp_backend (
119+ Backend .JAX , reason = "bool mask update with shaped rhs"
120+ ),
121+ pytest .mark .skip_xp_backend (
122+ Backend .DASK , reason = "bool mask update with shaped rhs"
123+ ),
124+ ),
125+ ),
126+ ],
127+ )
58128def test_update_ops (
59129 xp : ModuleType ,
60130 kwargs : dict [str , bool | None ],
61131 expect_copy : bool | None ,
62132 op : _AtOp ,
63- arg : float ,
133+ y : float ,
64134 expect : list [float ],
135+ bool_mask : bool ,
136+ shaped_y : bool ,
65137):
66- array = xp .asarray ([10.0 , 20.0 , 30.0 ])
138+ x = xp .asarray ([10.0 , 20.0 , 30.0 ])
139+ idx = xp .asarray ([False , True , True ]) if bool_mask else slice (1 , None )
140+ if shaped_y :
141+ y = xp .asarray ([y , y ])
67142
68- with assert_copy (array , expect_copy ):
69- func = cast (Callable [..., Array ], getattr (at (array )[1 :], op .value )) # type: ignore[no-any-explicit]
70- y = func (arg , ** kwargs )
71- assert isinstance (y , type (array ))
72- xp_assert_equal (y , xp .asarray (expect ))
143+ with assert_copy (x , expect_copy ):
144+ z = at_op (x , idx , op , y , ** kwargs ) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
145+ assert isinstance (z , type (x ))
146+ xp_assert_equal (z , xp .asarray (expect ))
73147
74148
75149def test_copy_invalid ():
@@ -121,7 +195,6 @@ def test_iops_incompatible_dtype(op: _AtOp, copy: bool):
121195 UFuncTypeError: Cannot cast ufunc 'divide' output from dtype('float64')
122196 to dtype('int64') with casting rule 'same_kind'
123197 """
124- a = np .asarray ([2 , 4 ])
125- func = cast (Callable [..., Array ], getattr (at (a )[:], op .value )) # type: ignore[no-any-explicit]
198+ x = np .asarray ([2 , 4 ])
126199 with pytest .raises (TypeError , match = "Cannot cast ufunc" ):
127- func ( 1.1 , copy = copy )
200+ at_op ( x , slice ( None ), op , 1.1 , copy = copy )
0 commit comments