|
23 | 23 | from pytensor.tensor import inplace
|
24 | 24 | from pytensor.tensor.basic import as_tensor_variable
|
25 | 25 | from pytensor.tensor.blas import (
|
| 26 | + BatchedDot, |
26 | 27 | Dot22,
|
27 | 28 | Dot22Scalar,
|
28 | 29 | Gemm,
|
@@ -2700,6 +2701,30 @@ def check_first_dim(inverted):
|
2700 | 2701 | check_first_dim(inverted)
|
2701 | 2702 |
|
2702 | 2703 |
|
| 2704 | +def test_batched_dot_blas_flags(): |
| 2705 | + """Test that BatchedDot works regardless of presence of BLAS flags""" |
| 2706 | + mode = "FAST_RUN" |
| 2707 | + rng = np.random.default_rng(2708) |
| 2708 | + |
| 2709 | + x = tensor("x", shape=(2, 5, 3)) |
| 2710 | + y = tensor("y", shape=(2, 3, 1)) |
| 2711 | + out = batched_dot(x, y) |
| 2712 | + assert isinstance(out.owner.op, BatchedDot) |
| 2713 | + x_test = rng.normal(size=x.type.shape).astype(x.type.dtype) |
| 2714 | + y_test = rng.normal(size=y.type.shape).astype(y.type.dtype) |
| 2715 | + |
| 2716 | + fn = function([x, y], out, mode=mode) |
| 2717 | + [batched_dot_thunk] = fn.vm.thunks |
| 2718 | + assert hasattr(batched_dot_thunk, "cthunk") |
| 2719 | + np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test) |
| 2720 | + |
| 2721 | + with config.change_flags(blas__ldflags=""): |
| 2722 | + fn = function([x, y], out, mode=mode) |
| 2723 | + [batched_dot_thunk] = fn.vm.thunks |
| 2724 | + assert not hasattr(batched_dot_thunk, "cthunk") |
| 2725 | + np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test) |
| 2726 | + |
| 2727 | + |
2703 | 2728 | def test_batched_tensordot():
|
2704 | 2729 | rng = np.random.default_rng(unittest_tools.fetch_seed())
|
2705 | 2730 | first = tensor4("first")
|
|
0 commit comments