Skip to content

Commit 701236c

Browse files
Use alias for scipy.linalg (improves linting)
1 parent 93438dd commit 701236c

File tree

1 file changed

+18
-18
lines changed

1 file changed

+18
-18
lines changed

pytensor/tensor/slinalg.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import Literal, cast
66

77
import numpy as np
8-
import scipy.linalg
8+
import scipy.linalg as scipy_linalg
99

1010
import pytensor
1111
import pytensor.tensor as pt
@@ -58,7 +58,7 @@ def make_node(self, x):
5858
f"Cholesky only allowed on matrix (2-D) inputs, got {x.type.ndim}-D input"
5959
)
6060
# Call scipy to find output dtype
61-
dtype = scipy.linalg.cholesky(np.eye(1, dtype=x.type.dtype)).dtype
61+
dtype = scipy_linalg.cholesky(np.eye(1, dtype=x.type.dtype)).dtype
6262
return Apply(self, [x], [tensor(shape=x.type.shape, dtype=dtype)])
6363

6464
def perform(self, node, inputs, outputs):
@@ -68,21 +68,21 @@ def perform(self, node, inputs, outputs):
6868
# Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS
6969
# If we have a `C_CONTIGUOUS` array we transpose to benefit from it
7070
if self.overwrite_a and x.flags["C_CONTIGUOUS"]:
71-
out[0] = scipy.linalg.cholesky(
71+
out[0] = scipy_linalg.cholesky(
7272
x.T,
7373
lower=not self.lower,
7474
check_finite=self.check_finite,
7575
overwrite_a=True,
7676
).T
7777
else:
78-
out[0] = scipy.linalg.cholesky(
78+
out[0] = scipy_linalg.cholesky(
7979
x,
8080
lower=self.lower,
8181
check_finite=self.check_finite,
8282
overwrite_a=self.overwrite_a,
8383
)
8484

85-
except scipy.linalg.LinAlgError:
85+
except scipy_linalg.LinAlgError:
8686
if self.on_error == "raise":
8787
raise
8888
else:
@@ -334,7 +334,7 @@ def __init__(self, **kwargs):
334334

335335
def perform(self, node, inputs, output_storage):
336336
C, b = inputs
337-
rval = scipy.linalg.cho_solve(
337+
rval = scipy_linalg.cho_solve(
338338
(C, self.lower),
339339
b,
340340
check_finite=self.check_finite,
@@ -401,7 +401,7 @@ def __init__(self, *, trans=0, unit_diagonal=False, **kwargs):
401401

402402
def perform(self, node, inputs, outputs):
403403
A, b = inputs
404-
outputs[0][0] = scipy.linalg.solve_triangular(
404+
outputs[0][0] = scipy_linalg.solve_triangular(
405405
A,
406406
b,
407407
lower=self.lower,
@@ -502,7 +502,7 @@ def __init__(self, *, assume_a="gen", **kwargs):
502502

503503
def perform(self, node, inputs, outputs):
504504
a, b = inputs
505-
outputs[0][0] = scipy.linalg.solve(
505+
outputs[0][0] = scipy_linalg.solve(
506506
a=a,
507507
b=b,
508508
lower=self.lower,
@@ -619,9 +619,9 @@ def make_node(self, a, b):
619619
def perform(self, node, inputs, outputs):
620620
(w,) = outputs
621621
if len(inputs) == 2:
622-
w[0] = scipy.linalg.eigvalsh(a=inputs[0], b=inputs[1], lower=self.lower)
622+
w[0] = scipy_linalg.eigvalsh(a=inputs[0], b=inputs[1], lower=self.lower)
623623
else:
624-
w[0] = scipy.linalg.eigvalsh(a=inputs[0], b=None, lower=self.lower)
624+
w[0] = scipy_linalg.eigvalsh(a=inputs[0], b=None, lower=self.lower)
625625

626626
def grad(self, inputs, g_outputs):
627627
a, b = inputs
@@ -675,7 +675,7 @@ def make_node(self, a, b, gw):
675675

676676
def perform(self, node, inputs, outputs):
677677
(a, b, gw) = inputs
678-
w, v = scipy.linalg.eigh(a, b, lower=self.lower)
678+
w, v = scipy_linalg.eigh(a, b, lower=self.lower)
679679
gA = v.dot(np.diag(gw).dot(v.T))
680680
gB = -v.dot(np.diag(gw * w).dot(v.T))
681681

@@ -718,7 +718,7 @@ def make_node(self, A):
718718
def perform(self, node, inputs, outputs):
719719
(A,) = inputs
720720
(expm,) = outputs
721-
expm[0] = scipy.linalg.expm(A)
721+
expm[0] = scipy_linalg.expm(A)
722722

723723
def grad(self, inputs, outputs):
724724
(A,) = inputs
@@ -758,8 +758,8 @@ def perform(self, node, inputs, outputs):
758758
# this expression.
759759
(A, gA) = inputs
760760
(out,) = outputs
761-
w, V = scipy.linalg.eig(A, right=True)
762-
U = scipy.linalg.inv(V).T
761+
w, V = scipy_linalg.eig(A, right=True)
762+
U = scipy_linalg.inv(V).T
763763

764764
exp_w = np.exp(w)
765765
X = np.subtract.outer(exp_w, exp_w) / np.subtract.outer(w, w)
@@ -800,7 +800,7 @@ def perform(self, node, inputs, output_storage):
800800
X = output_storage[0]
801801

802802
out_dtype = node.outputs[0].type.dtype
803-
X[0] = scipy.linalg.solve_continuous_lyapunov(A, B).astype(out_dtype)
803+
X[0] = scipy_linalg.solve_continuous_lyapunov(A, B).astype(out_dtype)
804804

805805
def infer_shape(self, fgraph, node, shapes):
806806
return [shapes[0]]
@@ -870,7 +870,7 @@ def perform(self, node, inputs, output_storage):
870870
X = output_storage[0]
871871

872872
out_dtype = node.outputs[0].type.dtype
873-
X[0] = scipy.linalg.solve_discrete_lyapunov(A, B, method="bilinear").astype(
873+
X[0] = scipy_linalg.solve_discrete_lyapunov(A, B, method="bilinear").astype(
874874
out_dtype
875875
)
876876

@@ -992,7 +992,7 @@ def perform(self, node, inputs, output_storage):
992992
Q = 0.5 * (Q + Q.T)
993993

994994
out_dtype = node.outputs[0].type.dtype
995-
X[0] = scipy.linalg.solve_discrete_are(A, B, Q, R).astype(out_dtype)
995+
X[0] = scipy_linalg.solve_discrete_are(A, B, Q, R).astype(out_dtype)
996996

997997
def infer_shape(self, fgraph, node, shapes):
998998
return [shapes[0]]
@@ -1118,7 +1118,7 @@ def make_node(self, *matrices):
11181118

11191119
def perform(self, node, inputs, output_storage, params=None):
11201120
dtype = node.outputs[0].type.dtype
1121-
output_storage[0][0] = scipy.linalg.block_diag(*inputs).astype(dtype)
1121+
output_storage[0][0] = scipy_linalg.block_diag(*inputs).astype(dtype)
11221122

11231123

11241124
def block_diag(*matrices: TensorVariable):

0 commit comments

Comments
 (0)