We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 5bfc9d0 commit 5531ddaCopy full SHA for 5531dda
tesseract_jax/primitive.py
@@ -5,6 +5,7 @@
5
from collections.abc import Sequence
6
from typing import Any, TypeVar
7
8
+import jax.numpy as jnp
9
import jax.tree
10
import numpy as np
11
from jax import ShapeDtypeStruct, dtypes, extend
@@ -246,7 +247,7 @@ def tesseract_dispatch_batching(
246
247
) -> Any:
248
"""Defines how to dispatch batch operations such as vmap (which is used by jax.jacobian)."""
249
new_args = [
- arg if ax is batching.not_mapped else batching.moveaxis(arg, ax, 0)
250
+ arg if ax is batching.not_mapped else jnp.moveaxis(arg, ax, 0)
251
for arg, ax in zip(array_args, axes, strict=True)
252
]
253
0 commit comments