1
1
import jax .numpy as jnp
2
+ import numpy as np
2
3
3
4
from pytensor .link .jax .dispatch import jax_funcify
4
5
from pytensor .tensor .math import Argmax , Dot , Max
@@ -36,12 +37,10 @@ def argmax(x):
36
37
37
38
# NumPy does not support multiple axes for argmax; this is a
38
39
# work-around
39
- keep_axes = jnp .array (
40
- [i for i in range (x .ndim ) if i not in axes ], dtype = "int64"
41
- )
40
+ keep_axes = np .array ([i for i in range (x .ndim ) if i not in axes ], dtype = "int64" )
42
41
# Not-reduced axes in front
43
42
transposed_x = jnp .transpose (
44
- x , jnp .concatenate ((keep_axes , jnp .array (axes , dtype = "int64" )))
43
+ x , tuple ( np .concatenate ((keep_axes , np .array (axes , dtype = "int64" ) )))
45
44
)
46
45
kept_shape = transposed_x .shape [: len (keep_axes )]
47
46
reduced_shape = transposed_x .shape [len (keep_axes ) :]
@@ -50,9 +49,9 @@ def argmax(x):
50
49
# Otherwise reshape would complain citing float arg
51
50
new_shape = (
52
51
* kept_shape ,
53
- jnp .prod (jnp .array (reduced_shape , dtype = "int64" ), dtype = "int64" ),
52
+ np .prod (np .array (reduced_shape , dtype = "int64" ), dtype = "int64" ),
54
53
)
55
- reshaped_x = transposed_x .reshape (new_shape )
54
+ reshaped_x = transposed_x .reshape (tuple ( new_shape ) )
56
55
57
56
max_idx_res = jnp .argmax (reshaped_x , axis = - 1 ).astype ("int64" )
58
57
0 commit comments