Skip to content

Commit c3fa8d3

Browse files
committed
Remove blas_scipy
1 parent c3ff864 commit c3fa8d3

File tree

9 files changed

+15
-176
lines changed

9 files changed

+15
-176
lines changed

pytensor/tensor/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int:
107107
from pytensor.tensor import (
108108
blas,
109109
blas_c,
110-
blas_scipy,
111110
sharedvar,
112111
xlogx,
113112
)

pytensor/tensor/basic.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1801,8 +1801,7 @@ def do_constant_folding(self, fgraph, node):
18011801
| pytensor.tensor.blas.Gemv
18021802
| pytensor.tensor.blas_c.CGemv
18031803
| pytensor.tensor.blas.Ger
1804-
| pytensor.tensor.blas_c.CGer
1805-
| pytensor.tensor.blas_scipy.ScipyGer,
1804+
| pytensor.tensor.blas_c.CGer,
18061805
)
18071806
):
18081807
# Ops that will work inplace on the Alloc. So if they

pytensor/tensor/blas.py

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
from pathlib import Path
8484

8585
import numpy as np
86+
from scipy.linalg import get_blas_funcs
8687

8788
from pytensor.graph import vectorize_graph
8889
from pytensor.npy_2_compat import normalize_axis_tuple
@@ -288,18 +289,17 @@ def make_node(self, A, alpha, x, y):
288289

289290
return Apply(self, inputs, [A.type()])
290291

291-
def perform(self, node, inp, out):
292-
cA, calpha, cx, cy = inp
293-
(cZ,) = out
294-
if self.destructive:
295-
A = cA
296-
else:
297-
A = cA.copy()
298-
if calpha != 1:
299-
A += calpha * np.outer(cx, cy)
300-
else:
301-
A += np.outer(cx, cy)
302-
cZ[0] = A
292+
def perform(self, node, inputs, output_storage):
293+
A, alpha, x, y = inputs
294+
if A.size:
295+
# GER doesn't handle zero-sized inputs
296+
ger_func = get_blas_funcs("ger", dtype=A.dtype)
297+
if A.flags["C_CONTIGUOUS"]:
298+
# Work on transposed system to avoid copying
299+
A = ger_func(alpha, y, x, a=A.T, overwrite_a=self.destructive).T
300+
else:
301+
A = ger_func(alpha, x, y, a=A, overwrite_a=self.destructive)
302+
output_storage[0][0] = A
303303

304304
def infer_shape(self, fgraph, node, input_shapes):
305305
return [input_shapes[0]]
@@ -1128,16 +1128,8 @@ def make_node(self, x, y):
11281128
outputs = [tensor(dtype=x.type.dtype, shape=(x.type.shape[0], y.type.shape[1]))]
11291129
return Apply(self, [x, y], outputs)
11301130

1131-
def perform(self, node, inp, out):
1132-
x, y = inp
1133-
(z,) = out
1134-
try:
1135-
z[0] = np.asarray(np.dot(x, y))
1136-
except ValueError as e:
1137-
# The error raised by numpy has no shape information, we mean to
1138-
# add that
1139-
e.args = (*e.args, x.shape, y.shape)
1140-
raise
1131+
def perform(self, node, inputs, output_storage):
1132+
output_storage[0][0] = np.dot(*inputs)
11411133

11421134
def infer_shape(self, fgraph, node, input_shapes):
11431135
return [[input_shapes[0][0], input_shapes[1][1]]]

pytensor/tensor/blas_scipy.py

Lines changed: 0 additions & 34 deletions
This file was deleted.

pytensor/tensor/rewriting/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import pytensor.tensor.rewriting.basic
22
import pytensor.tensor.rewriting.blas
33
import pytensor.tensor.rewriting.blas_c
4-
import pytensor.tensor.rewriting.blas_scipy
54
import pytensor.tensor.rewriting.blockwise
65
import pytensor.tensor.rewriting.einsum
76
import pytensor.tensor.rewriting.elemwise

pytensor/tensor/rewriting/blas_scipy.py

Lines changed: 0 additions & 37 deletions
This file was deleted.

tests/tensor/test_blas.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import pytensor
1010
import pytensor.scalar as ps
1111
import pytensor.tensor as pt
12-
import pytensor.tensor.blas_scipy
1312
from pytensor.compile.function import function
1413
from pytensor.compile.io import In
1514
from pytensor.compile.mode import Mode

tests/tensor/test_blas_c.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from pytensor.tensor.basic import AllocEmpty
99
from pytensor.tensor.blas import Ger
1010
from pytensor.tensor.blas_c import CGemv, CGer, must_initialize_y_gemv
11-
from pytensor.tensor.blas_scipy import ScipyGer
1211
from pytensor.tensor.type import dmatrix, dvector, matrix, scalar, tensor, vector
1312
from tests import unittest_tools
1413
from tests.tensor.test_blas import BaseGemv, TestBlasStrides
@@ -68,8 +67,6 @@ def test_eq(self):
6867
assert CGer(False) == CGer(False)
6968
assert CGer(False) != CGer(True)
7069

71-
assert CGer(True) != ScipyGer(True)
72-
assert CGer(False) != ScipyGer(False)
7370
assert CGer(True) != Ger(True)
7471
assert CGer(False) != Ger(False)
7572

tests/tensor/test_blas_scipy.py

Lines changed: 0 additions & 75 deletions
This file was deleted.

0 commit comments

Comments
 (0)