Skip to content

Commit d607e23

Browse files
committed
Do not generate C code for BatchedDot when BLAS flags are missing
1 parent 3f960de commit d607e23

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

pytensor/tensor/blas.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1795,6 +1795,10 @@ def c_header_dirs(self, **kwargs):
17951795
return ldflags(libs=False, include_dir=True)
17961796

17971797
def c_code(self, node, name, inp, out, sub):
1798+
# Can only compile if linked to blas libraries
1799+
if len(self.c_libraries()) <= 0:
1800+
raise NotImplementedError()
1801+
17981802
_x, _y = inp
17991803
(_z,) = out
18001804
fail = sub["fail"]

tests/tensor/test_blas.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from pytensor.tensor import inplace
2424
from pytensor.tensor.basic import as_tensor_variable
2525
from pytensor.tensor.blas import (
26+
BatchedDot,
2627
Dot22,
2728
Dot22Scalar,
2829
Gemm,
@@ -2700,6 +2701,30 @@ def check_first_dim(inverted):
27002701
check_first_dim(inverted)
27012702

27022703

2704+
def test_batched_dot_blas_flags():
2705+
"""Test that BatchedDot works regardless of presence of BLAS flags"""
2706+
mode = "FAST_RUN"
2707+
rng = np.random.default_rng(2708)
2708+
2709+
x = tensor("x", shape=(2, 5, 3))
2710+
y = tensor("y", shape=(2, 3, 1))
2711+
out = batched_dot(x, y)
2712+
assert isinstance(out.owner.op, BatchedDot)
2713+
x_test = rng.normal(size=x.type.shape).astype(x.type.dtype)
2714+
y_test = rng.normal(size=y.type.shape).astype(y.type.dtype)
2715+
2716+
fn = function([x, y], out, mode=mode)
2717+
[batched_dot_thunk] = fn.vm.thunks
2718+
assert hasattr(batched_dot_thunk, "cthunk")
2719+
np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test)
2720+
2721+
with config.change_flags(blas__ldflags=""):
2722+
fn = function([x, y], out, mode=mode)
2723+
[batched_dot_thunk] = fn.vm.thunks
2724+
assert not hasattr(batched_dot_thunk, "cthunk")
2725+
np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test)
2726+
2727+
27032728
def test_batched_tensordot():
27042729
rng = np.random.default_rng(unittest_tools.fetch_seed())
27052730
first = tensor4("first")

0 commit comments

Comments
 (0)