|
37 | 37 | vector,
|
38 | 38 | )
|
39 | 39 | from pytensor.tensor.basic import MakeVector, concatenate, expand_dims, make_vector
|
| 40 | +from pytensor.tensor.blas import Dot22, Gemv |
| 41 | +from pytensor.tensor.blas_c import CGemv |
40 | 42 | from pytensor.tensor.elemwise import DimShuffle, Elemwise
|
41 | 43 | from pytensor.tensor.math import sum as pt_sum
|
42 | 44 | from pytensor.tensor.rewriting.subtensor_lift import (
|
@@ -178,6 +180,48 @@ def test_local_subtensor_of_elemwise_multiple_clients(self):
|
178 | 180 | assert local_subtensor_of_elemwise.transform(fgraph, out2.owner) is not None
|
179 | 181 |
|
180 | 182 |
|
| 183 | +def test_local_subtensor_of_dot(): |
| 184 | + m1 = matrix() |
| 185 | + m2 = matrix() |
| 186 | + d1 = np.arange(6).reshape((3, 2)).astype(config.floatX) |
| 187 | + d2 = np.arange(8).reshape((2, 4)).astype(config.floatX) + 10 |
| 188 | + mode = get_default_mode().including("local_subtensor_of_dot") |
| 189 | + |
| 190 | + def test_equality(a, b): |
| 191 | + return a.shape == b.shape and np.allclose(a, b) |
| 192 | + |
| 193 | + # [cst] |
| 194 | + f = function([m1, m2], pt.dot(m1, m2)[1], mode=mode) |
| 195 | + topo = f.maker.fgraph.toposort() |
| 196 | + assert test_equality(f(d1, d2), np.dot(d1, d2)[1]) |
| 197 | + # DimShuffle happen in FAST_COMPILE |
| 198 | + assert isinstance(topo[-1].op, CGemv | Gemv | DimShuffle) |
| 199 | + |
| 200 | + # slice |
| 201 | + f = function([m1, m2], pt.dot(m1, m2)[1:2], mode=mode) |
| 202 | + topo = f.maker.fgraph.toposort() |
| 203 | + assert test_equality(f(d1, d2), np.dot(d1, d2)[1:2]) |
| 204 | + assert isinstance(topo[-1].op, Dot22) |
| 205 | + |
| 206 | + m1 = tensor3() |
| 207 | + m2 = tensor3() |
| 208 | + idx = iscalar() |
| 209 | + d1 = np.arange(30).reshape(2, 5, 3).astype(config.floatX) |
| 210 | + d2 = np.arange(72).reshape(4, 3, 6).astype(config.floatX) + 100 |
| 211 | + |
| 212 | + f = function([m1, m2, idx], pt.dot(m1, m2)[idx, 1:4, :, idx:], mode=mode) |
| 213 | + assert test_equality(f(d1, d2, 1), np.dot(d1, d2)[1, 1:4, :, 1:]) |
| 214 | + # if we return the gradients. We need to use same mode as before. |
| 215 | + assert check_stack_trace(f, ops_to_check="last") |
| 216 | + |
| 217 | + f = function([m1, m2, idx], pt.dot(m1, m2)[1:4, :, idx:, idx], mode=mode) |
| 218 | + assert test_equality(f(d1, d2, 1), np.dot(d1, d2)[1:4, :, 1:, 1]) |
| 219 | + |
| 220 | + # Now test that the stack trace is copied over properly, |
| 221 | + # if we return the gradients. We need to use same mode as before. |
| 222 | + assert check_stack_trace(f, ops_to_check="last") |
| 223 | + |
| 224 | + |
181 | 225 | @pytest.mark.parametrize(
|
182 | 226 | "original_fn, expected_fn",
|
183 | 227 | [
|
|
0 commit comments