|
1 | 1 | import numpy as np |
2 | 2 | import pytest |
| 3 | +import scipy |
3 | 4 |
|
4 | | -import pytensor.tensor as pt |
| 5 | +from pytensor import config, function |
| 6 | +from pytensor.tensor.basic import switch |
| 7 | +from pytensor.tensor.math import all as pt_all |
| 8 | +from pytensor.tensor.math import any as pt_any |
| 9 | +from pytensor.tensor.math import exp, isinf, log, mul, prod |
| 10 | +from pytensor.tensor.math import max as pt_max |
| 11 | +from pytensor.tensor.math import min as pt_min |
| 12 | +from pytensor.tensor.math import sum as pt_sum |
| 13 | +from pytensor.tensor.special import SoftmaxGrad, softmax |
| 14 | +from pytensor.tensor.type import matrix, vector, vectors |
5 | 15 | from tests.link.mlx.test_basic import compare_mlx_and_py |
6 | 16 |
|
7 | 17 |
|
8 | 18 | mx = pytest.importorskip("mlx.core") |
9 | 19 |
|
10 | 20 |
|
11 | | -@pytest.mark.parametrize("op", [pt.any, pt.all, pt.max, pt.min]) |
| 21 | +@pytest.mark.parametrize("op", [pt_any, pt_all, pt_max, pt_min]) |
12 | 22 | def test_input(op) -> None: |
13 | | - x = pt.vector("x") |
| 23 | + x = vector("x") |
14 | 24 | out = op(x > 0) |
15 | 25 | x_test = mx.array([1.0, 2.0, 3.0]) |
16 | 26 |
|
17 | 27 | compare_mlx_and_py([x], out, [x_test]) |
18 | 28 |
|
19 | 29 |
|
20 | | -def test_elemwise_operations() -> None: |
21 | | - """Test elemwise operations (IntDiv, IsNan, Erfc, Erfcx, Softplus) in elemwise context""" |
22 | | - x = pt.vector("x") |
23 | | - y = pt.vector("y") |
24 | | - |
25 | | - # Test int_div in an elemwise expression |
26 | | - out_int_div = pt.int_div(x, y) + 1 |
27 | | - x_test = mx.array([10.0, 15.0, 20.0]) |
28 | | - y_test = mx.array([3.0, 4.0, 6.0]) |
29 | | - compare_mlx_and_py([x, y], out_int_div, [x_test, y_test]) |
30 | | - |
31 | | - # Test isnan in an elemwise expression |
32 | | - z = pt.vector("z") |
33 | | - out_isnan = pt.isnan(z).astype("float32") * 10 |
34 | | - z_test = mx.array([1.0, np.nan, 3.0]) |
35 | | - compare_mlx_and_py([z], out_isnan, [z_test]) |
36 | | - |
37 | | - # Test erfc in an elemwise expression |
38 | | - w = pt.vector("w") |
39 | | - out_erfc = pt.erfc(w) * 2.0 |
40 | | - w_test = mx.array([0.0, 0.5, 1.0]) |
41 | | - compare_mlx_and_py([w], out_erfc, [w_test]) |
42 | | - |
43 | | - # Test erfcx in an elemwise expression |
44 | | - v = pt.vector("v") |
45 | | - out_erfcx = pt.erfcx(v) + 0.1 |
46 | | - v_test = mx.array([0.0, 1.0, 2.0]) |
47 | | - compare_mlx_and_py([v], out_erfcx, [v_test]) |
48 | | - |
49 | | - # Test softplus in an elemwise expression |
50 | | - u = pt.vector("u") |
51 | | - out_softplus = pt.softplus(u) - 0.5 |
52 | | - u_test = mx.array([0.0, 1.0, -1.0]) |
53 | | - compare_mlx_and_py([u], out_softplus, [u_test]) |
| 30 | +def test_mlx_CAReduce(): |
| 31 | + a_pt = vector("a") |
| 32 | + a_pt.tag.test_value = np.r_[1, 2, 3].astype(config.floatX) |
| 33 | + |
| 34 | + x = pt_sum(a_pt, axis=None) |
| 35 | + |
| 36 | + compare_mlx_and_py([a_pt], [x], [np.r_[1, 2, 3].astype(config.floatX)]) |
| 37 | + |
| 38 | + a_pt = matrix("a") |
| 39 | + a_pt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX) |
| 40 | + |
| 41 | + x = pt_sum(a_pt, axis=0) |
| 42 | + |
| 43 | + compare_mlx_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)]) |
| 44 | + |
| 45 | + x = pt_sum(a_pt, axis=1) |
| 46 | + |
| 47 | + compare_mlx_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)]) |
| 48 | + |
| 49 | + a_pt = matrix("a") |
| 50 | + a_pt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX) |
| 51 | + |
| 52 | + x = prod(a_pt, axis=0) |
| 53 | + |
| 54 | + compare_mlx_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)]) |
| 55 | + |
| 56 | + x = pt_all(a_pt) |
| 57 | + |
| 58 | + compare_mlx_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)]) |
| 59 | + |
| 60 | + |
| 61 | +@pytest.mark.parametrize("axis", [None, 0, 1]) |
| 62 | +def test_softmax(axis): |
| 63 | + x = matrix("x") |
| 64 | + x_test_value = np.arange(6, dtype=config.floatX).reshape(2, 3) |
| 65 | + out = softmax(x, axis=axis) |
| 66 | + compare_mlx_and_py([x], [out], [x_test_value]) |
| 67 | + |
| 68 | + |
| 69 | +@pytest.mark.parametrize("axis", [None, 0, 1]) |
| 70 | +def test_softmax_grad(axis): |
| 71 | + dy = matrix("dy") |
| 72 | + dy_test_value = np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX) |
| 73 | + sm = matrix("sm") |
| 74 | + sm_test_value = np.arange(6, dtype=config.floatX).reshape(2, 3) |
| 75 | + out = SoftmaxGrad(axis=axis)(dy, sm) |
| 76 | + |
| 77 | + compare_mlx_and_py([dy, sm], [out], [dy_test_value, sm_test_value]) |
| 78 | + |
| 79 | + |
| 80 | +@pytest.mark.parametrize("size", [(10, 10), (1000, 1000)]) |
| 81 | +@pytest.mark.parametrize("axis", [0, 1]) |
| 82 | +def test_logsumexp_benchmark(size, axis, benchmark): |
| 83 | + X = matrix("X") |
| 84 | + X_max = pt_max(X, axis=axis, keepdims=True) |
| 85 | + X_max = switch(isinf(X_max), 0, X_max) |
| 86 | + X_lse = log(pt_sum(exp(X - X_max), axis=axis, keepdims=True)) + X_max |
| 87 | + |
| 88 | + rng = np.random.default_rng(23920) |
| 89 | + X_val = rng.normal(size=size) |
| 90 | + |
| 91 | + X_lse_fn = function([X], X_lse, mode="MLX") |
| 92 | + |
| 93 | + # JIT compile first |
| 94 | + _ = X_lse_fn(X_val) |
| 95 | + |
| 96 | + res = benchmark(X_lse_fn, X_val) |
| 97 | + |
| 98 | + exp_res = scipy.special.logsumexp(X_val, axis=axis, keepdims=True) |
| 99 | + np.testing.assert_array_almost_equal(res, exp_res) |
| 100 | + |
| 101 | + |
| 102 | +def test_multiple_input_multiply(): |
| 103 | + x, y, z = vectors("xyz") |
| 104 | + out = mul(x, y, z) |
| 105 | + compare_mlx_and_py([x, y, z], [out], test_inputs=[[1.5], [2.5], [3.5]]) |
0 commit comments