Skip to content

Commit 1d4f1cb

Browse files
committed
Fixed JAX unit tests for new structure
1 parent 8eb150b commit 1d4f1cb

File tree

2 files changed

+15
-14
lines changed

2 files changed

+15
-14
lines changed

pytensor/link/jax/dispatch/__init__.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,20 @@
22
from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify
33

44
# Load dispatch specializations
5-
import pytensor.link.jax.dispatch.scalar
6-
import pytensor.link.jax.dispatch.tensor_basic
7-
import pytensor.link.jax.dispatch.subtensor
8-
import pytensor.link.jax.dispatch.shape
5+
import pytensor.link.jax.dispatch.blas
6+
import pytensor.link.jax.dispatch.blockwise
7+
import pytensor.link.jax.dispatch.elemwise
98
import pytensor.link.jax.dispatch.extra_ops
9+
import pytensor.link.jax.dispatch.math
1010
import pytensor.link.jax.dispatch.nlinalg
11-
import pytensor.link.jax.dispatch.slinalg
1211
import pytensor.link.jax.dispatch.random
13-
import pytensor.link.jax.dispatch.elemwise
12+
import pytensor.link.jax.dispatch.scalar
1413
import pytensor.link.jax.dispatch.scan
15-
import pytensor.link.jax.dispatch.sparse
16-
import pytensor.link.jax.dispatch.blockwise
14+
import pytensor.link.jax.dispatch.shape
15+
import pytensor.link.jax.dispatch.slinalg
1716
import pytensor.link.jax.dispatch.sort
17+
import pytensor.link.jax.dispatch.sparse
18+
import pytensor.link.jax.dispatch.subtensor
19+
import pytensor.link.jax.dispatch.tensor_basic
1820

1921
# isort: on

pytensor/link/jax/dispatch/math.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import jax.numpy as jnp
2+
import numpy as np
23

34
from pytensor.link.jax.dispatch import jax_funcify
45
from pytensor.tensor.math import Argmax, Dot, Max
@@ -36,12 +37,10 @@ def argmax(x):
3637

3738
# NumPy does not support multiple axes for argmax; this is a
3839
# work-around
39-
keep_axes = jnp.array(
40-
[i for i in range(x.ndim) if i not in axes], dtype="int64"
41-
)
40+
keep_axes = np.array([i for i in range(x.ndim) if i not in axes], dtype="int64")
4241
# Not-reduced axes in front
4342
transposed_x = jnp.transpose(
44-
x, jnp.concatenate((keep_axes, jnp.array(axes, dtype="int64")))
43+
x, tuple(np.concatenate((keep_axes, np.array(axes, dtype="int64"))))
4544
)
4645
kept_shape = transposed_x.shape[: len(keep_axes)]
4746
reduced_shape = transposed_x.shape[len(keep_axes) :]
@@ -50,9 +49,9 @@ def argmax(x):
5049
# Otherwise reshape would complain citing float arg
5150
new_shape = (
5251
*kept_shape,
53-
jnp.prod(jnp.array(reduced_shape, dtype="int64"), dtype="int64"),
52+
np.prod(np.array(reduced_shape, dtype="int64"), dtype="int64"),
5453
)
55-
reshaped_x = transposed_x.reshape(new_shape)
54+
reshaped_x = transposed_x.reshape(tuple(new_shape))
5655

5756
max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64")
5857

0 commit comments

Comments
 (0)