Skip to content

Commit 82bb964

Browse files
committed
tests for elemwise
1 parent 67bb8da commit 82bb964

File tree

1 file changed

+31
-4
lines changed

1 file changed

+31
-4
lines changed

tests/link/mlx/test_math.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import numpy as np
2+
import pytest
23

34
import pytensor
4-
from pytensor.tensor.type import matrix
5-
from tests.link.mlx.test_basic import mx
5+
import pytensor.tensor as pt
6+
from tests.link.mlx.test_basic import compare_mlx_and_py, mx
67

78

89
def test_dot():
9-
x = matrix("x")
10-
y = matrix("y")
10+
x = pt.matrix("x")
11+
y = pt.matrix("y")
1112

1213
out = x.dot(y)
1314
fn = pytensor.function([x, y], out, mode="MLX")
@@ -22,3 +23,29 @@ def test_dot():
2223
assert isinstance(actual, mx.array)
2324
expected = np.dot(test_x, test_y)
2425
np.testing.assert_allclose(actual, expected, rtol=1e-6)
26+
27+
28+
@pytest.mark.parametrize(
29+
"op",
30+
[pt.exp, pt.log, pt.sin, pt.cos],
31+
ids=["exp", "log", "sin", "cos"],
32+
)
33+
def test_elemwise_one_input(op) -> None:
34+
x = pt.vector("x")
35+
out = op(x)
36+
x_test = mx.array([1.0, 2.0, 3.0])
37+
compare_mlx_and_py([x], out, [x_test])
38+
39+
40+
@pytest.mark.parametrize(
41+
"op",
42+
[pt.add, pt.sub, pt.mul],
43+
ids=["add", "sub", "mul"],
44+
)
45+
def test_elemwise_two_inputs(op) -> None:
46+
x = pt.vector("x")
47+
y = pt.vector("y")
48+
out = op(x, y)
49+
x_test = mx.array([1.0, 2.0, 3.0])
50+
y_test = mx.array([4.0, 5.0, 6.0])
51+
compare_mlx_and_py([x, y], out, [x_test, y_test])

0 commit comments

Comments
 (0)