|
34 | 34 | from pytensor.tensor.basic import Alloc, constant, join, second, switch
|
35 | 35 | from pytensor.tensor.blas import Dot22, Gemv
|
36 | 36 | from pytensor.tensor.blas_c import CGemv
|
| 37 | +from pytensor.tensor.blockwise import Blockwise |
37 | 38 | from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
|
38 | 39 | from pytensor.tensor.math import Dot, MaxAndArgmax, Prod, Sum, _conj
|
39 | 40 | from pytensor.tensor.math import abs as pt_abs
|
@@ -4427,3 +4428,51 @@ def test_polygamma_specialization():
|
4427 | 4428 | assert isinstance(fn_outs[0].owner.op.scalar_op, Psi)
|
4428 | 4429 | assert isinstance(fn_outs[1].owner.op.scalar_op, TriGamma)
|
4429 | 4430 | assert isinstance(fn_outs[2].owner.op.scalar_op, PolyGamma)
|
| 4431 | + |
| 4432 | + |
| 4433 | +@pytest.mark.skipif( |
| 4434 | + config.mode == "FAST_COMPILE", |
| 4435 | + reason="Rewrite is only relevant in FAST_RUN", |
| 4436 | +) |
| 4437 | +def test_local_batched_matmul_to_core_matmul(): |
| 4438 | + rng = np.random.default_rng(seed=4433) |
| 4439 | + |
| 4440 | + # x is batched but not y |
| 4441 | + x = pt.tensor("x", shape=(None, 3, 2), dtype="float64") |
| 4442 | + y = pt.tensor("y", shape=(2, 2), dtype="float64") |
| 4443 | + out = x @ y |
| 4444 | + assert isinstance(out.owner.op, Blockwise) |
| 4445 | + |
| 4446 | + fn = pytensor.function([x, y], out) |
| 4447 | + assert not any( |
| 4448 | + isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes |
| 4449 | + ) |
| 4450 | + |
| 4451 | + x_test = rng.normal(size=(5, 3, 2)) |
| 4452 | + y_test = rng.normal(size=(2, 2)) |
| 4453 | + np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test) |
| 4454 | + |
| 4455 | + # y is batched but not x |
| 4456 | + x = pt.tensor("x", shape=(1, 3, 2), dtype="float64") |
| 4457 | + y = pt.tensor("y", shape=(5, 2, 2), dtype="float64") |
| 4458 | + out = x @ y |
| 4459 | + assert isinstance(out.owner.op, Blockwise) |
| 4460 | + |
| 4461 | + fn = pytensor.function([x, y], out) |
| 4462 | + assert not any( |
| 4463 | + isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes |
| 4464 | + ) |
| 4465 | + |
| 4466 | + x_test = rng.normal(size=(1, 3, 2)) |
| 4467 | + y_test = rng.normal(size=(5, 2, 2)) |
| 4468 | + np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test) |
| 4469 | + |
| 4470 | + # Both x and y are batched, rewrite does not apply |
| 4471 | + x = pt.tensor("x", shape=(None, 3, 2), dtype="float64") |
| 4472 | + y = pt.tensor("y", shape=(5, 2, 2), dtype="float64") |
| 4473 | + out = x @ y |
| 4474 | + |
| 4475 | + fn = pytensor.function([x, y], out) |
| 4476 | + x_test = rng.normal(size=(5, 3, 2)) |
| 4477 | + y_test = rng.normal(size=(5, 2, 2)) |
| 4478 | + np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test) |
0 commit comments