Skip to content

Commit b7aa9f8

Browse files
remove unused code
1 parent c9af37a commit b7aa9f8

File tree

2 files changed

+4
-34
lines changed

2 files changed

+4
-34
lines changed

pytensor/tensor/slinalg.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import pytensor.tensor as pt
1313
from pytensor.compile.builders import OpFromGraph
1414
from pytensor.gradient import DisconnectedType
15-
from pytensor.graph.basic import Apply
15+
from pytensor.graph.basic import Apply, Variable
1616
from pytensor.graph.op import Op
1717
from pytensor.tensor import TensorLike, as_tensor_variable
1818
from pytensor.tensor import basic as ptb
@@ -741,7 +741,8 @@ class LUSolve(OpFromGraph):
741741

742742
def __init__(
743743
self,
744-
*args,
744+
inputs: list[Variable],
745+
outputs: list[Variable],
745746
trans: bool = False,
746747
b_ndim: int | None = None,
747748
check_finite: bool = False,
@@ -753,7 +754,7 @@ def __init__(
753754
self.check_finite = check_finite
754755
self.overwrite_b = overwrite_b
755756

756-
super().__init__(*args, **kwargs)
757+
super().__init__(inputs=inputs, outputs=outputs, **kwargs)
757758

758759

759760
def lu_solve(

tests/tensor/test_slinalg.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from pytensor.tensor.slinalg import (
1616
Cholesky,
1717
CholeskySolve,
18-
LUSolve,
1918
Solve,
2019
SolveBase,
2120
SolveTriangular,
@@ -1033,33 +1032,3 @@ def test_block_diagonal_blockwise():
10331032
B = np.random.normal(size=(1, batch_size, 4, 4)).astype(config.floatX)
10341033
result = block_diag(A, B).eval()
10351034
assert result.shape == (10, batch_size, 6, 6)
1036-
1037-
1038-
def lu_solve_1(A, b):
1039-
lu, pivots = pt.linalg.lu_factor(A)
1040-
return pt.linalg.lu_solve((lu, pivots), b)
1041-
1042-
1043-
def lu_solve_2(A, b, b_ndim=1, trans=0, check_finite=False):
1044-
lu, pivots = pt.linalg.lu_factor(A)
1045-
return LUSolve(b_ndim=1, trans=0, check_finite=False)(lu, pivots, b)
1046-
1047-
1048-
@pytest.mark.parametrize(
1049-
"op", [lu_solve_1, lu_solve_2, pt.linalg.solve], ids=["lu_1", "lu_2", "solve"]
1050-
)
1051-
@pytest.mark.parametrize("n", [500])
1052-
def test_solve_methods(op, n, benchmark):
1053-
A = pt.tensor("A", shape=(n, n))
1054-
b = pt.tensor("b", shape=(n,))
1055-
1056-
x = op(A, b)
1057-
gx = pt.grad(x.sum(), [A, b])
1058-
f = pytensor.function([A, b], [x, *gx])
1059-
1060-
A_val = np.random.normal(size=(n, n)).astype(config.floatX)
1061-
b_val = np.random.normal(size=(n,)).astype(config.floatX)
1062-
1063-
# Trigger compilation if we're a jit mode
1064-
f(A_val, b_val)
1065-
benchmark(f, A_val, b_val)

0 commit comments

Comments
 (0)