Skip to content

Commit bc98e09

Browse files
committed
test for switch with mlx
1 parent 6cb47fc commit bc98e09

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

tests/link/mlx/test_basic.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,16 @@
55
import pytest
66

77
from pytensor.compile.function import function
8-
from pytensor.compile.mode import Mode
8+
from pytensor.compile.mode import MLX, Mode
9+
from pytensor.graph import RewriteDatabaseQuery
910
from pytensor.graph.basic import Variable
1011
from pytensor.link.mlx import MLXLinker
1112

1213

1314
mx = pytest.importorskip("mlx.core")
1415

15-
mlx_mode = Mode(linker=MLXLinker())
16+
optimizer = RewriteDatabaseQuery(include=["mlx"], exclude=MLX._optimizer.exclude)
17+
mlx_mode = Mode(linker=MLXLinker(), optimizer=optimizer)
1618
py_mode = Mode(linker="py", optimizer=None)
1719

1820

tests/link/mlx/test_math.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,18 @@ def test_elemwise_one_input(op) -> None:
4242
compare_mlx_and_py([x], out, [x_test])
4343

4444

45+
def test_switch() -> None:
46+
x = pt.vector("x")
47+
y = pt.vector("y")
48+
49+
out = pt.switch(x > 0, y, x)
50+
51+
x_test = mx.array([-1.0, 2.0, 3.0])
52+
y_test = mx.array([4.0, 5.0, 6.0])
53+
54+
compare_mlx_and_py([x, y], out, [x_test, y_test])
55+
56+
4557
@pytest.mark.parametrize(
4658
"op",
4759
[

0 commit comments

Comments
 (0)