Skip to content

Commit c8d1fe1

Browse files
Add logic for SolveTriangular trans argument in SolveBase.L_op and expand test coverage
1 parent e2f87e8 commit c8d1fe1

File tree

2 files changed

+28
-25
lines changed

2 files changed

+28
-25
lines changed

pytensor/tensor/slinalg.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,17 @@ def L_op(self, inputs, outputs, output_gradients):
300300
c_bar = output_gradients[0]
301301

302302
props_dict = self._props_dict()
303-
props_dict["transposed"] = not self.transposed
303+
304+
if isinstance(self, SolveTriangular):
305+
# SolveTriangular has a special trans argument we have to handle
306+
transposed = props_dict.pop("trans") in [1, "T"]
307+
props_dict["trans"] = not transposed
308+
else:
309+
transposed = props_dict.pop("transposed")
310+
props_dict["transposed"] = not transposed
311+
312+
# TODO: We were flipping lower before, but it doesn't appear we need to -- all tests pass without taking it into
313+
# account.
304314
# props_dict['lower'] = not self.lower
305315

306316
solve_op = type(self)(**props_dict)
@@ -309,7 +319,7 @@ def L_op(self, inputs, outputs, output_gradients):
309319
# force outer product if vector second input
310320
A_bar = -ptm.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T)
311321

312-
if self.transposed:
322+
if transposed:
313323
A_bar = A_bar.T
314324

315325
return [A_bar, b_bar]
@@ -402,9 +412,10 @@ class SolveTriangular(SolveBase):
402412
def __init__(self, *, trans=0, unit_diagonal=False, **kwargs):
403413
if kwargs.get("overwrite_a", False):
404414
raise ValueError("overwrite_a is not supported for SolverTriangulare")
415+
405416
super().__init__(**kwargs)
406-
self.trans = trans
407417
self.unit_diagonal = unit_diagonal
418+
self.trans = trans
408419

409420
def perform(self, node, inputs, outputs):
410421
A, b = inputs

tests/tensor/test_slinalg.py

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -367,56 +367,48 @@ def test_infer_shape(self, b_shape):
367367
warn=False,
368368
)
369369

370+
@pytest.mark.parametrize("b_shape", [(5, 1), (5,)])
370371
@pytest.mark.parametrize("lower", [True, False])
371-
def test_correctness(self, lower):
372+
@pytest.mark.parametrize("trans", ["N", "T"])
373+
def test_correctness(self, b_shape: tuple[int], lower, trans):
372374
rng = np.random.default_rng(utt.fetch_seed())
373375

374-
b_val = np.asarray(rng.random((5, 1)), dtype=config.floatX)
375-
376+
b_val = np.asarray(rng.random(b_shape), dtype=config.floatX)
376377
A_val = np.asarray(rng.random((5, 5)), dtype=config.floatX)
377378
A_val = np.dot(A_val.transpose(), A_val)
378379

379380
C_val = scipy.linalg.cholesky(A_val, lower=lower)
380381

381382
A = matrix()
382-
b = matrix()
383+
b = pt.tensor("b", shape=b_shape)
383384

384385
cholesky = Cholesky(lower=lower)
385386
C = cholesky(A)
386-
y_lower = solve_triangular(C, b, lower=lower)
387+
y_lower = solve_triangular(C, b, lower=lower, trans=trans)
387388
lower_solve_func = pytensor.function([C, b], y_lower)
388389

389390
assert np.allclose(
390-
scipy.linalg.solve_triangular(C_val, b_val, lower=lower),
391+
scipy.linalg.solve_triangular(C_val, b_val, lower=lower, trans=trans),
391392
lower_solve_func(C_val, b_val),
392393
)
393394

394-
@pytest.mark.parametrize(
395-
"m, n, lower",
396-
[
397-
(5, None, False),
398-
(5, None, True),
399-
(4, 2, False),
400-
(4, 2, True),
401-
],
402-
)
403-
def test_solve_grad(self, m, n, lower):
395+
@pytest.mark.parametrize("b_shape", [(5, 1), (5,)])
396+
@pytest.mark.parametrize("lower", [True, False])
397+
@pytest.mark.parametrize("trans", ["N", "T"])
398+
def test_solve_grad(self, b_shape: tuple[int], lower, trans):
404399
rng = np.random.default_rng(utt.fetch_seed())
400+
m = b_shape[0]
405401

406402
# Ensure diagonal elements of `A` are relatively large to avoid
407403
# numerical precision issues
408404
A_val = (rng.normal(size=(m, m)) * 0.5 + np.eye(m)).astype(config.floatX)
409-
410-
if n is None:
411-
b_val = rng.normal(size=m).astype(config.floatX)
412-
else:
413-
b_val = rng.normal(size=(m, n)).astype(config.floatX)
405+
b_val = rng.normal(size=b_shape).astype(config.floatX)
414406

415407
eps = None
416408
if config.floatX == "float64":
417409
eps = 2e-8
418410

419-
solve_op = SolveTriangular(lower=lower, b_ndim=1 if n is None else 2)
411+
solve_op = SolveTriangular(lower=lower, b_ndim=len(b_shape), trans=trans)
420412
utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps)
421413

422414

0 commit comments

Comments
 (0)