33
44import numpy as np
55import pytest
6- from array_api_compat import array_namespace
76
87import array_api_extra as xpx # Let some tests bypass lazy_xp_function
98from array_api_extra import lazy_apply
109from array_api_extra ._lib import Backend
1110from array_api_extra ._lib ._testing import xp_assert_equal
12- from array_api_extra ._lib ._utils ._typing import Array
11+ from array_api_extra ._lib ._utils import _compat
12+ from array_api_extra ._lib ._utils ._compat import array_namespace
13+ from array_api_extra ._lib ._utils ._typing import Array , Device
1314from array_api_extra .testing import lazy_xp_function
1415
1516lazy_xp_function (
@@ -55,6 +56,8 @@ def f(x: Array) -> Array:
5556
5657@as_numpy
5758def test_lazy_apply_broadcast (xp : ModuleType , as_numpy : bool ):
59+ """Test that default shape and dtype are broadcasted from the inputs."""
60+
5861 def f (x : Array , y : Array ) -> Array :
5962 return x + y
6063
@@ -88,31 +91,117 @@ def f(x: Array) -> tuple[Array, Array]:
8891
8992
9093def test_lazy_apply_core_indices (da : ModuleType ):
91- """Test that a func that performs reductions along axes does so
94+ """
95+ Test that a function that performs reductions along axes does so
9296 globally and not locally to each Dask chunk.
9397 """
94- pytest .skip ("TODO" )
98+
99+ def f (x : Array ) -> Array :
100+ return x .sum (axis = 0 ) + x
101+
102+ x_np = np .arange (15 ).reshape (5 , 3 )
103+ expect = da .asarray (f (x_np ))
104+ x_da = da .asarray (x_np ).rechunk (3 )
105+
106+ # A naive map_blocks fails because it applies f to each chunk separately,
107+ # but f needs to reduce along axis 0 which is broken into multiple chunks.
108+ # axis 0 is a "core axis" or "core index" (from xarray.apply_ufunc's
109+ # "core dimension").
110+ with pytest .raises (AssertionError ):
111+ xp_assert_equal (da .map_blocks (f , x_da ), expect )
112+
113+ xp_assert_equal (lazy_apply (f , x_da ), expect )
95114
96115
97116def test_lazy_apply_dont_run_on_meta (da : ModuleType ):
98117 """Test that Dask won't try running func on the meta array,
99118 as it may have minimum size requirements.
100119 """
101- pytest .skip ("TODO" )
102120
121+ def f (x : Array ) -> Array :
122+ assert x .size
123+ return x + 1
103124
104- def test_lazy_apply_none_shape (da : ModuleType ):
105- pytest .skip ("TODO" )
125+ x = da .arange (10 )
126+ assert not x ._meta .size
127+ y = lazy_apply (f , x )
128+ xp_assert_equal (y , x + 1 )
106129
107130
108- @as_numpy
109- def test_lazy_apply_device (xp : ModuleType , as_numpy : bool ):
110- pytest . skip ( "TODO" )
131+ @pytest . mark . xfail_xp_backend ( Backend . JAX , reason = "unknown shape" )
132+ def test_lazy_apply_none_shape_in_args (xp : ModuleType , library : Backend ):
133+ x = xp . asarray ([ 1 , 1 , 2 , 2 , 2 ] )
111134
135+ xp2 = np if library is Backend .DASK else xp
112136
113- @as_numpy
114- def test_lazy_apply_no_args (xp : ModuleType , as_numpy : bool ):
115- pytest .skip ("TODO" )
137+ # Single output
138+ values = lazy_apply (xp2 .unique_values , x , shape = (None ,))
139+ xp_assert_equal (values , xp .asarray ([1 , 2 ]))
140+
141+ # Multi output
142+ int_type = xp .asarray (0 ).dtype
143+ values , counts = lazy_apply (
144+ xp2 .unique_counts ,
145+ x ,
146+ shape = ((None ,), (None ,)),
147+ dtype = (x .dtype , int_type ),
148+ )
149+ xp_assert_equal (values , xp .asarray ([1 , 2 ]))
150+ xp_assert_equal (counts , xp .asarray ([2 , 3 ]))
151+
152+
153+ def check_lazy_apply_none_shape_broadcast (x : Array ) -> Array :
154+ def f (x : Array ) -> Array :
155+ return x
156+
157+ x = x [x > 1 ]
158+ return lazy_apply (f , x )
159+
160+
161+ lazy_xp_function (check_lazy_apply_none_shape_broadcast )
162+
163+
164+ @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "bool mask" )
165+ @pytest .mark .xfail_xp_backend (Backend .JAX , reason = "unknown shape" )
166+ def test_lazy_apply_none_shape_broadcast (xp : ModuleType ):
167+ """Broadcast from input array with unknown shape"""
168+ x = xp .asarray ([1 , 2 , 2 ])
169+ actual = check_lazy_apply_none_shape_broadcast (x )
170+ xp_assert_equal (actual , xp .asarray ([2 , 2 ]))
171+
172+
173+ @pytest .mark .parametrize (
174+ "as_numpy" ,
175+ [
176+ False ,
177+ pytest .param (
178+ True ,
179+ marks = [
180+ pytest .mark .skip_xp_backend (
181+ Backend .ARRAY_API_STRICT , reason = "device->host copy"
182+ ),
183+ pytest .mark .skip_xp_backend (Backend .CUPY , reason = "device->host copy" ),
184+ pytest .mark .skip_xp_backend (Backend .SPARSE , reason = "densification" ),
185+ ],
186+ ),
187+ ],
188+ )
189+ def test_lazy_apply_device (xp : ModuleType , as_numpy : bool , device : Device ):
190+ def f (x : Array ) -> Array :
191+ xp2 = array_namespace (x )
192+ # Deliberately forgetting to add device here to test that the
193+ # output is transferred to the right device. This is necessary when
194+ # as_numpy=True anyway.
195+ return xp2 .zeros (x .shape , dtype = x .dtype )
196+
197+ x = xp .asarray ([1 , 2 ], device = device )
198+ y = lazy_apply (f , x , as_numpy = as_numpy )
199+ assert _compat .device (y ) == device
200+
201+
202+ def test_lazy_apply_no_args (xp : ModuleType ):
203+ with pytest .raises (ValueError , match = "at least one argument" ):
204+ lazy_apply (lambda : xp .zeros (1 ), shape = (1 ,), dtype = xp .zeros (1 ).dtype , xp = xp )
116205
117206
118207class NT (NamedTuple ):
@@ -128,7 +217,8 @@ def eager(
128217 scalar : int ,
129218 ) -> Array :
130219 assert isinstance (x , expect_cls )
131- assert int (x ) == 0 # JAX will crash if x isn't material
220+ # JAX will crash if x isn't material
221+ assert int (x ) == 0 # type: ignore[call-overload]
132222 # Did we re-wrap the namedtuple correctly, or did it get
133223 # accidentally changed to a basic tuple?
134224 assert isinstance (z ["foo" ], NT )
0 commit comments