Skip to content

Commit 2fc81bc

Browse files
committed
fix for carlos
1 parent 5c97bc8 commit 2fc81bc

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

tests/link/mlx/test_elemwise.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1+
import pytest
2+
13
import pytensor.tensor as pt
24
from tests.link.mlx.test_basic import compare_mlx_and_py, mx
35

46

5-
def test_all() -> None:
7+
@pytest.mark.parametrize("op", [pt.any, pt.all, pt.max, pt.min])
8+
def test_input(op) -> None:
69
x = pt.vector("x")
7-
8-
out = pt.all(x > 0)
9-
10-
x_test = mx.array([-1.0, 2.0, 3.0])
10+
out = op(x > 0)
11+
x_test = mx.array([1.0, 2.0, 3.0])
1112

1213
compare_mlx_and_py([x], out, [x_test])

tests/link/mlx/test_math.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,16 @@ def test_switch() -> None:
5454
compare_mlx_and_py([x, y], out, [x_test, y_test])
5555

5656

57+
@pytest.mark.parametrize("op", [pt.sum, pt.prod])
58+
def test_input(op) -> None:
59+
x = pt.vector("x")
60+
y = pt.vector("y")
61+
out = op([x, y, x + y])
62+
x_test = mx.array([1.0, 2.0, 3.0])
63+
y_test = mx.array([4.0, 5.0, 6.0])
64+
compare_mlx_and_py([x, y], out, [x_test, y_test])
65+
66+
5767
@pytest.mark.parametrize(
5868
"op",
5969
[

0 commit comments

Comments
 (0)