Skip to content

Commit 89af81b

Browse files
committed
Make Blas flags check lazy
It replaces the old warning that does not actually apply by a more informative and actionable one. This warning was for Ops that might use the alternative blas_headers, which rely on the Numpy C-API. However, regular PyTensor user has not used this for a while. The only Op that would use C-code with this alternative headers is the GEMM Op which is not included in current rewrites. Instead Dot22 or Dot22Scalar are introduced, which refuse to generate C-code altogether if the blas flags are missing.
1 parent a0fe30d commit 89af81b

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

pytensor/link/c/cmodule.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2947,6 +2947,13 @@ def check_libs(
29472947
except Exception as e:
29482948
_logger.debug(e)
29492949
_logger.debug("Failed to identify blas ldflags. Will leave them empty.")
2950+
warnings.warn(
2951+
"PyTensor could not link to a BLAS installation. Operations that might benefit from BLAS will be severely degraded.\n"
2952+
"This usually happens when PyTensor is installed via pip. We recommend it be installed via conda/mamba/pixi instead.\n"
2953+
"Alternatively, you can use a experimental backend such as Numba or JAX that perform their own BLAS optimizations, "
2954+
"by setting `pytensor.config.mode == 'NUMBA'` or passing `mode='NUMBA'` when compiling a PyTensor function.",
2955+
UserWarning,
2956+
)
29502957
return ""
29512958

29522959

pytensor/tensor/blas_headers.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -742,6 +742,11 @@ def blas_header_text():
742742

743743
blas_code = ""
744744
if not config.blas__ldflags:
745+
# This code can only be reached by compiling a function with a manually specified GEMM Op.
746+
# Normal PyTensor usage will end up with Dot22 or Dot22Scalar instead,
747+
# which opt out of C-code completely if the blas flags are missing
748+
_logger.warning("Using NumPy C-API based implementation for BLAS functions.")
749+
745750
# Include the Numpy version implementation of [sd]gemm_.
746751
current_filedir = Path(__file__).parent
747752
blas_common_filepath = current_filedir / "c_code/alt_blas_common.h"
@@ -1003,10 +1008,6 @@ def blas_header_text():
10031008
return header + blas_code
10041009

10051010

1006-
if not config.blas__ldflags:
1007-
_logger.warning("Using NumPy C-API based implementation for BLAS functions.")
1008-
1009-
10101011
def mkl_threads_text():
10111012
"""C header for MKL threads interface"""
10121013
header = """

0 commit comments

Comments
 (0)