|
1 | 1 | import re |
2 | 2 | from functools import partial |
| 3 | +from typing import Literal |
3 | 4 |
|
4 | 5 | import numpy as np |
5 | 6 | import pytest |
6 | 7 | from numpy.testing import assert_allclose |
| 8 | +from scipy import linalg as scipy_linalg |
7 | 9 |
|
8 | 10 | import pytensor |
9 | 11 | import pytensor.tensor as pt |
@@ -31,59 +33,79 @@ def transpose_func(x, trans): |
31 | 33 |
|
32 | 34 |
|
33 | 35 | @pytest.mark.parametrize( |
34 | | - "b_func, b_size", |
35 | | - [(pt.matrix, (5, 1)), (pt.matrix, (5, 5)), (pt.vector, (5,))], |
| 36 | + "b_shape", |
| 37 | + [(5, 1), (5, 5), (5,)], |
36 | 38 | ids=["b_col_vec", "b_matrix", "b_vec"], |
37 | 39 | ) |
38 | 40 | @pytest.mark.parametrize("lower", [True, False], ids=["lower=True", "lower=False"]) |
39 | 41 | @pytest.mark.parametrize("trans", [0, 1, 2], ids=["trans=N", "trans=C", "trans=T"]) |
40 | 42 | @pytest.mark.parametrize( |
41 | 43 | "unit_diag", [True, False], ids=["unit_diag=True", "unit_diag=False"] |
42 | 44 | ) |
43 | | -@pytest.mark.parametrize("complex", [True, False], ids=["complex", "real"]) |
| 45 | +@pytest.mark.parametrize("is_complex", [True, False], ids=["complex", "real"]) |
44 | 46 | @pytest.mark.filterwarnings( |
45 | 47 | 'ignore:Cannot cache compiled function "numba_funcified_fgraph"' |
46 | 48 | ) |
47 | | -def test_solve_triangular(b_func, b_size, lower, trans, unit_diag, complex): |
48 | | - if complex: |
| 49 | +def test_solve_triangular(b_shape: tuple[int], lower, trans, unit_diag, is_complex): |
| 50 | + if is_complex: |
49 | 51 | # TODO: Complex raises ValueError: To change to a dtype of a different size, the last axis must be contiguous, |
50 | 52 | # why? |
51 | 53 | pytest.skip("Complex inputs currently not supported to solve_triangular") |
52 | 54 |
|
53 | 55 | complex_dtype = "complex64" if floatX.endswith("32") else "complex128" |
54 | | - dtype = complex_dtype if complex else floatX |
| 56 | + dtype = complex_dtype if is_complex else floatX |
55 | 57 |
|
56 | 58 | A = pt.matrix("A", dtype=dtype) |
57 | | - b = b_func("b", dtype=dtype) |
| 59 | + b = pt.tensor("b", shape=b_shape, dtype=dtype) |
| 60 | + |
| 61 | + def A_func(x): |
| 62 | + x = x @ x.conj().T |
| 63 | + x_tri = scipy_linalg.cholesky(x, lower=lower).astype(dtype) |
| 64 | + |
| 65 | + if unit_diag: |
| 66 | + x_tri[np.diag_indices_from(x_tri)] = 1.0 |
58 | 67 |
|
59 | | - X = pt.linalg.solve_triangular( |
60 | | - A, b, lower=lower, trans=trans, unit_diagonal=unit_diag |
| 68 | + return x_tri.astype(dtype) |
| 69 | + |
| 70 | + solve_op = partial( |
| 71 | + pt.linalg.solve_triangular, lower=lower, trans=trans, unit_diagonal=unit_diag |
61 | 72 | ) |
| 73 | + |
| 74 | + X = solve_op(A, b) |
62 | 75 | f = pytensor.function([A, b], X, mode="NUMBA") |
63 | 76 |
|
64 | 77 | A_val = np.random.normal(size=(5, 5)) |
65 | | - b_val = np.random.normal(size=b_size) |
| 78 | + b_val = np.random.normal(size=b_shape) |
66 | 79 |
|
67 | | - if complex: |
| 80 | + if is_complex: |
68 | 81 | A_val = A_val + np.random.normal(size=(5, 5)) * 1j |
69 | | - b_val = b_val + np.random.normal(size=b_size) * 1j |
70 | | - A_sym = A_val @ A_val.conj().T |
| 82 | + b_val = b_val + np.random.normal(size=b_shape) * 1j |
71 | 83 |
|
72 | | - A_tri = np.linalg.cholesky(A_sym).astype(dtype) |
73 | | - if unit_diag: |
74 | | - adj_mat = np.ones((5, 5)) |
75 | | - adj_mat[np.diag_indices(5)] = 1 / np.diagonal(A_tri) |
76 | | - A_tri = A_tri * adj_mat |
| 84 | + X_np = f(A_func(A_val.copy()), b_val.copy()) |
77 | 85 |
|
78 | | - A_tri = A_tri.astype(dtype) |
79 | | - b_val = b_val.astype(dtype) |
| 86 | + test_input = transpose_func(A_func(A_val.copy()), trans) |
| 87 | + np.testing.assert_allclose(test_input @ X_np, b_val, atol=ATOL, rtol=RTOL) |
80 | 88 |
|
81 | | - if not lower: |
82 | | - A_tri = A_tri.T |
| 89 | + compare_numba_and_py(f.maker.fgraph, [A_func(A_val.copy()), b_val.copy()]) |
83 | 90 |
|
84 | | - X_np = f(A_tri, b_val) |
85 | | - np.testing.assert_allclose( |
86 | | - transpose_func(A_tri, trans) @ X_np, b_val, atol=ATOL, rtol=RTOL |
| 91 | + # utt.verify_grad uses small perturbations to the input matrix to calculate the finite difference gradient. When |
| 92 | + # a non-triangular matrix is passed to scipy.linalg.solve_triangular, no error is raise, but the result will be |
| 93 | + # wrong, resulting in wrong gradients. As a result, it is necessary to add a mapping from the space of all matrices |
| 94 | + # to the space of triangular matrices, and test the gradient of that entire graph. |
| 95 | + def A_func_pt(x): |
| 96 | + x = x @ x.conj().T |
| 97 | + x_tri = pt.linalg.cholesky(x, lower=lower).astype(dtype) |
| 98 | + |
| 99 | + if unit_diag: |
| 100 | + n = A_val.shape[0] |
| 101 | + x_tri = x_tri[np.diag_indices(n)].set(1.0) |
| 102 | + |
| 103 | + return transpose_func(x_tri.astype(dtype), trans) |
| 104 | + |
| 105 | + utt.verify_grad( |
| 106 | + lambda A, b: solve_op(A_func_pt(A), b), |
| 107 | + [A_val.copy(), b_val.copy()], |
| 108 | + mode="NUMBA", |
87 | 109 | ) |
88 | 110 |
|
89 | 111 |
|
@@ -166,7 +188,8 @@ def test_numba_Cholesky_grad(lower, trans): |
166 | 188 | L = rng.normal(size=(5, 5)).astype(floatX) |
167 | 189 | X = L @ L.T |
168 | 190 |
|
169 | | - utt.verify_grad(pt.linalg.cholesky, [X]) |
| 191 | + chol_op = partial(pt.linalg.cholesky, lower=lower, trans=trans) |
| 192 | + utt.verify_grad(chol_op, [X], mode="NUMBA") |
170 | 193 |
|
171 | 194 |
|
172 | 195 | def test_block_diag(): |
@@ -319,69 +342,72 @@ def lu_solve(a, b, trans, overwrite_a, overwrite_b): |
319 | 342 |
|
320 | 343 |
|
321 | 344 | @pytest.mark.parametrize( |
322 | | - "b_func, b_size", |
323 | | - [(pt.matrix, (5, 1)), (pt.matrix, (5, 5)), (pt.vector, (5,))], |
| 345 | + "b_shape", |
| 346 | + [(5, 1), (5, 5), (5,)], |
324 | 347 | ids=["b_col_vec", "b_matrix", "b_vec"], |
325 | 348 | ) |
326 | 349 | @pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str) |
327 | | -@pytest.mark.parametrize("transposed", [True, False], ids=["trans", "no_trans"]) |
328 | 350 | @pytest.mark.filterwarnings( |
329 | 351 | 'ignore:Cannot cache compiled function "numba_funcified_fgraph"' |
330 | 352 | ) |
331 | | -def test_solve(b_func, b_size, assume_a, transposed): |
| 353 | +def test_solve(b_shape: tuple[int], assume_a: Literal["gen", "sym", "pos"]): |
332 | 354 | A = pt.matrix("A", dtype=floatX) |
333 | | - b = b_func("b", dtype=floatX) |
| 355 | + b = pt.tensor("b", shape=b_shape, dtype=floatX) |
| 356 | + |
| 357 | + A_val = np.asfortranarray(np.random.normal(size=(5, 5)).astype(floatX)) |
| 358 | + b_val = np.asfortranarray(np.random.normal(size=b_shape).astype(floatX)) |
| 359 | + |
| 360 | + def A_func(x): |
| 361 | + if assume_a == "pos": |
| 362 | + x = x.T @ x |
| 363 | + elif assume_a == "sym": |
| 364 | + x = (x.T + x) / 2 |
| 365 | + |
| 366 | + return x |
334 | 367 |
|
335 | 368 | X = pt.linalg.solve( |
336 | | - A, |
| 369 | + A_func(A), |
337 | 370 | b, |
338 | | - lower=False, |
339 | 371 | assume_a=assume_a, |
340 | | - transposed=transposed, |
341 | | - b_ndim=len(b_size), |
| 372 | + b_ndim=len(b_shape), |
342 | 373 | ) |
343 | 374 | f = pytensor.function( |
344 | 375 | [pytensor.In(A, mutable=True), pytensor.In(b, mutable=True)], X, mode="NUMBA" |
345 | 376 | ) |
| 377 | + op = f.maker.fgraph.outputs[0].owner.op |
346 | 378 |
|
347 | | - A_val = np.random.normal(size=(5, 5)).astype(floatX) |
348 | | - |
349 | | - if assume_a in ["sym", "pos"]: |
350 | | - A_val = A_val @ A_val.conj().T |
351 | | - A_val = np.asfortranarray(A_val) |
352 | | - |
353 | | - b_val = np.random.normal(size=b_size) |
354 | | - b_val = b_val.astype(floatX) |
355 | | - b_val = np.asfortranarray(b_val) |
| 379 | + compare_numba_and_py(f.maker.fgraph, inputs=[A_func(A_val.copy()), b_val.copy()]) |
356 | 380 |
|
| 381 | + # Calling this is destructive and will rewrite b_val to be the answer. Store copies of the inputs first. |
357 | 382 | A_val_copy = A_val.copy() |
358 | 383 | b_val_copy = b_val.copy() |
359 | 384 |
|
360 | | - X_np = f(A_val, b_val) |
361 | | - op = f.maker.fgraph.outputs[0].owner.op |
| 385 | + X_np = f(A_func(A_val), b_val) |
362 | 386 |
|
363 | 387 | # overwrite_b is preferred when both inputs can be destroyed |
364 | 388 | assert op.destroy_map == {0: [1]} |
365 | 389 |
|
366 | | - # Test that the result is numerically correct |
367 | | - np.testing.assert_allclose( |
368 | | - transpose_func(A_val_copy, transposed) @ X_np, b_val_copy, atol=ATOL, rtol=RTOL |
369 | | - ) |
370 | | - |
371 | | - # Confirm input was destroyed |
| 390 | + # Confirm inputs were destroyed by checking against the copies |
372 | 391 | assert (A_val == A_val_copy).all() == (op.destroy_map.get(0, None) != [0]) |
373 | 392 | assert (b_val == b_val_copy).all() == (op.destroy_map.get(0, None) != [1]) |
374 | 393 |
|
375 | | - # Test gradients |
376 | | - solve = partial( |
377 | | - pt.linalg.solve, |
378 | | - lower=False, |
379 | | - assume_a=assume_a, |
380 | | - transposed=transposed, |
381 | | - b_ndim=len(b_size), |
| 394 | + # Confirm b_val is used to store to solution |
| 395 | + np.testing.assert_allclose(X_np, b_val, atol=ATOL, rtol=RTOL) |
| 396 | + assert not np.allclose(b_val, b_val_copy) |
| 397 | + |
| 398 | + # Test that the result is numerically correct. Need to use the unmodified copy |
| 399 | + np.testing.assert_allclose( |
| 400 | + A_func(A_val_copy) @ X_np, b_val_copy, atol=ATOL, rtol=RTOL |
382 | 401 | ) |
383 | 402 |
|
384 | | - utt.verify_grad(solve, [A_val_copy, b_val_copy], mode="NUMBA") |
| 403 | + # See the note in tensor/test_slinalg.py::test_solve_correctness for details about the setup here |
| 404 | + utt.verify_grad( |
| 405 | + lambda A, b: pt.linalg.solve( |
| 406 | + A_func(A), b, lower=False, assume_a=assume_a, b_ndim=len(b_shape) |
| 407 | + ), |
| 408 | + [A_val_copy, b_val_copy], |
| 409 | + mode="NUMBA", |
| 410 | + ) |
385 | 411 |
|
386 | 412 |
|
387 | 413 | @pytest.mark.parametrize( |
|
0 commit comments