|
9 | 9 | import os
|
10 | 10 | import sys
|
11 | 11 | import textwrap
|
12 |
| -from os.path import dirname |
| 12 | +from pathlib import Path |
13 | 13 |
|
14 | 14 | from pytensor.configdefaults import config
|
15 | 15 | from pytensor.link.c.cmodule import GCC_compiler
|
@@ -743,36 +743,27 @@ def blas_header_text():
|
743 | 743 | blas_code = ""
|
744 | 744 | if not config.blas__ldflags:
|
745 | 745 | # Include the Numpy version implementation of [sd]gemm_.
|
746 |
| - current_filedir = dirname(__file__) |
747 |
| - blas_common_filepath = os.path.join( |
748 |
| - current_filedir, "c_code", "alt_blas_common.h" |
749 |
| - ) |
750 |
| - blas_template_filepath = os.path.join( |
751 |
| - current_filedir, "c_code", "alt_blas_template.c" |
752 |
| - ) |
753 |
| - common_code = "" |
754 |
| - sblas_code = "" |
755 |
| - dblas_code = "" |
756 |
| - with open(blas_common_filepath) as code: |
757 |
| - common_code = code.read() |
758 |
| - with open(blas_template_filepath) as code: |
759 |
| - template_code = code.read() |
760 |
| - sblas_code = template_code % { |
761 |
| - "float_type": "float", |
762 |
| - "float_size": 4, |
763 |
| - "npy_float": "NPY_FLOAT32", |
764 |
| - "precision": "s", |
765 |
| - } |
766 |
| - dblas_code = template_code % { |
767 |
| - "float_type": "double", |
768 |
| - "float_size": 8, |
769 |
| - "npy_float": "NPY_FLOAT64", |
770 |
| - "precision": "d", |
771 |
| - } |
772 |
| - if not common_code or not template_code: |
773 |
| - raise OSError( |
774 |
| - "Unable to load NumPy implementation of BLAS functions from C source files." |
775 |
| - ) |
| 746 | + current_filedir = Path(__file__).parent |
| 747 | + blas_common_filepath = current_filedir / "c_code/alt_blas_common.h" |
| 748 | + blas_template_filepath = current_filedir / "c_code/alt_blas_template.c" |
| 749 | + try: |
| 750 | + common_code = blas_common_filepath.read_text(encoding="utf-8") |
| 751 | + template_code = blas_template_filepath.read_text(encoding="utf-8") |
| 752 | + except OSError as err: |
| 753 | + msg = "Unable to load NumPy implementation of BLAS functions from C source files." |
| 754 | + raise OSError(msg) from err |
| 755 | + sblas_code = template_code % { |
| 756 | + "float_type": "float", |
| 757 | + "float_size": 4, |
| 758 | + "npy_float": "NPY_FLOAT32", |
| 759 | + "precision": "s", |
| 760 | + } |
| 761 | + dblas_code = template_code % { |
| 762 | + "float_type": "double", |
| 763 | + "float_size": 8, |
| 764 | + "npy_float": "NPY_FLOAT64", |
| 765 | + "precision": "d", |
| 766 | + } |
776 | 767 | blas_code += common_code
|
777 | 768 | blas_code += sblas_code
|
778 | 769 | blas_code += dblas_code
|
|
0 commit comments