Skip to content

Commit d70e03c

Browse files
committed
finalize
1 parent edf2ce8 commit d70e03c

File tree

2 files changed

+64
-11
lines changed

2 files changed

+64
-11
lines changed

src/array_api_extra/_lib/_lazy.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -222,21 +222,19 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
222222
dtypes = [xp.result_type(*args)] * len(shapes)
223223
elif multi_output:
224224
if not isinstance(dtype, Sequence):
225-
msg = "Got sequence of shapes but only one dtype"
226-
raise TypeError(msg)
225+
msg = "Got multiple shapes but only one dtype"
226+
raise ValueError(msg)
227227
dtypes = list(dtype) # pyright: ignore[reportUnknownArgumentType]
228228
else:
229229
if isinstance(dtype, Sequence):
230230
msg = "Got single shape but multiple dtypes"
231-
raise TypeError(msg)
231+
raise ValueError(msg)
232+
232233
dtypes = [dtype]
233234

234235
if len(shapes) != len(dtypes):
235236
msg = f"Got {len(shapes)} shapes and {len(dtypes)} dtypes"
236237
raise ValueError(msg)
237-
if len(shapes) == 0:
238-
msg = "func must return one or more output arrays"
239-
raise ValueError(msg)
240238
del shape
241239
del dtype
242240

tests/test_lazy.py

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,48 @@ def f(x: Array) -> tuple[Array, Array]:
9090
xp_assert_equal(actual[1], expect[1])
9191

9292

93+
@pytest.mark.parametrize(
94+
"as_numpy",
95+
[
96+
pytest.param(
97+
False,
98+
marks=[
99+
pytest.mark.xfail_xp_backend(
100+
Backend.TORCH, reason="illegal dtype promotion"
101+
),
102+
],
103+
),
104+
pytest.param(
105+
True,
106+
marks=[
107+
pytest.mark.skip_xp_backend(Backend.CUPY, reason="device->host copy"),
108+
pytest.mark.skip_xp_backend(Backend.SPARSE, reason="densification"),
109+
],
110+
),
111+
],
112+
)
113+
def test_lazy_apply_multi_output_broadcast_dtype(xp: ModuleType, as_numpy: bool):
114+
"""
115+
If dtype is omitted and there are multiple shapes, use the same
116+
dtype for all output arrays, broadcasted from the inputs
117+
"""
118+
119+
def f(x: Array, y: Array) -> tuple[Array, Array]:
120+
return x + y, x - y
121+
122+
x = xp.asarray([1, 2], dtype=xp.float32)
123+
y = xp.asarray(3, dtype=xp.float64)
124+
expect = (
125+
xp.asarray([4, 5], dtype=xp.float64),
126+
xp.asarray([-2, -1], dtype=xp.float64),
127+
)
128+
actual = lazy_apply(f, x, y, shape=((2,), (2,)), as_numpy=as_numpy)
129+
assert isinstance(actual, tuple)
130+
assert len(actual) == 2
131+
xp_assert_equal(actual[0], expect[0])
132+
xp_assert_equal(actual[1], expect[1])
133+
134+
93135
def test_lazy_apply_core_indices(da: ModuleType):
94136
"""
95137
Test that a function that performs reductions along axes does so
@@ -199,11 +241,6 @@ def f(x: Array) -> Array:
199241
assert _compat.device(y) == device
200242

201243

202-
def test_lazy_apply_no_args(xp: ModuleType):
203-
with pytest.raises(ValueError, match="at least one argument"):
204-
lazy_apply(lambda: xp.zeros(1), shape=(1,), dtype=xp.zeros(1).dtype, xp=xp)
205-
206-
207244
class NT(NamedTuple):
208245
a: Array
209246

@@ -292,3 +329,21 @@ def test_lazy_apply_raises(xp: ModuleType) -> None:
292329
# exception not to be raised.
293330
# However, lazy_xp_function will do it for us on function exit.
294331
raises(x)
332+
333+
334+
def test_invalid_args():
335+
def f(x: Array) -> Array:
336+
return x
337+
338+
x = np.asarray(1)
339+
340+
with pytest.raises(ValueError, match="at least one argument"):
341+
_ = lazy_apply(f, shape=(1,), dtype=np.int32, xp=np)
342+
with pytest.raises(ValueError, match="at least one argument"):
343+
_ = lazy_apply(f, shape=(1,), dtype=np.int32)
344+
with pytest.raises(ValueError, match="multiple shapes but only one dtype"):
345+
_ = lazy_apply(f, x, shape=[(1,), (2,)], dtype=np.int32) # type: ignore[call-overload] # pyright: ignore[reportCallIssue,reportArgumentType]
346+
with pytest.raises(ValueError, match="single shape but multiple dtypes"):
347+
_ = lazy_apply(f, x, shape=(1,), dtype=[np.int32, np.int64])
348+
with pytest.raises(ValueError, match="2 shapes and 1 dtypes"):
349+
_ = lazy_apply(f, x, shape=[(1,), (2,)], dtype=[np.int32])

0 commit comments

Comments
 (0)