1010from array_api_extra ._lib ._utils ._typing import Array
1111from array_api_extra .testing import lazy_xp_function
1212
13- skip_as_numpy = [
14- pytest .mark .skip_xp_backend (Backend .CUPY , reason = "device->host transfer" ),
15- pytest .mark .skip_xp_backend (Backend .SPARSE , reason = "densification" ),
16- ]
13+ as_numpy = pytest .mark .parametrize (
14+ "as_numpy" ,
15+ [
16+ False ,
17+ pytest .param (
18+ True ,
19+ marks = [
20+ pytest .mark .skip_xp_backend (Backend .CUPY , reason = "device->host copy" ),
21+ pytest .mark .skip_xp_backend (Backend .SPARSE , reason = "densification" ),
22+ ],
23+ ),
24+ ],
25+ )
26+
27+
28+ @as_numpy
29+ def test_lazy_apply_simple (xp : ModuleType , as_numpy : bool ):
30+ pytest .skip ("TODO" )
31+
32+
33+ @as_numpy
34+ def test_lazy_apply_broadcast (xp : ModuleType , as_numpy : bool ):
35+ pytest .skip ("TODO" )
36+
37+
38+ @as_numpy
39+ def test_lazy_apply_multi_output (xp : ModuleType , as_numpy : bool ):
40+ pytest .skip ("TODO" )
41+
42+
43+ def test_lazy_apply_core_indices (da : ModuleType ):
44+ """Test that a func that performs reductions along axes does so
45+ globally and not locally to each Dask chunk.
46+ """
47+ pytest .skip ("TODO" )
48+
49+
50+ def test_lazy_apply_dont_run_on_meta (da : ModuleType ):
51+ """Test that Dask won't try running func on the meta array,
52+ as it may have minimum size requirements.
53+ """
54+ pytest .skip ("TODO" )
55+
56+
57+ def test_lazy_apply_none_shape (da : ModuleType ):
58+ pytest .skip ("TODO" )
59+
60+
61+ @as_numpy
62+ def test_lazy_apply_device (xp : ModuleType , as_numpy : bool ):
63+ pytest .skip ("TODO" )
64+
65+
66+ @as_numpy
67+ def test_lazy_apply_no_args (xp : ModuleType , as_numpy : bool ):
68+ pytest .skip ("TODO" )
1769
1870
1971class NT (NamedTuple ):
2072 a : Array
2173
2274
23- def check_lazy_apply_kwargs (x : Array , expect : type , as_numpy : bool ) -> Array :
24- def f (
75+ def check_lazy_apply_kwargs (x : Array , expect_cls : type , as_numpy : bool ) -> Array :
76+ def eager (
2577 x : Array ,
2678 z : dict [str , list [Array ] | tuple [Array , ...] | NT ],
2779 msg : str ,
2880 msgs : list [str ],
81+ scalar : int ,
2982 ) -> Array :
30- assert isinstance (x , expect )
83+ assert isinstance (x , expect_cls )
84+ assert int (x ) == 0 # JAX will crash if x isn't material
85+ # Did we re-wrap the namedtuple correctly, or did it get
86+ # accidentally changed to a basic tuple?
3187 assert isinstance (z ["foo" ], NT )
32- assert isinstance (z ["foo" ].a , expect )
33- assert isinstance (z ["bar" ][0 ], expect )
34- assert isinstance (z ["baz" ][0 ], expect )
35- assert msg == "Hello World"
36- assert msgs [0 ] == "Hello World"
37- return x + 1
88+ assert isinstance (z ["foo" ].a , expect_cls )
89+ assert isinstance (z ["bar" ][0 ], expect_cls ) # list
90+ assert isinstance (z ["baz" ][0 ], expect_cls ) # tuple
91+ assert msg == "Hello World" # must be hidden from JAX
92+ assert msgs [0 ] == "Hello World" # must be hidden from JAX
93+ assert isinstance (msg , str )
94+ assert isinstance (msgs [0 ], str )
95+ assert scalar == 1 # must be hidden from JAX
96+ assert isinstance (scalar , int )
97+ return x + 1 # type: ignore[operator]
3898
3999 return lazy_apply ( # pyright: ignore[reportCallIssue]
40- f ,
100+ eager ,
41101 x ,
102+ # These kwargs can and should be passed through jax.pure_callback
42103 z = {"foo" : NT (x ), "bar" : [x ], "baz" : (x ,)},
104+ # These can't
43105 msg = "Hello World" ,
44106 msgs = ["Hello World" ],
107+ # This will be automatically cast to jax.Array if we don't wrap it
108+ scalar = 1 ,
45109 shape = x .shape ,
46110 dtype = x .dtype ,
47111 as_numpy = as_numpy ,
48112 )
49113
50- lazy_xp_function (check_lazy_apply_kwargs , static_argnames = ("expect" , "as_numpy" ))
114+
115+ lazy_xp_function (check_lazy_apply_kwargs , static_argnames = ("expect_cls" , "as_numpy" ))
51116
52117
53- @pytest . mark . parametrize ( " as_numpy" , [ False , pytest . param ( True , marks = skip_as_numpy )])
118+ @as_numpy
54119def test_lazy_apply_kwargs (xp : ModuleType , library : Backend , as_numpy : bool ) -> None :
55- expect = np .ndarray if as_numpy or library is Backend .DASK else type (xp .asarray (0 ))
120+ """When as_numpy=True, search and replace arrays in the (nested) keywords arguments
121+ with numpy arrays, and leave the rest untouched."""
122+ expect_cls = (
123+ np .ndarray if as_numpy or library is Backend .DASK else type (xp .asarray (0 ))
124+ )
56125 x = xp .asarray (0 )
57- actual = check_lazy_apply_kwargs (x , expect , as_numpy )
126+ actual = check_lazy_apply_kwargs (x , expect_cls , as_numpy ) # pyright: ignore[reportUnknownArgumentType]
58127 xp_assert_equal (actual , x + 1 )
59128
60129
@@ -69,10 +138,12 @@ def eager(_: Array) -> Array:
69138
70139 return lazy_apply (eager , x , shape = x .shape , dtype = x .dtype )
71140
72- lazy_xp_function (raises )
141+
142+ # jax.pure_callback does not support raising
143+ # https://github.com/jax-ml/jax/issues/26102
144+ lazy_xp_function (raises , jax_jit = False )
73145
74146
75- @pytest .mark .skip_xp_backend (Backend .JAX_JIT , reason = "no exception support" )
76147def test_lazy_apply_raises (xp : ModuleType ) -> None :
77148 x = xp .asarray (0 )
78149
0 commit comments