Skip to content

Commit a9ecad0

Browse files
committed
wrap in mx.array
1 parent 052fdc2 commit a9ecad0

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

tests/link/mlx/dispatch/test_math.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import mlx.core as mx
12
import numpy as np
23

34
import pytensor
@@ -11,8 +12,8 @@ def test_mlx_dot():
1112
out = x.dot(y)
1213
fn = pytensor.function([x, y], out, mode="MLX")
1314

14-
test_x = np.random.normal(size=(3, 2))
15-
test_y = np.random.normal(size=(2, 4))
15+
test_x = mx.array(np.random.normal(size=(3, 2)))
16+
test_y = mx.array(np.random.normal(size=(2, 4)))
1617
np.testing.assert_allclose(
1718
fn(test_x, test_y),
1819
np.dot(test_x, test_y),

0 commit comments

Comments
 (0)