Skip to content

Commit 2589de4

Browse files
Use importorskip in mlx tests
1 parent ed0d687 commit 2589de4

File tree

4 files changed

+12
-7
lines changed

4 files changed

+12
-7
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ jobs:
289289
shell: micromamba-shell {0}
290290
run: |
291291
export PYTENSOR_FLAGS=mode=FAST_COMPILE,warn__ignore_bug_before=all,on_opt_error=raise,on_shape_error=raise,gcc__cxxflags=-pipe
292-
python -m pytest --runslow --benchmark-only --benchmark-json output.json --ignore tests/link/mlx
292+
python -m pytest --runslow --benchmark-only --benchmark-json output.json
293293
- name: Store benchmark result
294294
uses: benchmark-action/github-action-benchmark@v1
295295
with:

tests/link/mlx/test_basic.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from collections.abc import Callable, Iterable
66
from functools import partial
77

8-
import mlx.core as mx
98
import numpy as np
109
import pytest
1110

@@ -16,13 +15,12 @@
1615
from pytensor.graph import RewriteDatabaseQuery
1716
from pytensor.graph.basic import Variable
1817
from pytensor.link.mlx import MLXLinker
19-
from pytensor.link.mlx.dispatch.core import (
20-
mlx_funcify_Alloc,
21-
)
2218
from pytensor.raise_op import assert_op
2319
from pytensor.tensor.basic import Alloc
2420

2521

22+
mx = pytest.importorskip("mlx.core")
23+
2624
optimizer = RewriteDatabaseQuery(include=["mlx"], exclude=MLX._optimizer.exclude)
2725
mlx_mode = Mode(linker=MLXLinker(), optimizer=optimizer)
2826
mlx_mode_no_compile = Mode(linker=MLXLinker(use_compile=False), optimizer=optimizer)
@@ -196,6 +194,9 @@ def test_alloc_with_different_shape_types():
196194
This addresses the TypeError that occurred when shape parameters
197195
contained MLX arrays instead of Python integers.
198196
"""
197+
from pytensor.link.mlx.dispatch.core import (
198+
mlx_funcify_Alloc,
199+
)
199200

200201
# Create a mock node (we don't need a real node for this test)
201202
class MockNode:

tests/link/mlx/test_elemwise.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
import mlx.core as mx
21
import numpy as np
32
import pytest
43

54
import pytensor.tensor as pt
65
from tests.link.mlx.test_basic import compare_mlx_and_py
76

87

8+
mx = pytest.importorskip("mlx.core")
9+
10+
911
@pytest.mark.parametrize("op", [pt.any, pt.all, pt.max, pt.min])
1012
def test_input(op) -> None:
1113
x = pt.vector("x")

tests/link/mlx/test_math.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import mlx.core as mx
21
import numpy as np
32
import pytest
43

@@ -8,6 +7,9 @@
87
from tests.link.mlx.test_basic import compare_mlx_and_py
98

109

10+
mx = pytest.importorskip("mlx.core")
11+
12+
1113
def test_dot():
1214
x = pt.matrix("x")
1315
y = pt.matrix("y")

0 commit comments

Comments
 (0)