Skip to content

Commit fb5eb52

Browse files
committed
make compare function
1 parent 67a74fb commit fb5eb52

File tree

1 file changed

+74
-0
lines changed

1 file changed

+74
-0
lines changed

tests/link/mlx/test_basic.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,78 @@
1+
from collections.abc import Callable, Iterable
2+
from functools import partial
3+
4+
import numpy as np
15
import pytest
26

7+
from pytensor.compile.function import function
8+
from pytensor.compile.mode import Mode
9+
from pytensor.graph.basic import Variable
10+
from pytensor.link.mlx import MLXLinker
11+
312

413
mx = pytest.importorskip("mlx.core")
14+
15+
mlx_mode = Mode(linker=MLXLinker())
16+
py_mode = Mode(linker="py", optimizer=None)
17+
18+
19+
def compare_mlx_and_py(
20+
graph_inputs: Iterable[Variable],
21+
graph_outputs: Variable | Iterable[Variable],
22+
test_inputs: Iterable,
23+
*,
24+
assert_fn: Callable | None = None,
25+
must_be_device_array: bool = True,
26+
mlx_mode=mlx_mode,
27+
py_mode=py_mode,
28+
):
29+
"""Function to compare python function output and mlx compiled output for testing equality
30+
31+
The inputs and outputs are then passed to this function which then compiles the given function in both
32+
mlx and python, runs the calculation in both and checks if the results are the same
33+
34+
Parameters
35+
----------
36+
graph_inputs:
37+
Symbolic inputs to the graph
38+
outputs:
39+
Symbolic outputs of the graph
40+
test_inputs: iter
41+
Numerical inputs for testing the function.
42+
assert_fn: func, opt
43+
Assert function used to check for equality between python and mlx. If not
44+
provided uses np.testing.assert_allclose
45+
must_be_device_array: Bool
46+
Checks for instance of jax.interpreters.xla.DeviceArray. For testing purposes
47+
if this device array is found it indicates if the result was computed by jax
48+
49+
Returns
50+
-------
51+
mlx_res
52+
53+
"""
54+
if assert_fn is None:
55+
assert_fn = partial(np.testing.assert_allclose, rtol=1e-4)
56+
57+
if any(inp.owner is not None for inp in graph_inputs):
58+
raise ValueError("Inputs must be root variables")
59+
60+
pytensor_mlx_fn = function(graph_inputs, graph_outputs, mode=mlx_mode)
61+
mlx_res = pytensor_mlx_fn(*test_inputs)
62+
63+
if must_be_device_array:
64+
if isinstance(mlx_res, list):
65+
assert all(isinstance(res, mx.array) for res in mlx_res)
66+
else:
67+
assert isinstance(mlx_res, mx.array)
68+
69+
pytensor_py_fn = function(graph_inputs, graph_outputs, mode=py_mode)
70+
py_res = pytensor_py_fn(*test_inputs)
71+
72+
if isinstance(graph_outputs, list | tuple):
73+
for j, p in zip(mlx_res, py_res, strict=True):
74+
assert_fn(j, p)
75+
else:
76+
assert_fn(mlx_res, py_res)
77+
78+
return pytensor_mlx_fn, mlx_res

0 commit comments

Comments
 (0)