11import contextlib
22import math
33import warnings
4- from collections .abc import Callable
54from types import ModuleType
65
76import hypothesis
3635# some xp backends are untyped
3736# mypy: disable-error-code=no-untyped-def
3837
38+ lazy_xp_function (apply_where , static_argnums = (2 , 3 ), static_argnames = "xp" )
3939lazy_xp_function (atleast_nd , static_argnames = ("ndim" , "xp" ))
4040lazy_xp_function (cov , static_argnames = "xp" )
4141# FIXME .device attribute https://github.com/data-apis/array-api-compat/pull/238
5050lazy_xp_function (sinc , jax_jit = False , static_argnames = "xp" )
5151
5252
53- def apply_where_jit ( # type: ignore[no-any-explicit]
54- cond : Array ,
55- f1 : Callable [..., Array ],
56- f2 : Callable [..., Array ] | None ,
57- args : Array | tuple [Array , ...],
58- fill_value : Array | int | float | complex | bool | None = None ,
59- xp : ModuleType | None = None ,
60- ) -> Array :
61- """
62- Work around jax.jit's inability to handle variadic positional arguments.
63-
64- This is a lazy_xp_function artefact for when jax.jit is applied directly
65- to apply_where, which would not happen in real life.
66- """
67- if f2 is None :
68- return apply_where (cond , f1 , args , fill_value = fill_value , xp = xp )
69- assert fill_value is None
70- return apply_where (cond , f1 , f2 , args , xp = xp )
71-
72-
73- lazy_xp_function (apply_where_jit , static_argnames = ("f1" , "f2" , "xp" ))
74-
75-
7653class TestApplyWhere :
7754 @staticmethod
7855 def f1 (x : Array , y : Array | int = 10 ) -> Array :
@@ -86,27 +63,27 @@ def f2(x: Array, y: Array | int = 10) -> Array:
8663 def test_f1_f2 (self , xp : ModuleType ):
8764 x = xp .asarray ([1 , 2 , 3 , 4 ])
8865 cond = x % 2 == 0
89- actual = apply_where_jit (cond , self .f1 , self .f2 , x )
66+ actual = apply_where (cond , x , self .f1 , self .f2 )
9067 expect = xp .where (cond , self .f1 (x ), self .f2 (x ))
9168 xp_assert_equal (actual , expect )
9269
9370 @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "read-only without .at" )
9471 def test_fill_value (self , xp : ModuleType ):
9572 x = xp .asarray ([1 , 2 , 3 , 4 ])
9673 cond = x % 2 == 0
97- actual = apply_where_jit (x % 2 == 0 , self .f1 , None , x , fill_value = 0 )
74+ actual = apply_where (x % 2 == 0 , x , self .f1 , fill_value = 0 )
9875 expect = xp .where (cond , self .f1 (x ), xp .asarray (0 ))
9976 xp_assert_equal (actual , expect )
10077
101- actual = apply_where_jit (x % 2 == 0 , self .f1 , None , x , fill_value = xp .asarray (0 ))
78+ actual = apply_where (x % 2 == 0 , x , self .f1 , fill_value = xp .asarray (0 ))
10279 xp_assert_equal (actual , expect )
10380
10481 @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "read-only without .at" )
10582 def test_args_tuple (self , xp : ModuleType ):
10683 x = xp .asarray ([1 , 2 , 3 , 4 ])
10784 y = xp .asarray ([10 , 20 , 30 , 40 ])
10885 cond = x % 2 == 0
109- actual = apply_where_jit (cond , self .f1 , self .f2 , ( x , y ) )
86+ actual = apply_where (cond , ( x , y ), self .f1 , self .f2 )
11087 expect = xp .where (cond , self .f1 (x , y ), self .f2 (x , y ))
11188 xp_assert_equal (actual , expect )
11289
@@ -116,21 +93,21 @@ def test_broadcast(self, xp: ModuleType):
11693 y = xp .asarray ([[10 ], [20 ], [30 ]])
11794 cond = xp .broadcast_to (xp .asarray (True ), (4 , 1 , 1 ))
11895
119- actual = apply_where_jit (cond , self .f1 , self .f2 , ( x , y ) )
96+ actual = apply_where (cond , ( x , y ), self .f1 , self .f2 )
12097 expect = xp .where (cond , self .f1 (x , y ), self .f2 (x , y ))
12198 xp_assert_equal (actual , expect )
12299
123- actual = apply_where_jit (
100+ actual = apply_where (
124101 cond ,
102+ (x , y ),
125103 lambda x , _ : x , # pyright: ignore[reportUnknownArgumentType]
126104 lambda _ , y : y , # pyright: ignore[reportUnknownArgumentType]
127- (x , y ),
128105 )
129106 expect = xp .where (cond , x , y )
130107 xp_assert_equal (actual , expect )
131108
132109 # Shaped fill_value
133- actual = apply_where_jit (cond , self .f1 , None , x , fill_value = y )
110+ actual = apply_where (cond , x , self .f1 , fill_value = y )
134111 expect = xp .where (cond , self .f1 (x ), y )
135112 xp_assert_equal (actual , expect )
136113
@@ -141,15 +118,15 @@ def test_dtype_propagation(self, xp: ModuleType, library: Backend):
141118 cond = x % 2 == 0
142119
143120 mxp = np if library is Backend .DASK else xp
144- actual = apply_where_jit (
121+ actual = apply_where (
145122 cond ,
123+ (x , y ),
146124 self .f1 ,
147125 lambda x , y : mxp .astype (x - y , xp .int64 ), # pyright: ignore[reportUnknownArgumentType]
148- (x , y ),
149126 )
150127 assert actual .dtype == xp .int64
151128
152- actual = apply_where_jit (cond , self .f1 , None , y , fill_value = 5 )
129+ actual = apply_where (cond , y , self .f1 , fill_value = 5 )
153130 assert actual .dtype == xp .int16
154131
155132 @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "read-only without .at" )
@@ -168,14 +145,14 @@ def test_dtype_propagation_fill_value(
168145 cond = x % 2 == 0
169146 fill_value = xp .asarray (fill_value_raw , dtype = getattr (xp , fill_value_dtype ))
170147
171- actual = apply_where_jit (cond , self .f1 , None , x , fill_value = fill_value )
148+ actual = apply_where (cond , x , self .f1 , fill_value = fill_value )
172149 assert actual .dtype == getattr (xp , expect_dtype )
173150
174151 @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "read-only without .at" )
175152 def test_dont_overwrite_fill_value (self , xp : ModuleType ):
176153 x = xp .asarray ([1 , 2 ])
177154 fill_value = xp .asarray ([100 , 200 ])
178- actual = apply_where_jit (x % 2 == 0 , self .f1 , None , x , fill_value = fill_value )
155+ actual = apply_where (x % 2 == 0 , x , self .f1 , fill_value = fill_value )
179156 xp_assert_equal (actual , xp .asarray ([100 , 12 ]))
180157 xp_assert_equal (fill_value , xp .asarray ([100 , 200 ]))
181158
@@ -184,11 +161,11 @@ def test_dont_run_on_false(self, xp: ModuleType):
184161 x = xp .asarray ([1.0 , 2.0 , 0.0 ])
185162 y = xp .asarray ([0.0 , 3.0 , 4.0 ])
186163 # On NumPy, division by zero will trigger warnings
187- actual = apply_where_jit (
164+ actual = apply_where (
188165 x == 0 ,
166+ (x , y ),
189167 lambda x , y : x / y , # pyright: ignore[reportUnknownArgumentType]
190168 lambda x , y : y / x , # pyright: ignore[reportUnknownArgumentType]
191- (x , y ),
192169 )
193170 xp_assert_equal (actual , xp .asarray ([0.0 , 1.5 , 0.0 ]))
194171
@@ -197,29 +174,28 @@ def test_bad_args(self, xp: ModuleType):
197174 cond = x % 2 == 0
198175 # Neither f2 nor fill_value
199176 with pytest .raises (TypeError , match = "Exactly one of" ):
200- apply_where (cond , self .f1 , x ) # type: ignore[call-overload] # pyright: ignore[reportCallIssue]
177+ apply_where (cond , x , self .f1 ) # type: ignore[call-overload] # pyright: ignore[reportCallIssue]
201178 # Both f2 and fill_value
202179 with pytest .raises (TypeError , match = "Exactly one of" ):
203- apply_where (cond , self .f1 , self .f2 , x , fill_value = 0 ) # type: ignore[call-overload] # pyright: ignore[reportCallIssue]
204- # Multiple args; forgot to wrap them in a tuple
205- with pytest .raises (TypeError , match = "takes from 3 to 4 positional arguments" ):
206- apply_where (cond , self .f1 , self .f2 , x , x ) # type: ignore[call-overload] # pyright: ignore[reportCallIssue]
207- with pytest .raises (TypeError , match = "callable" ):
208- apply_where (cond , self .f1 , x , x , fill_value = 0 ) # type: ignore[call-overload] # pyright: ignore[reportCallIssue]
180+ apply_where (cond , x , self .f1 , self .f2 , fill_value = 0 ) # type: ignore[call-overload] # pyright: ignore[reportCallIssue]
209181
210182 @pytest .mark .skip_xp_backend (Backend .NUMPY_READONLY , reason = "xp=xp" )
211183 @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "read-only without .at" )
212184 def test_xp (self , xp : ModuleType ):
213185 x = xp .asarray ([1 , 2 , 3 , 4 ])
214186 cond = x % 2 == 0
215- actual = apply_where_jit (cond , self .f1 , self .f2 , x , xp = xp )
187+ actual = apply_where (cond , x , self .f1 , self .f2 , xp = xp )
216188 expect = xp .where (cond , self .f1 (x ), self .f2 (x ))
217189 xp_assert_equal (actual , expect )
218190
219191 @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "read-only without .at" )
220192 def test_device (self , xp : ModuleType , device : Device ):
221193 x = xp .asarray ([1 , 2 , 3 , 4 ], device = device )
222- y = apply_where_jit (x % 2 == 0 , self .f1 , self .f2 , x )
194+ y = apply_where (x % 2 == 0 , x , self .f1 , self .f2 )
195+ assert get_device (y ) == device
196+ y = apply_where (x % 2 == 0 , x , self .f1 , fill_value = 0 )
197+ assert get_device (y ) == device
198+ y = apply_where (x % 2 == 0 , x , self .f1 , fill_value = x )
223199 assert get_device (y ) == device
224200
225201 # skip instead of xfail in order not to waste time
@@ -273,10 +249,9 @@ def f2(*args: Array) -> Array:
273249 rng = np .random .default_rng (rng_seed )
274250 cond = xp .asarray (rng .random (size = cond_shape ) > p )
275251
276- # Use apply_where instead of apply_where_jit to speed the test up
277- res1 = apply_where (cond , f1 , arrays , fill_value = fill_value )
278- res2 = apply_where (cond , f1 , f2 , arrays )
279- res3 = apply_where (cond , f1 , arrays , fill_value = float_fill_value )
252+ res1 = apply_where (cond , arrays , f1 , fill_value = fill_value )
253+ res2 = apply_where (cond , arrays , f1 , f2 )
254+ res3 = apply_where (cond , arrays , f1 , fill_value = float_fill_value )
280255
281256 ref1 = xp .where (cond , f1 (* arrays ), fill_value )
282257 ref2 = xp .where (cond , f1 (* arrays ), f2 (* arrays ))
0 commit comments