Skip to content

Commit 98f07d0

Browse files
JAX 0.7.2 compat
1 parent e26de73 commit 98f07d0

File tree

3 files changed

+24
-3
lines changed

3 files changed

+24
-3
lines changed

jaxtyping/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,11 @@ def __getattr__(item):
186186
elif item == "ArrayLike":
187187
import jax.typing
188188

189-
return jax.typing.ArrayLike
189+
if jax.__version__ == "0.7.2":
190+
# Fix for https://github.com/jax-ml/jax/issues/31989
191+
return jax.typing.ArrayLike | jax._src.literals.LiteralArray
192+
else:
193+
return jax.typing.ArrayLike
190194
elif item == "PRNGKeyArray":
191195
# New-style `jax.random.key` have scalar shape and dtype `key<foo>`.
192196
# Old-style `jax.random.PRNGKey` have shape `(2,)` and dtype

test/test_array.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
1818
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
1919

20+
import contextlib
2021
import dataclasses as dc
2122
import sys
2223
from typing import Any, get_args, get_origin, TypeVar, Union
2324

25+
import jax
2426
import jax.numpy as jnp
2527
import jax.random as jr
2628
import numpy as np
@@ -554,6 +556,16 @@ def test_arraylike(typecheck, getkey):
554556
floatlike2 = Float[ArrayLike, ""]
555557
floatlike3 = Float32[ArrayLike, "4"]
556558

559+
def _literal(dtype, dimstr):
560+
out = []
561+
with contextlib.suppress(Exception):
562+
# JAX ==0.7.2
563+
out.append(dtype[jax._src.literals.LiteralArray, dimstr])
564+
with contextlib.suppress(Exception):
565+
# JAX > 0.7.2
566+
out.append(dtype[jax._src.literals.TypedNdArray, dimstr])
567+
return out
568+
557569
assert get_origin(floatlike1) is Union
558570
assert get_origin(floatlike2) is Union
559571
assert get_origin(floatlike3) is Union
@@ -564,6 +576,7 @@ def test_arraylike(typecheck, getkey):
564576
Float32[np.number, ""],
565577
float,
566578
]
579+
+ _literal(Float32, "")
567580
)
568581
assert _to_set(get_args(floatlike2)) == _to_set(
569582
[
@@ -572,12 +585,14 @@ def test_arraylike(typecheck, getkey):
572585
Float[np.number, ""],
573586
float,
574587
]
588+
+ _literal(Float, "")
575589
)
576590
assert _to_set(get_args(floatlike3)) == _to_set(
577591
[
578592
Float32[Array, "4"],
579593
Float32[np.ndarray, "4"],
580594
]
595+
+ _literal(Float32, "4")
581596
)
582597

583598
shaped1 = Shaped[ArrayLike, ""]
@@ -595,12 +610,14 @@ def test_arraylike(typecheck, getkey):
595610
float,
596611
complex,
597612
]
613+
+ _literal(Shaped, "")
598614
)
599615
assert _to_set(get_args(shaped2)) == _to_set(
600616
[
601617
Shaped[Array, "4"],
602618
Shaped[np.ndarray, "4"],
603619
]
620+
+ _literal(Shaped, "4")
604621
)
605622

606623

test/test_tf_dtype.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# Tensorflow dependency kept in a separate file, so that we can optionally exclude it
22
# more easily.
3-
import tensorflow as tf
4-
53
from jaxtyping import UInt
64

75

86
def test_tf_dtype():
7+
import tensorflow as tf
8+
99
x = tf.constant(1, dtype=tf.uint8)
1010
y = tf.constant(1, dtype=tf.float32)
1111
hint = UInt[tf.Tensor, "..."]

0 commit comments

Comments
 (0)