Skip to content

Commit edacc0e

Browse files
committed
add test for dot
1 parent d25f214 commit edacc0e

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import numpy as np
2+
3+
import pytensor
4+
from pytensor.tensor.type import matrix
5+
6+
7+
def test_mlx_dot():
8+
x = matrix("x")
9+
y = matrix("y")
10+
11+
out = x.dot(y)
12+
fn = pytensor.function([x, y], out, mode="MLX")
13+
14+
test_x = np.random.normal(size=(3, 2))
15+
test_y = np.random.normal(size=(2, 4))
16+
np.testing.assert_allclose(
17+
fn(test_x, test_y),
18+
np.dot(test_x, test_y),
19+
)

0 commit comments

Comments
 (0)