Skip to content

Commit 1e6addd

Browse files
committed
bring argmax test
1 parent 880dd5c commit 1e6addd

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

tests/link/mlx/test_math.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import pytensor
55
import pytensor.tensor as pt
6+
from pytensor.tensor.math import Argmax, Max
67
from tests.link.mlx.test_basic import compare_mlx_and_py, mx
78

89

@@ -87,3 +88,14 @@ def test_elemwise_two_inputs(op) -> None:
8788
x_test = mx.array([1.0, 2.0, 3.0])
8889
y_test = mx.array([4.0, 5.0, 6.0])
8990
compare_mlx_and_py([x, y], out, [x_test, y_test])
91+
92+
93+
@pytest.mark.xfail(reason="Argmax not implemented yet")
94+
def test_mlx_max_and_argmax():
95+
# Test that a single output of a multi-output `Op` can be used as input to
96+
# another `Op`
97+
x = pt.dvector()
98+
mx = Max([0])(x)
99+
amx = Argmax([0])(x)
100+
out = mx * amx
101+
compare_mlx_and_py([x], [out], [np.r_[1, 2]])

0 commit comments

Comments
 (0)