Skip to content

Commit bad7c90

Browse files
Copy jax CARReduce test
1 parent 9f41a4e commit bad7c90

File tree

1 file changed

+89
-37
lines changed

1 file changed

+89
-37
lines changed

tests/link/mlx/test_elemwise.py

Lines changed: 89 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,105 @@
11
import numpy as np
22
import pytest
3+
import scipy
34

4-
import pytensor.tensor as pt
5+
from pytensor import config, function
6+
from pytensor.tensor.basic import switch
7+
from pytensor.tensor.math import all as pt_all
8+
from pytensor.tensor.math import any as pt_any
9+
from pytensor.tensor.math import exp, isinf, log, mul, prod
10+
from pytensor.tensor.math import max as pt_max
11+
from pytensor.tensor.math import min as pt_min
12+
from pytensor.tensor.math import sum as pt_sum
13+
from pytensor.tensor.special import SoftmaxGrad, softmax
14+
from pytensor.tensor.type import matrix, vector, vectors
515
from tests.link.mlx.test_basic import compare_mlx_and_py
616

717

818
mx = pytest.importorskip("mlx.core")
919

1020

11-
@pytest.mark.parametrize("op", [pt.any, pt.all, pt.max, pt.min])
21+
@pytest.mark.parametrize("op", [pt_any, pt_all, pt_max, pt_min])
1222
def test_input(op) -> None:
13-
x = pt.vector("x")
23+
x = vector("x")
1424
out = op(x > 0)
1525
x_test = mx.array([1.0, 2.0, 3.0])
1626

1727
compare_mlx_and_py([x], out, [x_test])
1828

1929

20-
def test_elemwise_operations() -> None:
21-
"""Test elemwise operations (IntDiv, IsNan, Erfc, Erfcx, Softplus) in elemwise context"""
22-
x = pt.vector("x")
23-
y = pt.vector("y")
24-
25-
# Test int_div in an elemwise expression
26-
out_int_div = pt.int_div(x, y) + 1
27-
x_test = mx.array([10.0, 15.0, 20.0])
28-
y_test = mx.array([3.0, 4.0, 6.0])
29-
compare_mlx_and_py([x, y], out_int_div, [x_test, y_test])
30-
31-
# Test isnan in an elemwise expression
32-
z = pt.vector("z")
33-
out_isnan = pt.isnan(z).astype("float32") * 10
34-
z_test = mx.array([1.0, np.nan, 3.0])
35-
compare_mlx_and_py([z], out_isnan, [z_test])
36-
37-
# Test erfc in an elemwise expression
38-
w = pt.vector("w")
39-
out_erfc = pt.erfc(w) * 2.0
40-
w_test = mx.array([0.0, 0.5, 1.0])
41-
compare_mlx_and_py([w], out_erfc, [w_test])
42-
43-
# Test erfcx in an elemwise expression
44-
v = pt.vector("v")
45-
out_erfcx = pt.erfcx(v) + 0.1
46-
v_test = mx.array([0.0, 1.0, 2.0])
47-
compare_mlx_and_py([v], out_erfcx, [v_test])
48-
49-
# Test softplus in an elemwise expression
50-
u = pt.vector("u")
51-
out_softplus = pt.softplus(u) - 0.5
52-
u_test = mx.array([0.0, 1.0, -1.0])
53-
compare_mlx_and_py([u], out_softplus, [u_test])
30+
def test_mlx_CAReduce():
31+
a_pt = vector("a")
32+
a_pt.tag.test_value = np.r_[1, 2, 3].astype(config.floatX)
33+
34+
x = pt_sum(a_pt, axis=None)
35+
36+
compare_mlx_and_py([a_pt], [x], [np.r_[1, 2, 3].astype(config.floatX)])
37+
38+
a_pt = matrix("a")
39+
a_pt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)
40+
41+
x = pt_sum(a_pt, axis=0)
42+
43+
compare_mlx_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
44+
45+
x = pt_sum(a_pt, axis=1)
46+
47+
compare_mlx_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
48+
49+
a_pt = matrix("a")
50+
a_pt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)
51+
52+
x = prod(a_pt, axis=0)
53+
54+
compare_mlx_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
55+
56+
x = pt_all(a_pt)
57+
58+
compare_mlx_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
59+
60+
61+
@pytest.mark.parametrize("axis", [None, 0, 1])
62+
def test_softmax(axis):
63+
x = matrix("x")
64+
x_test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
65+
out = softmax(x, axis=axis)
66+
compare_mlx_and_py([x], [out], [x_test_value])
67+
68+
69+
@pytest.mark.parametrize("axis", [None, 0, 1])
70+
def test_softmax_grad(axis):
71+
dy = matrix("dy")
72+
dy_test_value = np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)
73+
sm = matrix("sm")
74+
sm_test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
75+
out = SoftmaxGrad(axis=axis)(dy, sm)
76+
77+
compare_mlx_and_py([dy, sm], [out], [dy_test_value, sm_test_value])
78+
79+
80+
@pytest.mark.parametrize("size", [(10, 10), (1000, 1000)])
81+
@pytest.mark.parametrize("axis", [0, 1])
82+
def test_logsumexp_benchmark(size, axis, benchmark):
83+
X = matrix("X")
84+
X_max = pt_max(X, axis=axis, keepdims=True)
85+
X_max = switch(isinf(X_max), 0, X_max)
86+
X_lse = log(pt_sum(exp(X - X_max), axis=axis, keepdims=True)) + X_max
87+
88+
rng = np.random.default_rng(23920)
89+
X_val = rng.normal(size=size)
90+
91+
X_lse_fn = function([X], X_lse, mode="MLX")
92+
93+
# JIT compile first
94+
_ = X_lse_fn(X_val)
95+
96+
res = benchmark(X_lse_fn, X_val)
97+
98+
exp_res = scipy.special.logsumexp(X_val, axis=axis, keepdims=True)
99+
np.testing.assert_array_almost_equal(res, exp_res)
100+
101+
102+
def test_multiple_input_multiply():
103+
x, y, z = vectors("xyz")
104+
out = mul(x, y, z)
105+
compare_mlx_and_py([x, y, z], [out], test_inputs=[[1.5], [2.5], [3.5]])

0 commit comments

Comments
 (0)