Skip to content

Commit 3744a18

Browse files
committed
test conv1d case
1 parent e308f83 commit 3744a18

File tree

1 file changed

+64
-0
lines changed

1 file changed

+64
-0
lines changed

tests/link/mlx/test_blockwise.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import numpy as np
2+
3+
import pytensor.tensor as pt
4+
from pytensor.tensor import tensor
5+
from pytensor.tensor.blockwise import Blockwise
6+
from pytensor.tensor.math import Dot
7+
from tests.link.mlx.test_basic import compare_mlx_and_py
8+
9+
10+
# Equivalent blockwise to matmul but with dumb signature
11+
odd_matmul = Blockwise(Dot(), signature="(i00,i01),(i10,i11)->(o00,o01)")
12+
13+
14+
# @pytest.mark.parametrize("matmul_op", (matmul, odd_matmul))
15+
# def test_matmul(matmul_op):
16+
# rng = np.random.default_rng(14)
17+
# a = tensor("a", shape=(2, 3, 5))
18+
# b = tensor("b", shape=(2, 5, 3))
19+
# test_values = [
20+
# rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (a, b)
21+
# ]
22+
#
23+
# out = matmul_op(a, b)
24+
# assert isinstance(out.owner.op, Blockwise)
25+
# fn, _ = compare_mlx_and_py([a, b], [out], test_values)
26+
#
27+
## Check we are not adding any unnecessary stuff
28+
# jaxpr = str(jax.make_jaxpr(fn.vm.jit_fn)(*test_values))
29+
# jaxpr = jaxpr.replace("name=jax_funcified_fgraph", "name=matmul")
30+
# expected_jaxpr = str(jax.make_jaxpr(jax.jit(jax.numpy.matmul))(*test_values))
31+
# assert jaxpr == expected_jaxpr
32+
33+
34+
# conv1d
35+
# (2, 100)
36+
# (8, 100)
37+
# mode = valid
38+
39+
40+
def test_blockwise_conv1d():
41+
rng = np.random.default_rng(14)
42+
a = tensor("a", shape=(2, 100))
43+
b = tensor("b", shape=(2, 8))
44+
45+
# a_test = np.broadcast_to(np.arange(100), (2, 100))
46+
a_test = rng.normal(size=(2, 100))
47+
b_test = rng.normal(size=(2, 8))
48+
# b_test = np.concatenate(
49+
# [
50+
# np.ones((1, 8)),
51+
# np.zeros((1, 8)),
52+
# np.zeros((1, 8)),
53+
# np.array([1, 0, 0, 0, 0, 0, 0, 0]).reshape(1, 8),
54+
# np.array([1, 0, 0, 0, 0, 0, 0, 0]).reshape(1, 8),
55+
# ],
56+
# axis=0,
57+
# )
58+
59+
test_values = [a_test, b_test]
60+
61+
out = pt.signal.convolve1d(a, b, mode="valid")
62+
63+
# assert isinstance(out.owner.op, Blockwise)
64+
compare_mlx_and_py([a, b], [out], test_values, must_be_device_array=True)

0 commit comments

Comments
 (0)