Skip to content

Commit c5f7bec

Browse files
Revert adding transpose argument to Solve
1 parent 86c5539 commit c5f7bec

File tree

1 file changed

+2
-187
lines changed

1 file changed

+2
-187
lines changed

pytensor/tensor/slinalg.py

Lines changed: 2 additions & 187 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import logging
22
import typing
33
import warnings
4-
from collections.abc import Sequence
54
from functools import reduce
65
from typing import Literal, cast
76

@@ -10,8 +9,6 @@
109

1110
import pytensor
1211
import pytensor.tensor as pt
13-
from pytensor import Variable
14-
from pytensor.gradient import DisconnectedType
1512
from pytensor.graph.basic import Apply
1613
from pytensor.graph.op import Op
1714
from pytensor.tensor import TensorLike, as_tensor_variable
@@ -28,6 +25,8 @@
2825

2926

3027
class Cholesky(Op):
28+
# TODO: LAPACK wrapper with in-place behavior, for solve also
29+
3130
__props__ = ("lower", "check_finite", "on_error", "overwrite_a")
3231
gufunc_signature = "(m,m)->(m,m)"
3332

@@ -397,186 +396,6 @@ def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: int | None = None):
397396
)(A, b)
398397

399398

400-
class LU(Op):
401-
"""Decompose a matrix into lower and upper triangular matrices."""
402-
403-
__props__ = ("permute_l", "overwrite_a", "check_finite", "p_indices")
404-
405-
def __init__(
406-
self, *, permute_l=False, overwrite_a=False, check_finite=True, p_indices=False
407-
):
408-
self.permute_l = permute_l
409-
self.check_finite = check_finite
410-
self.p_indices = p_indices
411-
self.overwrite_a = overwrite_a
412-
413-
if self.permute_l:
414-
# permute_l overrides p_indices in the scipy function. We can copy that behavior
415-
self.gufunc_signature = "(m,m)->(m,m),(m,m)"
416-
elif self.p_indices:
417-
self.gufunc_signature = "(m,m)->(m),(m,m),(m,m)"
418-
else:
419-
self.gufunc_signature = "(m,m)->(m,m),(m,m),(m,m)"
420-
421-
if self.overwrite_a:
422-
self.destroy_map = {0: [0]}
423-
424-
def infer_shape(self, fgraph, node, shapes):
425-
n = shapes[0][0]
426-
if self.permute_l:
427-
return [(n, n), (n, n)]
428-
elif self.p_indices:
429-
return [(n,), (n, n), (n, n)]
430-
else:
431-
return [(n, n), (n, n), (n, n)]
432-
433-
def make_node(self, x):
434-
x = as_tensor_variable(x)
435-
if x.type.ndim != 2:
436-
raise TypeError(
437-
f"LU only allowed on matrix (2-D) inputs, got {x.type.ndim}-D input"
438-
)
439-
440-
real_dtype = "f" if np.dtype(x.type.dtype).char in "fF" else "d"
441-
p_dtype = "int32" if self.p_indices else np.dtype(real_dtype)
442-
443-
L = tensor(shape=x.type.shape, dtype=real_dtype)
444-
U = tensor(shape=x.type.shape, dtype=real_dtype)
445-
446-
if self.permute_l:
447-
# In this case, L is actually P @ L
448-
return Apply(self, inputs=[x], outputs=[L, U])
449-
elif self.p_indices:
450-
p = tensor(shape=(x.type.shape[0],), dtype=p_dtype)
451-
return Apply(self, inputs=[x], outputs=[p, L, U])
452-
else:
453-
P = tensor(shape=x.type.shape, dtype=p_dtype)
454-
return Apply(self, inputs=[x], outputs=[P, L, U])
455-
456-
def perform(self, node, inputs, outputs):
457-
[A] = inputs
458-
459-
out = scipy.linalg.lu(
460-
A,
461-
permute_l=self.permute_l,
462-
overwrite_a=self.overwrite_a,
463-
check_finite=self.check_finite,
464-
p_indices=self.p_indices,
465-
)
466-
467-
outputs[0][0] = out[0]
468-
outputs[1][0] = out[1]
469-
470-
if not self.permute_l:
471-
# In all cases except permute_l, there are three returns
472-
outputs[2][0] = out[2]
473-
474-
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
475-
if 0 in allowed_inplace_inputs:
476-
new_props = self._props_dict() # type: ignore
477-
new_props["overwrite_a"] = True
478-
return type(self)(**new_props)
479-
else:
480-
return self
481-
482-
def L_op(
483-
self,
484-
inputs: Sequence[Variable],
485-
outputs: Sequence[Variable],
486-
output_grads: Sequence[Variable],
487-
) -> list[Variable]:
488-
r"""
489-
Derivation is due to Differentiation of Matrix Functionals Using Triangular Factorization
490-
F. R. De Hoog, R.S. Anderssen, M. A. Lukas
491-
"""
492-
[A] = inputs
493-
A = cast(TensorVariable, A)
494-
495-
if self.permute_l:
496-
PL_bar, U_bar = output_grads
497-
498-
# TODO: Rewrite into permute_l = False for graphs where we need to compute the gradient
499-
P, L, U = lu( # type: ignore
500-
A, permute_l=False, check_finite=self.check_finite, p_indices=False
501-
)
502-
503-
# Permutation matrix is orthogonal
504-
L_bar = (
505-
P.T @ PL_bar
506-
if not isinstance(PL_bar.type, DisconnectedType)
507-
else pt.zeros_like(A)
508-
)
509-
510-
elif self.p_indices:
511-
p, L, U = outputs
512-
513-
# TODO: rewrite to p_indices = False for graphs where we need to compute the gradient
514-
P = pt.eye(A.shape[0])[p]
515-
_, L_bar, U_bar = output_grads
516-
else:
517-
P, L, U = outputs
518-
_, L_bar, U_bar = output_grads
519-
520-
L_bar = (
521-
L_bar if not isinstance(L_bar.type, DisconnectedType) else pt.zeros_like(A)
522-
)
523-
U_bar = (
524-
U_bar if not isinstance(U_bar.type, DisconnectedType) else pt.zeros_like(A)
525-
)
526-
527-
x1 = ptb.tril(L.T @ L_bar, k=-1)
528-
x2 = ptb.triu(U_bar @ U.T)
529-
530-
L_inv_x = solve_triangular(L.T, x1 + x2, lower=False, unit_diagonal=True)
531-
A_bar = P @ solve_triangular(U, L_inv_x.T, lower=False).T
532-
533-
return [A_bar]
534-
535-
536-
def lu(
537-
a: TensorLike, permute_l=False, check_finite=True, p_indices=False
538-
) -> (
539-
tuple[TensorVariable, TensorVariable, TensorVariable]
540-
| tuple[TensorVariable, TensorVariable]
541-
):
542-
"""
543-
Factorize a matrix as the product of a unit lower triangular matrix and an upper triangular matrix:
544-
545-
... math::
546-
547-
A = P L U
548-
549-
Where P is a permutation matrix, L is lower triangular with unit diagonal elements, and U is upper triangular.
550-
551-
Parameters
552-
----------
553-
a: TensorLike
554-
Matrix to be factorized
555-
permute_l: bool
556-
If True, L is a product of permutation and unit lower triangular matrices. Only two values, PL and U, will
557-
be returned in this case, and PL will not be lower triangular.
558-
check_finite: bool
559-
Whether to check that the input matrix contains only finite numbers.
560-
p_indices: bool
561-
If True, return integer matrix indices for the permutation matrix. Otherwise, return the permutation matrix
562-
itself.
563-
564-
Returns
565-
-------
566-
P: TensorVariable
567-
Permutation matrix, or array of integer indices for permutation matrix. Not returned if permute_l is True.
568-
L: TensorVariable
569-
Lower triangular matrix, or product of permutation and unit lower triangular matrices if permute_l is True.
570-
U: TensorVariable
571-
Upper triangular matrix
572-
"""
573-
return cast(
574-
tuple[TensorVariable, TensorVariable, TensorVariable]
575-
| tuple[TensorVariable, TensorVariable],
576-
LU(permute_l=permute_l, check_finite=check_finite, p_indices=p_indices)(a),
577-
)
578-
579-
580399
class SolveTriangular(SolveBase):
581400
"""Solve a system of linear equations."""
582401

@@ -734,7 +553,6 @@ def solve(
734553
assume_a="gen",
735554
lower=False,
736555
check_finite=True,
737-
transposed=False,
738556
b_ndim: int | None = None,
739557
):
740558
"""Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix.
@@ -772,8 +590,6 @@ def solve(
772590
(crashes, non-termination) if the inputs do contain infinities or NaNs.
773591
assume_a : str, optional
774592
Valid entries are explained above.
775-
transposed: bool, optional
776-
If True, solve ``A.T @ x = b``
777593
b_ndim : int
778594
Whether the core case of b is a vector (1) or matrix (2).
779595
This will influence how batched dimensions are interpreted.
@@ -785,7 +601,6 @@ def solve(
785601
check_finite=check_finite,
786602
assume_a=assume_a,
787603
b_ndim=b_ndim,
788-
transposed=transposed,
789604
)
790605
)(a, b)
791606

0 commit comments

Comments
 (0)