We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent ad29c17 commit ba29b37Copy full SHA for ba29b37
tests/link/mlx/test_math.py
@@ -1,4 +1,3 @@
1
-import mlx.core as mx
2
import numpy as np
3
4
import pytensor
@@ -12,9 +11,12 @@ def test_mlx_dot():
12
11
out = x.dot(y)
13
fn = pytensor.function([x, y], out, mode="MLX")
14
15
- test_x = mx.array(np.random.normal(size=(3, 2)))
16
- test_y = mx.array(np.random.normal(size=(2, 4)))
17
- np.testing.assert_allclose(
18
- fn(test_x, test_y),
19
- np.dot(test_x, test_y),
20
- )
+ seed = sum(map(ord, "test_mlx_dot"))
+ rng = np.random.default_rng(seed)
+
+ test_x = rng.normal(size=(3, 2))
+ test_y = rng.normal(size=(2, 4))
+ actual = fn(test_x, test_y)
21
+ expected = np.dot(test_x, test_y)
22
+ np.testing.assert_allclose(actual, expected)
0 commit comments