Skip to content

Commit f8c9d7e

Browse files
Improving/debugging solve tests
1 parent bc9e7c0 commit f8c9d7e

File tree

1 file changed

+86
-60
lines changed

1 file changed

+86
-60
lines changed

tests/link/numba/test_slinalg.py

Lines changed: 86 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import re
22
from functools import partial
3+
from typing import Literal
34

45
import numpy as np
56
import pytest
67
from numpy.testing import assert_allclose
8+
from scipy import linalg as scipy_linalg
79

810
import pytensor
911
import pytensor.tensor as pt
@@ -31,59 +33,79 @@ def transpose_func(x, trans):
3133

3234

3335
@pytest.mark.parametrize(
34-
"b_func, b_size",
35-
[(pt.matrix, (5, 1)), (pt.matrix, (5, 5)), (pt.vector, (5,))],
36+
"b_shape",
37+
[(5, 1), (5, 5), (5,)],
3638
ids=["b_col_vec", "b_matrix", "b_vec"],
3739
)
3840
@pytest.mark.parametrize("lower", [True, False], ids=["lower=True", "lower=False"])
3941
@pytest.mark.parametrize("trans", [0, 1, 2], ids=["trans=N", "trans=C", "trans=T"])
4042
@pytest.mark.parametrize(
4143
"unit_diag", [True, False], ids=["unit_diag=True", "unit_diag=False"]
4244
)
43-
@pytest.mark.parametrize("complex", [True, False], ids=["complex", "real"])
45+
@pytest.mark.parametrize("is_complex", [True, False], ids=["complex", "real"])
4446
@pytest.mark.filterwarnings(
4547
'ignore:Cannot cache compiled function "numba_funcified_fgraph"'
4648
)
47-
def test_solve_triangular(b_func, b_size, lower, trans, unit_diag, complex):
48-
if complex:
49+
def test_solve_triangular(b_shape: tuple[int], lower, trans, unit_diag, is_complex):
50+
if is_complex:
4951
# TODO: Complex raises ValueError: To change to a dtype of a different size, the last axis must be contiguous,
5052
# why?
5153
pytest.skip("Complex inputs currently not supported to solve_triangular")
5254

5355
complex_dtype = "complex64" if floatX.endswith("32") else "complex128"
54-
dtype = complex_dtype if complex else floatX
56+
dtype = complex_dtype if is_complex else floatX
5557

5658
A = pt.matrix("A", dtype=dtype)
57-
b = b_func("b", dtype=dtype)
59+
b = pt.tensor("b", shape=b_shape, dtype=dtype)
60+
61+
def A_func(x):
62+
x = x @ x.conj().T
63+
x_tri = scipy_linalg.cholesky(x, lower=lower).astype(dtype)
64+
65+
if unit_diag:
66+
x_tri[np.diag_indices_from(x_tri)] = 1.0
5867

