Skip to content

Commit fbdf806

Browse files
Add transposed argument to solve and solve_triangular
1 parent 757a10c commit fbdf806

File tree

4 files changed

+47
-22
lines changed

4 files changed

+47
-22
lines changed

pytensor/link/jax/dispatch/slinalg.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ def solve(a, b, lower=lower):
5353
@jax_funcify.register(SolveTriangular)
5454
def jax_funcify_SolveTriangular(op, **kwargs):
5555
lower = op.lower
56-
trans = op.trans
5756
unit_diagonal = op.unit_diagonal
5857
check_finite = op.check_finite
5958

@@ -62,7 +61,7 @@ def solve_triangular(A, b):
6261
A,
6362
b,
6463
lower=lower,
65-
trans=trans,
64+
trans=0, # this is handled by explicitly transposing A, so it will always be 0 when we get to here.
6665
unit_diagonal=unit_diagonal,
6766
check_finite=check_finite,
6867
)

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,6 @@ def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
180180

181181
@numba_funcify.register(SolveTriangular)
182182
def numba_funcify_SolveTriangular(op, node, **kwargs):
183-
trans = bool(op.trans)
184183
lower = op.lower
185184
unit_diagonal = op.unit_diagonal
186185
check_finite = op.check_finite
@@ -208,7 +207,7 @@ def solve_triangular(a, b):
208207
res = _solve_triangular(
209208
a,
210209
b,
211-
trans=trans,
210+
trans=0, # transposing is handled explicitly on the graph, so we never use this argument
212211
lower=lower,
213212
unit_diagonal=unit_diagonal,
214213
overwrite_b=overwrite_b,

pytensor/tensor/slinalg.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -296,13 +296,12 @@ def L_op(self, inputs, outputs, output_gradients):
296296
# We need to return (dC/d[inv(A)], dC/db)
297297
c_bar = output_gradients[0]
298298

299-
trans_solve_op = type(self)(
300-
**{
301-
k: (not getattr(self, k) if k == "lower" else getattr(self, k))
302-
for k in self.__props__
303-
}
304-
)
305-
b_bar = trans_solve_op(A.T, c_bar)
299+
props_dict = self._props_dict()
300+
props_dict["lower"] = not self.lower
301+
302+
solve_op = type(self)(**props_dict)
303+
304+
b_bar = solve_op(A.T, c_bar)
306305
# force outer product if vector second input
307306
A_bar = -ptm.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T)
308307

@@ -385,19 +384,17 @@ class SolveTriangular(SolveBase):
385384
"""Solve a system of linear equations."""
386385

387386
__props__ = (
388-
"trans",
389387
"unit_diagonal",
390388
"lower",
391389
"check_finite",
392390
"b_ndim",
393391
"overwrite_b",
394392
)
395393

396-
def __init__(self, *, trans=0, unit_diagonal=False, **kwargs):
394+
def __init__(self, *, unit_diagonal=False, **kwargs):
397395
if kwargs.get("overwrite_a", False):
398396
raise ValueError("overwrite_a is not supported for SolverTriangulare")
399397
super().__init__(**kwargs)
400-
self.trans = trans
401398
self.unit_diagonal = unit_diagonal
402399

403400
def perform(self, node, inputs, outputs):
@@ -406,7 +403,7 @@ def perform(self, node, inputs, outputs):
406403
A,
407404
b,
408405
lower=self.lower,
409-
trans=self.trans,
406+
trans=0,
410407
unit_diagonal=self.unit_diagonal,
411408
check_finite=self.check_finite,
412409
overwrite_b=self.overwrite_b,
@@ -445,9 +442,9 @@ def solve_triangular(
445442
446443
Parameters
447444
----------
448-
a
445+
a: TensorVariable
449446
Square input data
450-
b
447+
b: TensorVariable
451448
Input data for the right hand side.
452449
lower : bool, optional
453450
Use only data contained in the lower triangle of `a`. Default is to use upper triangle.
@@ -468,10 +465,17 @@ def solve_triangular(
468465
This will influence how batched dimensions are interpreted.
469466
"""
470467
b_ndim = _default_b_ndim(b, b_ndim)
468+
469+
if trans in [1, "T", True]:
470+
a = a.mT
471+
lower = not lower
472+
if trans in [2, "C"]:
473+
a = a.conj().mT
474+
lower = not lower
475+
471476
ret = Blockwise(
472477
SolveTriangular(
473478
lower=lower,
474-
trans=trans,
475479
unit_diagonal=unit_diagonal,
476480
check_finite=check_finite,
477481
b_ndim=b_ndim,
@@ -534,6 +538,7 @@ def solve(
534538
*,
535539
assume_a="gen",
536540
lower=False,
541+
transposed=False,
537542
check_finite=True,
538543
b_ndim: int | None = None,
539544
):
@@ -564,8 +569,10 @@ def solve(
564569
b : (..., N, NRHS) array_like
565570
Input data for the right hand side.
566571
lower : bool, optional
567-
If True, only the data contained in the lower triangle of `a`. Default
572+
If True, use only the data contained in the lower triangle of `a`. Default
568573
is to use upper triangle. (ignored for ``'gen'``)
574+
transposed: bool, optional
575+
If True, solves the system A^T x = b. Default is False.
569576
check_finite : bool, optional
570577
Whether to check that the input matrices contain only finite numbers.
571578
Disabling may give a performance gain, but may result in problems
@@ -577,6 +584,11 @@ def solve(
577584
This will influence how batched dimensions are interpreted.
578585
"""
579586
b_ndim = _default_b_ndim(b, b_ndim)
587+
588+
if transposed:
589+
a = a.mT
590+
lower = not lower
591+
580592
return Blockwise(
581593
Solve(
582594
lower=lower,

tests/tensor/test_slinalg.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def test__repr__(self):
205205
y = self.SolveTest(b_ndim=2)(A, b)
206206
assert (
207207
y.__repr__()
208-
== "SolveTest{lower=False, check_finite=True, b_ndim=2, overwrite_a=False, overwrite_b=False}.0"
208+
== "SolveTest{lower=False, transposed=False, check_finite=True, b_ndim=2, overwrite_a=False, overwrite_b=False}.0"
209209
)
210210

211211

@@ -275,11 +275,26 @@ def A_func(x):
275275
@pytest.mark.parametrize(
276276
"b_size", [(5, 1), (5, 5), (5,)], ids=["b_col_vec", "b_matrix", "b_vec"]
277277
)
278-
@pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str)
278+
@pytest.mark.parametrize(
279+
"assume_a, lower, transposed",
280+
[
281+
("gen", False, False),
282+
("gen", False, True),
283+
("sym", False, False),
284+
("sym", True, False),
285+
("sym", True, True),
286+
("pos", False, False),
287+
("pos", True, False),
288+
("pos", True, True),
289+
],
290+
ids=str,
291+
)
279292
@pytest.mark.skipif(
280293
config.floatX == "float32", reason="Gradients not numerically stable in float32"
281294
)
282-
def test_solve_gradient(self, b_size: tuple[int], assume_a: str):
295+
def test_solve_gradient(
296+
self, b_size: tuple[int], assume_a: str, lower: bool, transposed: bool
297+
):
283298
rng = np.random.default_rng(utt.fetch_seed())
284299

285300
eps = 2e-8 if config.floatX == "float64" else None

0 commit comments

Comments
 (0)