Skip to content

Commit 79eee67

Browse files
committed
Scipy blas is no longer optional
1 parent 6bdfbae commit 79eee67

File tree

4 files changed

+35
-78
lines changed

4 files changed

+35
-78
lines changed

pytensor/tensor/blas.py

Lines changed: 14 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -108,50 +108,19 @@
108108

109109
_logger = logging.getLogger("pytensor.tensor.blas")
110110

111-
try:
112-
import scipy.linalg.blas
113-
114-
have_fblas = True
115-
try:
116-
fblas = scipy.linalg.blas.fblas
117-
except AttributeError:
118-
# A change merged in Scipy development version on 2012-12-02 replaced
119-
# `scipy.linalg.blas.fblas` with `scipy.linalg.blas`.
120-
# See http://github.com/scipy/scipy/pull/358
121-
fblas = scipy.linalg.blas
122-
_blas_gemv_fns = {
123-
np.dtype("float32"): fblas.sgemv,
124-
np.dtype("float64"): fblas.dgemv,
125-
np.dtype("complex64"): fblas.cgemv,
126-
np.dtype("complex128"): fblas.zgemv,
127-
}
128-
except ImportError as e:
129-
have_fblas = False
130-
# This is used in Gemv and ScipyGer. We use CGemv and CGer
131-
# when config.blas__ldflags is defined. So we don't need a
132-
# warning in that case.
133-
if not config.blas__ldflags:
134-
_logger.warning(
135-
"Failed to import scipy.linalg.blas, and "
136-
"PyTensor flag blas__ldflags is empty. "
137-
"Falling back on slower implementations for "
138-
"dot(matrix, vector), dot(vector, matrix) and "
139-
f"dot(vector, vector) ({e!s})"
140-
)
141-
142111

143112
# If check_init_y() == True we need to initialize y when beta == 0.
144113
def check_init_y():
114+
# TODO: What is going on here?
115+
from scipy.linalg.blas import get_blas_funcs
116+
145117
if check_init_y._result is None:
146-
if not have_fblas: # pragma: no cover
147-
check_init_y._result = False
148-
else:
149-
y = float("NaN") * np.ones((2,))
150-
x = np.ones((2,))
151-
A = np.ones((2, 2))
152-
gemv = _blas_gemv_fns[y.dtype]
153-
gemv(1.0, A.T, x, 0.0, y, overwrite_y=True, trans=True)
154-
check_init_y._result = np.isnan(y).any()
118+
y = float("NaN") * np.ones((2,))
119+
x = np.ones((2,))
120+
A = np.ones((2, 2))
121+
gemv = get_blas_funcs("gemv", dtype=y.dtype)
122+
gemv(1.0, A.T, x, 0.0, y, overwrite_y=True, trans=True)
123+
check_init_y._result = np.isnan(y).any()
155124

156125
return check_init_y._result
157126

@@ -208,14 +177,15 @@ def make_node(self, y, alpha, A, x, beta):
208177
return Apply(self, inputs, [y.type()])
209178

210179
def perform(self, node, inputs, out_storage):
180+
from scipy.linalg.blas import get_blas_funcs
181+
211182
y, alpha, A, x, beta = inputs
212183
if (
213-
have_fblas
214-
and y.shape[0] != 0
184+
y.shape[0] != 0
215185
and x.shape[0] != 0
216-
and y.dtype in _blas_gemv_fns
186+
and y.dtype in {"float32", "float64", "complex64", "complex128"}
217187
):
218-
gemv = _blas_gemv_fns[y.dtype]
188+
gemv = get_blas_funcs("gemv", dtype=y.dtype)
219189

220190
if A.shape[0] != y.shape[0] or A.shape[1] != x.shape[0]:
221191
raise ValueError(

pytensor/tensor/blas_scipy.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,19 @@
22
Implementations of BLAS Ops based on scipy's BLAS bindings.
33
"""
44

5-
import numpy as np
6-
7-
from pytensor.tensor.blas import Ger, have_fblas
8-
9-
10-
if have_fblas:
11-
from pytensor.tensor.blas import fblas
12-
13-
_blas_ger_fns = {
14-
np.dtype("float32"): fblas.sger,
15-
np.dtype("float64"): fblas.dger,
16-
np.dtype("complex64"): fblas.cgeru,
17-
np.dtype("complex128"): fblas.zgeru,
18-
}
5+
from pytensor.tensor.blas import Ger
196

207

218
class ScipyGer(Ger):
229
def perform(self, node, inputs, output_storage):
10+
from scipy.linalg.blas import get_blas_funcs
11+
2312
cA, calpha, cx, cy = inputs
2413
(cZ,) = output_storage
2514
# N.B. some versions of scipy (e.g. mine) don't actually work
2615
# in-place on a, even when I tell it to.
2716
A = cA
28-
local_ger = _blas_ger_fns[cA.dtype]
17+
local_ger = get_blas_funcs("ger", dtype=cA.dtype)
2918
if A.size == 0:
3019
# We don't have to compute anything, A is empty.
3120
# We need this special case because Numpy considers it
Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pytensor.graph.rewriting.basic import in2out
2-
from pytensor.tensor.blas import ger, ger_destructive, have_fblas
2+
from pytensor.tensor.blas import ger, ger_destructive
33
from pytensor.tensor.blas_scipy import scipy_ger_inplace, scipy_ger_no_inplace
44
from pytensor.tensor.rewriting.blas import blas_optdb, node_rewriter, optdb
55

@@ -19,19 +19,19 @@ def make_ger_destructive(fgraph, node):
1919
use_scipy_blas = in2out(use_scipy_ger)
2020
make_scipy_blas_destructive = in2out(make_ger_destructive)
2121

22-
if have_fblas:
23-
# scipy_blas is scheduled in the blas_optdb very late, because scipy sortof
24-
# sucks, but it is almost always present.
25-
# C implementations should be scheduled earlier than this, so that they take
26-
# precedence. Once the original Ger is replaced, then these optimizations
27-
# have no effect.
28-
blas_optdb.register("scipy_blas", use_scipy_blas, "fast_run", position=100)
29-
30-
# this matches the InplaceBlasOpt defined in blas.py
31-
optdb.register(
32-
"make_scipy_blas_destructive",
33-
make_scipy_blas_destructive,
34-
"fast_run",
35-
"inplace",
36-
position=50.2,
37-
)
22+
23+
# scipy_blas is scheduled in the blas_optdb very late, because scipy sortof
24+
# sucks [citation needed], but it is almost always present.
25+
# C implementations should be scheduled earlier than this, so that they take
26+
# precedence. Once the original Ger is replaced, then these optimizations
27+
# have no effect.
28+
blas_optdb.register("scipy_blas", use_scipy_blas, "fast_run", position=100)
29+
30+
# this matches the InplaceBlasOpt defined in blas.py
31+
optdb.register(
32+
"make_scipy_blas_destructive",
33+
make_scipy_blas_destructive,
34+
"fast_run",
35+
"inplace",
36+
position=50.2,
37+
)

tests/tensor/test_blas_scipy.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import pickle
22

33
import numpy as np
4-
import pytest
54

65
import pytensor
76
from pytensor import tensor as pt
@@ -12,7 +11,6 @@
1211
from tests.unittest_tools import OptimizationTestMixin
1312

1413

15-
@pytest.mark.skipif(not pytensor.tensor.blas_scipy.have_fblas, reason="fblas needed")
1614
class TestScipyGer(OptimizationTestMixin):
1715
def setup_method(self):
1816
self.mode = pytensor.compile.get_default_mode()

0 commit comments

Comments
 (0)