@@ -31,14 +31,6 @@ class P: # pylint: disable=missing-class-docstring
3131 kwargs : dict
3232
3333
34- class UnknownShapeError (ValueError ):
35- """
36- `shape` contains one or more None elements.
37-
38- This is unsupported when running inside `jax.jit`.
39- """
40-
41-
4234@overload
4335def apply_numpy_func ( # type: ignore[valid-type]
4436 func : Callable [P , NumPyObject ],
@@ -148,15 +140,14 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
148140
149141 Raises
150142 ------
151- UnknownShapeError
152- When `shape` is unknown (one or more sizes are None) and this function was
153- called inside `jax.jit`.
154-
155- Exception (varies)
156-
157- - When the backend disallows implicit device to host transfers and the input
158- arrays are on a device, e.g. on GPU;
159- - When the backend is sparse and auto-densification is disabled.
143+ jax.errors.TracerArrayConversionError
144+ When `xp=jax.numpy`, `shape` is unknown (it contains None on one or more axes)
145+ and this function was called inside `jax.jit`.
146+ RuntimeError
147+ When `xp=sparse` and auto-densification is disabled.
148+ Exception (backend-specific)
149+ When the backend disallows implicit device to host transfers and the input
150+ arrays are on a device, e.g. on GPU.
160151
161152 See Also
162153 --------
@@ -241,15 +232,8 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
241232 if any (s is None for shape in shapes for s in shape ):
242233 # Unknown output shape. Won't work with jax.jit, but it
243234 # can work with eager jax.
244- try :
245- out = wrapped (* args , ** kwargs )
246- except jax .errors .TracerArrayConversionError :
247- msg = (
248- "jax.jit can't delay application of numpy functions when the shape "
249- "of the returned array(s) is unknown. "
250- f"shape={ shapes if multi_output else shapes [0 ]} "
251- )
252- raise UnknownShapeError (msg ) from None
235+ # Raises jax.errors.TracerArrayConversionError if we're inside jax.jit.
236+ out = wrapped (* args , ** kwargs )
253237
254238 else :
255239 out = cast (
0 commit comments