|
111 | 111 |
|
112 | 112 | _logger = logging.getLogger("pytensor.tensor.blas") |
113 | 113 |
|
114 | | -try: |
115 | | - import scipy.linalg.blas |
116 | | - |
117 | | - have_fblas = True |
118 | | - try: |
119 | | - fblas = scipy.linalg.blas.fblas |
120 | | - except AttributeError: |
121 | | - # A change merged in Scipy development version on 2012-12-02 replaced |
122 | | - # `scipy.linalg.blas.fblas` with `scipy.linalg.blas`. |
123 | | - # See http://github.com/scipy/scipy/pull/358 |
124 | | - fblas = scipy.linalg.blas |
125 | | - _blas_gemv_fns = { |
126 | | - np.dtype("float32"): fblas.sgemv, |
127 | | - np.dtype("float64"): fblas.dgemv, |
128 | | - np.dtype("complex64"): fblas.cgemv, |
129 | | - np.dtype("complex128"): fblas.zgemv, |
130 | | - } |
131 | | -except ImportError as e: |
132 | | - have_fblas = False |
133 | | - # This is used in Gemv and ScipyGer. We use CGemv and CGer |
134 | | - # when config.blas__ldflags is defined. So we don't need a |
135 | | - # warning in that case. |
136 | | - if not config.blas__ldflags: |
137 | | - _logger.warning( |
138 | | - "Failed to import scipy.linalg.blas, and " |
139 | | - "PyTensor flag blas__ldflags is empty. " |
140 | | - "Falling back on slower implementations for " |
141 | | - "dot(matrix, vector), dot(vector, matrix) and " |
142 | | - f"dot(vector, vector) ({e!s})" |
143 | | - ) |
144 | | - |
145 | 114 |
|
146 | 115 | # If check_init_y() == True we need to initialize y when beta == 0. |
147 | 116 | def check_init_y(): |
| 117 | + # TODO: What is going on here? |
| 118 | + from scipy.linalg.blas import get_blas_funcs |
| 119 | + |
148 | 120 | if check_init_y._result is None: |
149 | | - if not have_fblas: # pragma: no cover |
150 | | - check_init_y._result = False |
151 | | - else: |
152 | | - y = float("NaN") * np.ones((2,)) |
153 | | - x = np.ones((2,)) |
154 | | - A = np.ones((2, 2)) |
155 | | - gemv = _blas_gemv_fns[y.dtype] |
156 | | - gemv(1.0, A.T, x, 0.0, y, overwrite_y=True, trans=True) |
157 | | - check_init_y._result = np.isnan(y).any() |
| 121 | + y = float("NaN") * np.ones((2,)) |
| 122 | + x = np.ones((2,)) |
| 123 | + A = np.ones((2, 2)) |
| 124 | + gemv = get_blas_funcs("gemv", dtype=y.dtype) |
| 125 | + gemv(1.0, A.T, x, 0.0, y, overwrite_y=True, trans=True) |
| 126 | + check_init_y._result = np.isnan(y).any() |
158 | 127 |
|
159 | 128 | return check_init_y._result |
160 | 129 |
|
@@ -211,14 +180,15 @@ def make_node(self, y, alpha, A, x, beta): |
211 | 180 | return Apply(self, inputs, [y.type()]) |
212 | 181 |
|
213 | 182 | def perform(self, node, inputs, out_storage): |
| 183 | + from scipy.linalg.blas import get_blas_funcs |
| 184 | + |
214 | 185 | y, alpha, A, x, beta = inputs |
215 | 186 | if ( |
216 | | - have_fblas |
217 | | - and y.shape[0] != 0 |
| 187 | + y.shape[0] != 0 |
218 | 188 | and x.shape[0] != 0 |
219 | | - and y.dtype in _blas_gemv_fns |
| 189 | + and y.dtype in {"float32", "float64", "complex64", "complex128"} |
220 | 190 | ): |
221 | | - gemv = _blas_gemv_fns[y.dtype] |
| 191 | + gemv = get_blas_funcs("gemv", dtype=y.dtype) |
222 | 192 |
|
223 | 193 | if A.shape[0] != y.shape[0] or A.shape[1] != x.shape[0]: |
224 | 194 | raise ValueError( |
|
0 commit comments