59-
X = pt.linalg.solve_triangular(
60-
A, b, lower=lower, trans=trans, unit_diagonal=unit_diag
68+
return x_tri.astype(dtype)
69+
70+
solve_op = partial(
71+
pt.linalg.solve_triangular, lower=lower, trans=trans, unit_diagonal=unit_diag
6172
)
73+
74+
X = solve_op(A, b)
6275
f = pytensor.function([A, b], X, mode="NUMBA")
6376

6477
A_val = np.random.normal(size=(5, 5))
65-
b_val = np.random.normal(size=b_size)
78+
b_val = np.random.normal(size=b_shape)
6679

67-
if complex:
80+
if is_complex:
6881
A_val = A_val + np.random.normal(size=(5, 5)) * 1j
69-
b_val = b_val + np.random.normal(size=b_size) * 1j
70-
A_sym = A_val @ A_val.conj().T
82+
b_val = b_val + np.random.normal(size=b_shape) * 1j
7183

72-
A_tri = np.linalg.cholesky(A_sym).astype(dtype)
73-
if unit_diag:
74-
adj_mat = np.ones((5, 5))
75-
adj_mat[np.diag_indices(5)] = 1 / np.diagonal(A_tri)
76-
A_tri = A_tri * adj_mat
84+
X_np = f(A_func(A_val.copy()), b_val.copy())
7785

78-
A_tri = A_tri.astype(dtype)
79-
b_val = b_val.astype(dtype)
86+
test_input = transpose_func(A_func(A_val.copy()), trans)
87+
np.testing.assert_allclose(test_input @ X_np, b_val, atol=ATOL, rtol=RTOL)
8088

81-
if not lower:
82-
A_tri = A_tri.T
89+
compare_numba_and_py(f.maker.fgraph, [A_func(A_val.copy()), b_val.copy()])
8390

84-
X_np = f(A_tri, b_val)
85-
np.testing.assert_allclose(
86-
transpose_func(A_tri, trans) @ X_np, b_val, atol=ATOL, rtol=RTOL
91+
# utt.verify_grad uses small perturbations to the input matrix to calculate the finite difference gradient. When
92+
# a non-triangular matrix is passed to scipy.linalg.solve_triangular, no error is raise, but the result will be
93+
# wrong, resulting in wrong gradients. As a result, it is necessary to add a mapping from the space of all matrices
94+
# to the space of triangular matrices, and test the gradient of that entire graph.
95+
def A_func_pt(x):
96+
x = x @ x.conj().T
97+
x_tri = pt.linalg.cholesky(x, lower=lower).astype(dtype)
98+
99+
if unit_diag:
100+
n = A_val.shape[0]
101+
x_tri = x_tri[np.diag_indices(n)].set(1.0)
102+
103+
return transpose_func(x_tri.astype(dtype), trans)
104+
105+
utt.verify_grad(
106+
lambda A, b: solve_op(A_func_pt(A), b),
107+
[A_val.copy(), b_val.copy()],
108+
mode="NUMBA",
87109
)
88110

89111

@@ -166,7 +188,8 @@ def test_numba_Cholesky_grad(lower, trans):
166188
L = rng.normal(size=(5, 5)).astype(floatX)
167189
X = L @ L.T
168190

169-
utt.verify_grad(pt.linalg.cholesky, [X])
191+
chol_op = partial(pt.linalg.cholesky, lower=lower, trans=trans)
192+
utt.verify_grad(chol_op, [X], mode="NUMBA")
170193

171194

172195
def test_block_diag():
@@ -319,69 +342,72 @@ def lu_solve(a, b, trans, overwrite_a, overwrite_b):
319342

320343

321344
@pytest.mark.parametrize(
322-
"b_func, b_size",
323-
[(pt.matrix, (5, 1)), (pt.matrix, (5, 5)), (pt.vector, (5,))],
345+
"b_shape",
346+
[(5, 1), (5, 5), (5,)],
324347
ids=["b_col_vec", "b_matrix", "b_vec"],
325348
)
326349
@pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str)
327-
@pytest.mark.parametrize("transposed", [True, False], ids=["trans", "no_trans"])
328350
@pytest.mark.filterwarnings(
329351
'ignore:Cannot cache compiled function "numba_funcified_fgraph"'
330352
)
331-
def test_solve(b_func, b_size, assume_a, transposed):
353+
def test_solve(b_shape: tuple[int], assume_a: Literal["gen", "sym", "pos"]):
332354
A = pt.matrix("A", dtype=floatX)
333-
b = b_func("b", dtype=floatX)
355+
b = pt.tensor("b", shape=b_shape, dtype=floatX)
356+
357+
A_val = np.asfortranarray(np.random.normal(size=(5, 5)).astype(floatX))
358+
b_val = np.asfortranarray(np.random.normal(size=b_shape).astype(floatX))
359+
360+
def A_func(x):
361+
if assume_a == "pos":
362+
x = x.T @ x
363+
elif assume_a == "sym":
364+
x = (x.T + x) / 2
365+
366+
return x
334367

335368
X = pt.linalg.solve(
336-
A,
369+
A_func(A),
337370
b,
338-
lower=False,
339371
assume_a=assume_a,
340-
transposed=transposed,
341-
b_ndim=len(b_size),
372+
b_ndim=len(b_shape),
342373
)
343374
f = pytensor.function(
344375
[pytensor.In(A, mutable=True), pytensor.In(b, mutable=True)], X, mode="NUMBA"
345376
)
377+
op = f.maker.fgraph.outputs[0].owner.op
346378

347-
A_val = np.random.normal(size=(5, 5)).astype(floatX)
348-
349-
if assume_a in ["sym", "pos"]:
350-
A_val = A_val @ A_val.conj().T
351-
A_val = np.asfortranarray(A_val)
352-
353-
b_val = np.random.normal(size=b_size)
354-
b_val = b_val.astype(floatX)
355-
b_val = np.asfortranarray(b_val)
379+
compare_numba_and_py(f.maker.fgraph, inputs=[A_func(A_val.copy()), b_val.copy()])
356380

381+
# Calling this is destructive and will rewrite b_val to be the answer. Store copies of the inputs first.
357382
A_val_copy = A_val.copy()
358383
b_val_copy = b_val.copy()
359384

360-
X_np = f(A_val, b_val)
361-
op = f.maker.fgraph.outputs[0].owner.op
385+
X_np = f(A_func(A_val), b_val)
362386

363387
# overwrite_b is preferred when both inputs can be destroyed
364388
assert op.destroy_map == {0: [1]}
365389

366-
# Test that the result is numerically correct
367-
np.testing.assert_allclose(
368-
transpose_func(A_val_copy, transposed) @ X_np, b_val_copy, atol=ATOL, rtol=RTOL
369-
)
370-
371-
# Confirm input was destroyed
390+
# Confirm inputs were destroyed by checking against the copies
372391
assert (A_val == A_val_copy).all() == (op.destroy_map.get(0, None) != [0])
373392
assert (b_val == b_val_copy).all() == (op.destroy_map.get(0, None) != [1])
374393

375-
# Test gradients
376-
solve = partial(
377-
pt.linalg.solve,
378-
lower=False,
379-
assume_a=assume_a,
380-
transposed=transposed,
381-
b_ndim=len(b_size),
394+
# Confirm b_val is used to store to solution
395+
np.testing.assert_allclose(X_np, b_val, atol=ATOL, rtol=RTOL)
396+
assert not np.allclose(b_val, b_val_copy)
397+
398+
# Test that the result is numerically correct. Need to use the unmodified copy
399+
np.testing.assert_allclose(
400+
A_func(A_val_copy) @ X_np, b_val_copy, atol=ATOL, rtol=RTOL
382401
)
383402

384-
utt.verify_grad(solve, [A_val_copy, b_val_copy], mode="NUMBA")
403+
# See the note in tensor/test_slinalg.py::test_solve_correctness for details about the setup here
404+
utt.verify_grad(
405+
lambda A, b: pt.linalg.solve(
406+
A_func(A), b, lower=False, assume_a=assume_a, b_ndim=len(b_shape)
407+
),
408+
[A_val_copy, b_val_copy],
409+
mode="NUMBA",
410+
)
385411

386412

387413
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)