Skip to content

Commit 7fa8d58

Browse files
committed
Scipy blas is no longer optional
1 parent d4a2b2b commit 7fa8d58

File tree

3 files changed

+35
-76
lines changed

3 files changed

+35
-76
lines changed

pytensor/tensor/blas.py

Lines changed: 14 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -111,50 +111,19 @@
111111

112112
_logger = logging.getLogger("pytensor.tensor.blas")
113113

114-
try:
115-
import scipy.linalg.blas
116-
117-
have_fblas = True
118-
try:
119-
fblas = scipy.linalg.blas.fblas
120-
except AttributeError:
121-
# A change merged in Scipy development version on 2012-12-02 replaced
122-
# `scipy.linalg.blas.fblas` with `scipy.linalg.blas`.
123-
# See http://github.com/scipy/scipy/pull/358
124-
fblas = scipy.linalg.blas
125-
_blas_gemv_fns = {
126-
np.dtype("float32"): fblas.sgemv,
127-
np.dtype("float64"): fblas.dgemv,
128-
np.dtype("complex64"): fblas.cgemv,
129-
np.dtype("complex128"): fblas.zgemv,
130-
}
131-
except ImportError as e:
132-
have_fblas = False
133-
# This is used in Gemv and ScipyGer. We use CGemv and CGer
134-
# when config.blas__ldflags is defined. So we don't need a
135-
# warning in that case.
136-
if not config.blas__ldflags:
137-
_logger.warning(
138-
"Failed to import scipy.linalg.blas, and "
139-
"PyTensor flag blas__ldflags is empty. "
140-
"Falling back on slower implementations for "
141-
"dot(matrix, vector), dot(vector, matrix) and "
142-
f"dot(vector, vector) ({e!s})"
143-
)
144-
145114

146115
# If check_init_y() == True we need to initialize y when beta == 0.
147116
def check_init_y():
117+
# TODO: What is going on here?
118+
from scipy.linalg.blas import get_blas_funcs
119+
148120
if check_init_y._result is None:
149-
if not have_fblas: # pragma: no cover
150-
check_init_y._result = False
151-
else:
152-
y = float("NaN") * np.ones((2,))
153-
x = np.ones((2,))
154-
A = np.ones((2, 2))
155-
gemv = _blas_gemv_fns[y.dtype]
156-
gemv(1.0, A.T, x, 0.0, y, overwrite_y=True, trans=True)
157-
check_init_y._result = np.isnan(y).any()
121+
y = float("NaN") * np.ones((2,))
122+
x = np.ones((2,))
123+
A = np.ones((2, 2))
124+
gemv = get_blas_funcs("gemv", dtype=y.dtype)
125+
gemv(1.0, A.T, x, 0.0, y, overwrite_y=True, trans=True)
126+
check_init_y._result = np.isnan(y).any()
158127

159128
return check_init_y._result
160129

@@ -211,14 +180,15 @@ def make_node(self, y, alpha, A, x, beta):
211180
return Apply(self, inputs, [y.type()])
212181

213182
def perform(self, node, inputs, out_storage):
183+
from scipy.linalg.blas import get_blas_funcs
184+
214185
y, alpha, A, x, beta = inputs
215186
if (
216-
have_fblas
217-
and y.shape[0] != 0
187+
y.shape[0] != 0
218188
and x.shape[0] != 0
219-
and y.dtype in _blas_gemv_fns
189+
and y.dtype in {"float32", "float64", "complex64", "complex128"}
220190
):
221-
gemv = _blas_gemv_fns[y.dtype]
191+
gemv = get_blas_funcs("gemv", dtype=y.dtype)
222192

223193
if A.shape[0] != y.shape[0] or A.shape[1] != x.shape[0]:
224194
raise ValueError(

pytensor/tensor/blas_scipy.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,19 @@
22
Implementations of BLAS Ops based on scipy's BLAS bindings.
33
"""
44

5-
import numpy as np
6-
7-
from pytensor.tensor.blas import Ger, have_fblas
8-
9-
10-
if have_fblas:
11-
from pytensor.tensor.blas import fblas
12-
13-
_blas_ger_fns = {
14-
np.dtype("float32"): fblas.sger,
15-
np.dtype("float64"): fblas.dger,
16-
np.dtype("complex64"): fblas.cgeru,
17-
np.dtype("complex128"): fblas.zgeru,
18-
}
5+
from pytensor.tensor.blas import Ger
196

207

218
class ScipyGer(Ger):
229
def perform(self, node, inputs, output_storage):
10+
from scipy.linalg.blas import get_blas_funcs
11+
2312
cA, calpha, cx, cy = inputs
2413
(cZ,) = output_storage
2514
# N.B. some versions of scipy (e.g. mine) don't actually work
2615
# in-place on a, even when I tell it to.
2716
A = cA
28-
local_ger = _blas_ger_fns[cA.dtype]
17+
local_ger = get_blas_funcs("ger", dtype=cA.dtype)
2918
if A.size == 0:
3019
# We don't have to compute anything, A is empty.
3120
# We need this special case because Numpy considers it
Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pytensor.graph.rewriting.basic import in2out
2-
from pytensor.tensor.blas import ger, ger_destructive, have_fblas
2+
from pytensor.tensor.blas import ger, ger_destructive
33
from pytensor.tensor.blas_scipy import scipy_ger_inplace, scipy_ger_no_inplace
44
from pytensor.tensor.rewriting.blas import blas_optdb, node_rewriter, optdb
55

@@ -19,19 +19,19 @@ def make_ger_destructive(fgraph, node):
1919
use_scipy_blas = in2out(use_scipy_ger)
2020
make_scipy_blas_destructive = in2out(make_ger_destructive)
2121

22-
if have_fblas:
23-
# scipy_blas is scheduled in the blas_optdb very late, because scipy sortof
24-
# sucks, but it is almost always present.
25-
# C implementations should be scheduled earlier than this, so that they take
26-
# precedence. Once the original Ger is replaced, then these optimizations
27-
# have no effect.
28-
blas_optdb.register("scipy_blas", use_scipy_blas, "fast_run", position=100)
29-
30-
# this matches the InplaceBlasOpt defined in blas.py
31-
optdb.register(
32-
"make_scipy_blas_destructive",
33-
make_scipy_blas_destructive,
34-
"fast_run",
35-
"inplace",
36-
position=50.2,
37-
)
22+
23+
# scipy_blas is scheduled in the blas_optdb very late, because scipy sortof
24+
# sucks [citation needed], but it is almost always present.
25+
# C implementations should be scheduled earlier than this, so that they take
26+
# precedence. Once the original Ger is replaced, then these optimizations
27+
# have no effect.
28+
blas_optdb.register("scipy_blas", use_scipy_blas, "fast_run", position=100)
29+
30+
# this matches the InplaceBlasOpt defined in blas.py
31+
optdb.register(
32+
"make_scipy_blas_destructive",
33+
make_scipy_blas_destructive,
34+
"fast_run",
35+
"inplace",
36+
position=50.2,
37+
)

0 commit comments

Comments
 (0)