We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 052fdc2 commit a9ecad0Copy full SHA for a9ecad0
tests/link/mlx/dispatch/test_math.py
@@ -1,3 +1,4 @@
1
+import mlx.core as mx
2
import numpy as np
3
4
import pytensor
@@ -11,8 +12,8 @@ def test_mlx_dot():
11
12
out = x.dot(y)
13
fn = pytensor.function([x, y], out, mode="MLX")
14
- test_x = np.random.normal(size=(3, 2))
15
- test_y = np.random.normal(size=(2, 4))
+ test_x = mx.array(np.random.normal(size=(3, 2)))
16
+ test_y = mx.array(np.random.normal(size=(2, 4)))
17
np.testing.assert_allclose(
18
fn(test_x, test_y),
19
np.dot(test_x, test_y),
0 commit comments