Skip to content

Commit 242aba7

Browse files
committed
add more tests
1 parent 60acb8d commit 242aba7

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

tests/link/mlx/test_math.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,13 @@ def test_dot():
2727

2828
@pytest.mark.parametrize(
2929
"op",
30-
[pt.exp, pt.log, pt.sin, pt.cos],
31-
ids=["exp", "log", "sin", "cos"],
30+
[
31+
pytest.param(pt.exp, id="exp"),
32+
pytest.param(pt.log, id="log"),
33+
pytest.param(pt.sin, id="sin"),
34+
pytest.param(pt.cos, id="cos"),
35+
pytest.param(pt.sigmoid, id="sigmoid"),
36+
],
3237
)
3338
def test_elemwise_one_input(op) -> None:
3439
x = pt.vector("x")
@@ -39,8 +44,16 @@ def test_elemwise_one_input(op) -> None:
3944

4045
@pytest.mark.parametrize(
4146
"op",
42-
[pt.add, pt.sub, pt.mul],
43-
ids=["add", "sub", "mul"],
47+
[
48+
pytest.param(pt.add, id="add"),
49+
pytest.param(pt.sub, id="sub"),
50+
pytest.param(pt.mul, id="mul"),
51+
pytest.param(pt.power, id="power"),
52+
pytest.param(pt.le, id="le"),
53+
pytest.param(pt.lt, id="lt"),
54+
pytest.param(pt.ge, id="ge"),
55+
pytest.param(pt.gt, id="gt"),
56+
],
4457
)
4558
def test_elemwise_two_inputs(op) -> None:
4659
x = pt.vector("x")

0 commit comments

Comments
 (0)