Skip to content

Commit 8090d8a

Browse files
committed
Remove os.path in blas_headers.py
1 parent 465b3b3 commit 8090d8a

File tree

1 file changed

+22
-31
lines changed

1 file changed

+22
-31
lines changed

pytensor/tensor/blas_headers.py

Lines changed: 22 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import os
1010
import sys
1111
import textwrap
12-
from os.path import dirname
12+
from pathlib import Path
1313

1414
from pytensor.configdefaults import config
1515
from pytensor.link.c.cmodule import GCC_compiler
@@ -743,36 +743,27 @@ def blas_header_text():
743743
blas_code = ""
744744
if not config.blas__ldflags:
745745
# Include the Numpy version implementation of [sd]gemm_.
746-
current_filedir = dirname(__file__)
747-
blas_common_filepath = os.path.join(
748-
current_filedir, "c_code", "alt_blas_common.h"
749-
)
750-
blas_template_filepath = os.path.join(
751-
current_filedir, "c_code", "alt_blas_template.c"
752-
)
753-
common_code = ""
754-
sblas_code = ""
755-
dblas_code = ""
756-
with open(blas_common_filepath) as code:
757-
common_code = code.read()
758-
with open(blas_template_filepath) as code:
759-
template_code = code.read()
760-
sblas_code = template_code % {
761-
"float_type": "float",
762-
"float_size": 4,
763-
"npy_float": "NPY_FLOAT32",
764-
"precision": "s",
765-
}
766-
dblas_code = template_code % {
767-
"float_type": "double",
768-
"float_size": 8,
769-
"npy_float": "NPY_FLOAT64",
770-
"precision": "d",
771-
}
772-
if not common_code or not template_code:
773-
raise OSError(
774-
"Unable to load NumPy implementation of BLAS functions from C source files."
775-
)
746+
current_filedir = Path(__file__).parent
747+
blas_common_filepath = current_filedir / "c_code/alt_blas_common.h"
748+
blas_template_filepath = current_filedir / "c_code/alt_blas_template.c"
749+
try:
750+
common_code = blas_common_filepath.read_text(encoding="utf-8")
751+
template_code = blas_template_filepath.read_text(encoding="utf-8")
752+
except OSError as err:
753+
msg = "Unable to load NumPy implementation of BLAS functions from C source files."
754+
raise OSError(msg) from err
755+
sblas_code = template_code % {
756+
"float_type": "float",
757+
"float_size": 4,
758+
"npy_float": "NPY_FLOAT32",
759+
"precision": "s",
760+
}
761+
dblas_code = template_code % {
762+
"float_type": "double",
763+
"float_size": 8,
764+
"npy_float": "NPY_FLOAT64",
765+
"precision": "d",
766+
}
776767
blas_code += common_code
777768
blas_code += sblas_code
778769
blas_code += dblas_code

0 commit comments

Comments
 (0)