Skip to content

Commit fd90c5b

Browse files
authored
fix(deps): use jnp.moveaxis instead of jax.batching.moveaxis (#72)
[Jax 0.7.1](https://docs.jax.dev/en/latest/changelog.html) deprecates `jax.batching.moveaxis` and recommends switching to `jnp.moveaxis` #### Description of changes Followed the recommendation! #### Testing done CI passes with jax 0.7.1 (and whatever version the lockfile specifies on github)
1 parent 5bfc9d0 commit fd90c5b

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

tesseract_jax/primitive.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from collections.abc import Sequence
66
from typing import Any, TypeVar
77

8+
import jax.numpy as jnp
89
import jax.tree
910
import numpy as np
1011
from jax import ShapeDtypeStruct, dtypes, extend
@@ -246,7 +247,7 @@ def tesseract_dispatch_batching(
246247
) -> Any:
247248
"""Defines how to dispatch batch operations such as vmap (which is used by jax.jacobian)."""
248249
new_args = [
249-
arg if ax is batching.not_mapped else batching.moveaxis(arg, ax, 0)
250+
arg if ax is batching.not_mapped else jnp.moveaxis(arg, ax, 0)
250251
for arg, ax in zip(array_args, axes, strict=True)
251252
]
252253

0 commit comments

Comments
 (0)