Skip to content

Commit 04c3e1a

Browse files
committed
Code review
1 parent dc00bfd commit 04c3e1a

File tree

2 files changed

+38
-22
lines changed

2 files changed

+38
-22
lines changed

src/array_api_extra/_lib/_lazy.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,9 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
179179
180180
Raises
181181
------
182-
jax.errors.TracerArrayConversionError
183-
When ``xp=jax.numpy``, `shape` is unknown (it contains None on one or more axes)
184-
and this function was called inside ``jax.jit``.
182+
ValueError
183+
When ``xp=jax.numpy``, the output `shape` is unknown (it contains ``None`` on
184+
one or more axes) and this function was called inside ``jax.jit``.
185185
RuntimeError
186186
When ``xp=sparse`` and auto-densification is disabled.
187187
Exception (backend-specific)
@@ -272,6 +272,10 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
272272
# https://github.com/jax-ml/jax/issues/26102
273273
import jax
274274

275+
if any(None in shape for shape in shapes):
276+
msg = "Output shape must be fully known when running inside jax.jit"
277+
raise ValueError(msg)
278+
275279
# Shield eager kwargs from being coerced into JAX arrays.
276280
# jax.pure_callback calls jax.jit under the hood, but without the chance of
277281
# passing static_argnames / static_argnums.

tests/test_lazy.py

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -191,26 +191,40 @@ def f(x: Array) -> Array:
191191
xp_assert_equal(y.compute(), x_cp + 1) # type: ignore[attr-defined] # pyright: ignore[reportUnknownArgumentType,reportAttributeAccessIssue]
192192

193193

194-
@pytest.mark.xfail_xp_backend(Backend.JAX, reason="unknown shape")
195194
def test_lazy_apply_none_shape_in_args(xp: ModuleType, library: Backend):
196195
x = xp.asarray([1, 1, 2, 2, 2])
197196

198-
xp2 = np if library is Backend.DASK else xp
199-
200-
# Single output
201-
values = lazy_apply(xp2.unique_values, x, shape=(None,))
202-
xp_assert_equal(values, xp.asarray([1, 2]))
203-
204-
# Multi output
197+
# TODO mxp = meta_namespace(x, xp=xp)
198+
mxp = np if library is Backend.DASK else xp
205199
int_type = xp.asarray(0).dtype
206-
values, counts = lazy_apply(
207-
xp2.unique_counts,
208-
x,
209-
shape=((None,), (None,)),
210-
dtype=(x.dtype, int_type),
211-
)
212-
xp_assert_equal(values, xp.asarray([1, 2]))
213-
xp_assert_equal(counts, xp.asarray([2, 3]))
200+
201+
if library is Backend.JAX:
202+
# Single output
203+
with pytest.raises(ValueError, match="Output shape must be fully known"):
204+
_ = lazy_apply(mxp.unique_values, x, shape=(None,))
205+
206+
# Multi output
207+
with pytest.raises(ValueError, match="Output shape must be fully known"):
208+
_ = lazy_apply(
209+
mxp.unique_counts,
210+
x,
211+
shape=((None,), (None,)),
212+
dtype=(x.dtype, int_type),
213+
)
214+
else:
215+
# Single output
216+
values = lazy_apply(mxp.unique_values, x, shape=(None,))
217+
xp_assert_equal(values, xp.asarray([1, 2]))
218+
219+
# Multi output
220+
values, counts = lazy_apply(
221+
mxp.unique_counts,
222+
x,
223+
shape=((None,), (None,)),
224+
dtype=(x.dtype, int_type),
225+
)
226+
xp_assert_equal(values, xp.asarray([1, 2]))
227+
xp_assert_equal(counts, xp.asarray([2, 3]))
214228

215229

216230
def check_lazy_apply_none_shape_broadcast(x: Array) -> Array:
@@ -349,10 +363,8 @@ def eager(
349363
def test_lazy_apply_kwargs(xp: ModuleType, library: Backend, as_numpy: bool):
350364
"""When as_numpy=True, search and replace arrays in the (nested) keywords arguments
351365
with numpy arrays, and leave the rest untouched."""
352-
expect_cls = (
353-
np.ndarray if as_numpy or library is Backend.DASK else type(xp.asarray(0))
354-
)
355366
x = xp.asarray(0)
367+
expect_cls = np.ndarray if as_numpy or library is Backend.DASK else type(x)
356368
actual = check_lazy_apply_kwargs(x, expect_cls, as_numpy) # pyright: ignore[reportUnknownArgumentType]
357369
xp_assert_equal(actual, x + 1)
358370

0 commit comments

Comments
 (0)