|
3 | 3 |
|
4 | 4 | import pytensor |
5 | 5 | import pytensor.tensor as pt |
6 | | -from pytensor.graph.replace import vectorize_node |
7 | | -from pytensor.tensor import tensor |
8 | | -from pytensor.tensor.blockwise import Blockwise |
9 | | -from pytensor.tensor.nlinalg import MatrixInverse |
10 | 6 | from pytensor.tensor.shape import specify_broadcastable |
11 | 7 |
|
12 | 8 |
|
13 | 9 | torch = pytest.importorskip("torch") |
14 | 10 |
|
15 | 11 |
|
16 | | -def test_vectorize_blockwise(): |
17 | | - mat = tensor(shape=(None, None)) |
18 | | - tns = tensor(shape=(None, None, None)) |
19 | | - |
20 | | - # Something that falls back to Blockwise |
21 | | - node = MatrixInverse()(mat).owner |
22 | | - vect_node = vectorize_node(node, tns) |
23 | | - assert isinstance(vect_node.op, Blockwise) and isinstance( |
24 | | - vect_node.op.core_op, MatrixInverse |
25 | | - ) |
26 | | - assert vect_node.op.signature == ("(m,m)->(m,m)") |
27 | | - assert vect_node.inputs[0] is tns |
28 | | - |
29 | | - # Useless blockwise |
30 | | - tns4 = tensor(shape=(5, None, None, None)) |
31 | | - new_vect_node = vectorize_node(vect_node, tns4) |
32 | | - assert new_vect_node.op is vect_node.op |
33 | | - assert isinstance(new_vect_node.op, Blockwise) and isinstance( |
34 | | - new_vect_node.op.core_op, MatrixInverse |
35 | | - ) |
36 | | - assert new_vect_node.inputs[0] is tns4 |
37 | | - |
38 | | - |
39 | 12 | def test_blockwise_broadcast(): |
40 | 13 | _x = np.random.rand(5, 1, 2, 3) |
41 | 14 | _y = np.random.rand(3, 3, 2) |
|
0 commit comments