Skip to content

Commit 5531dda

Browse files
committed
fix(deps): use jnp.moveaxis instead of jax.batching.moveaxis
1 parent 5bfc9d0 commit 5531dda

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

tesseract_jax/primitive.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from collections.abc import Sequence
66
from typing import Any, TypeVar
77

8+
import jax.numpy as jnp
89
import jax.tree
910
import numpy as np
1011
from jax import ShapeDtypeStruct, dtypes, extend
@@ -246,7 +247,7 @@ def tesseract_dispatch_batching(
246247
) -> Any:
247248
"""Defines how to dispatch batch operations such as vmap (which is used by jax.jacobian)."""
248249
new_args = [
249-
arg if ax is batching.not_mapped else batching.moveaxis(arg, ax, 0)
250+
arg if ax is batching.not_mapped else jnp.moveaxis(arg, ax, 0)
250251
for arg, ax in zip(array_args, axes, strict=True)
251252
]
252253

0 commit comments

Comments
 (0)