11import numpy as np
2+ import pytest
23
34import pytensor
4- from pytensor .tensor . type import matrix
5- from tests .link .mlx .test_basic import mx
5+ import pytensor .tensor as pt
6+ from tests .link .mlx .test_basic import compare_mlx_and_py , mx
67
78
89def test_dot ():
9- x = matrix ("x" )
10- y = matrix ("y" )
10+ x = pt . matrix ("x" )
11+ y = pt . matrix ("y" )
1112
1213 out = x .dot (y )
1314 fn = pytensor .function ([x , y ], out , mode = "MLX" )
@@ -22,3 +23,29 @@ def test_dot():
2223 assert isinstance (actual , mx .array )
2324 expected = np .dot (test_x , test_y )
2425 np .testing .assert_allclose (actual , expected , rtol = 1e-6 )
26+
27+
28+ @pytest .mark .parametrize (
29+ "op" ,
30+ [pt .exp , pt .log , pt .sin , pt .cos ],
31+ ids = ["exp" , "log" , "sin" , "cos" ],
32+ )
33+ def test_elemwise_one_input (op ) -> None :
34+ x = pt .vector ("x" )
35+ out = op (x )
36+ x_test = mx .array ([1.0 , 2.0 , 3.0 ])
37+ compare_mlx_and_py ([x ], out , [x_test ])
38+
39+
40+ @pytest .mark .parametrize (
41+ "op" ,
42+ [pt .add , pt .sub , pt .mul ],
43+ ids = ["add" , "sub" , "mul" ],
44+ )
45+ def test_elemwise_two_inputs (op ) -> None :
46+ x = pt .vector ("x" )
47+ y = pt .vector ("y" )
48+ out = op (x , y )
49+ x_test = mx .array ([1.0 , 2.0 , 3.0 ])
50+ y_test = mx .array ([4.0 , 5.0 , 6.0 ])
51+ compare_mlx_and_py ([x , y ], out , [x_test , y_test ])
0 commit comments