Skip to content

Commit bb45939

Browse files
committed
Remove blas_scipy
1 parent a6b729e commit bb45939

File tree

6 files changed

+19
-152
lines changed

6 files changed

+19
-152
lines changed

pytensor/tensor/basic.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1801,8 +1801,7 @@ def do_constant_folding(self, fgraph, node):
18011801
| pytensor.tensor.blas.Gemv
18021802
| pytensor.tensor.blas_c.CGemv
18031803
| pytensor.tensor.blas.Ger
1804-
| pytensor.tensor.blas_c.CGer
1805-
| pytensor.tensor.blas_scipy.ScipyGer,
1804+
| pytensor.tensor.blas_c.CGer,
18061805
)
18071806
):
18081807
# Ops that will work inplace on the Alloc. So if they

pytensor/tensor/blas.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
from pathlib import Path
8484

8585
import numpy as np
86+
from scipy.linalg import get_blas_funcs
8687

8788
from pytensor.graph import vectorize_graph
8889
from pytensor.npy_2_compat import normalize_axis_tuple
@@ -288,18 +289,15 @@ def make_node(self, A, alpha, x, y):
288289

289290
return Apply(self, inputs, [A.type()])
290291

291-
def perform(self, node, inp, out):
292-
cA, calpha, cx, cy = inp
293-
(cZ,) = out
294-
if self.destructive:
295-
A = cA
296-
else:
297-
A = cA.copy()
298-
if calpha != 1:
299-
A += calpha * np.outer(cx, cy)
292+
def perform(self, node, inputs, output_storage):
293+
A, alpha, x, y = inputs
294+
ger_func = get_blas_funcs("ger", dtype=A.dtype)
295+
if A.flags["C_CONTIGUOUS"]:
296+
# Work on transposed system to avoid copying
297+
A = ger_func(alpha, y, x, a=A.T, overwrite_a=self.destructive).T
300298
else:
301-
A += np.outer(cx, cy)
302-
cZ[0] = A
299+
A = ger_func(alpha, x, y, a=A, overwrite_a=self.destructive)
300+
output_storage[0][0] = A
303301

304302
def infer_shape(self, fgraph, node, input_shapes):
305303
return [input_shapes[0]]
@@ -1128,16 +1126,8 @@ def make_node(self, x, y):
11281126
outputs = [tensor(dtype=x.type.dtype, shape=(x.type.shape[0], y.type.shape[1]))]
11291127
return Apply(self, [x, y], outputs)
11301128

1131-
def perform(self, node, inp, out):
1132-
x, y = inp
1133-
(z,) = out
1134-
try:
1135-
z[0] = np.asarray(np.dot(x, y))
1136-
except ValueError as e:
1137-
# The error raised by numpy has no shape information, we mean to
1138-
# add that
1139-
e.args = (*e.args, x.shape, y.shape)
1140-
raise
1129+
def perform(self, node, inputs, output_storage):
1130+
output_storage[0][0] = np.dot(*inputs)
11411131

11421132
def infer_shape(self, fgraph, node, input_shapes):
11431133
return [[input_shapes[0][0], input_shapes[1][1]]]

pytensor/tensor/blas_scipy.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,31 +2,22 @@
22
Implementations of BLAS Ops based on scipy's BLAS bindings.
33
"""
44

5+
from scipy.linalg.blas import get_blas_funcs
6+
57
from pytensor.tensor.blas import Ger
68

79

810
class ScipyGer(Ger):
911
def perform(self, node, inputs, output_storage):
10-
from scipy.linalg.blas import get_blas_funcs
11-
1212
cA, calpha, cx, cy = inputs
1313
(cZ,) = output_storage
14-
# N.B. some versions of scipy (e.g. mine) don't actually work
15-
# in-place on a, even when I tell it to.
1614
A = cA
17-
local_ger = get_blas_funcs("ger", dtype=cA.dtype)
18-
if A.size == 0:
19-
# We don't have to compute anything, A is empty.
20-
# We need this special case because Numpy considers it
21-
# C-contiguous, which is confusing.
22-
if not self.destructive:
23-
# Sometimes numpy thinks empty matrices can share memory,
24-
# so here to stop DebugMode from complaining.
25-
A = A.copy()
26-
elif A.flags["C_CONTIGUOUS"]:
27-
A = local_ger(calpha, cy, cx, a=A.T, overwrite_a=int(self.destructive)).T
15+
ger_func = get_blas_funcs("ger", dtype=cA.dtype)
16+
if A.flags["C_CONTIGUOUS"]:
17+
# Work on transposed system to avoid copying
18+
A = ger_func(calpha, cy, cx, a=A.T, overwrite_a=self.destructive).T
2819
else:
29-
A = local_ger(calpha, cx, cy, a=A, overwrite_a=int(self.destructive))
20+
A = ger_func(calpha, cx, cy, a=A, overwrite_a=self.destructive)
3021
cZ[0] = A
3122

3223

pytensor/tensor/rewriting/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import pytensor.tensor.rewriting.basic
22
import pytensor.tensor.rewriting.blas
33
import pytensor.tensor.rewriting.blas_c
4-
import pytensor.tensor.rewriting.blas_scipy
54
import pytensor.tensor.rewriting.blockwise
65
import pytensor.tensor.rewriting.einsum
76
import pytensor.tensor.rewriting.elemwise

pytensor/tensor/rewriting/blas_scipy.py

Lines changed: 0 additions & 37 deletions
This file was deleted.

tests/tensor/test_blas_scipy.py

Lines changed: 0 additions & 75 deletions
This file was deleted.

0 commit comments

Comments
 (0)