Skip to content

Commit bfe075d

Browse files
Improve solve test
1 parent c5f7bec commit bfe075d

File tree

1 file changed

+28
-19
lines changed

1 file changed

+28
-19
lines changed

tests/tensor/test_slinalg.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -238,37 +238,46 @@ def test_infer_shape(self, b_shape):
238238
"b_size", [(5, 1), (5, 5), (5,)], ids=["b_col_vec", "b_matrix", "b_vec"]
239239
)
240240
@pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str)
241-
@pytest.mark.parametrize("transposed", [True, False], ids=["trans", "no_trans"])
242-
def test_solve_correctness(b_size: tuple[int], assume_a: str, transposed: bool):
241+
def test_solve_correctness(b_size: tuple[int], assume_a: str):
243242
rng = np.random.default_rng(utt.fetch_seed())
244243
A = pt.tensor("A", shape=(5, 5))
245244
b = pt.tensor("b", shape=b_size)
246245

247-
solve_op = functools.partial(
248-
solve, assume_a=assume_a, transposed=transposed, b_ndim=len(b_size)
249-
)
250-
y = solve_op(
251-
A,
252-
b,
253-
)
254-
solve_func = pytensor.function([A, b], y)
255-
256-
b_val = rng.normal(size=b_size).astype(config.floatX)
257246
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
247+
b_val = rng.normal(size=b_size).astype(config.floatX)
248+
249+
solve_op = functools.partial(solve, assume_a=assume_a, b_ndim=len(b_size))
258250

259-
if assume_a == "sym":
260-
A_val = (A_val + A_val.T) / 2
261-
elif assume_a == "pos":
262-
A_val = A_val @ A_val.T
251+
def A_func(x):
252+
if assume_a == "pos":
253+
return x @ x.T
254+
elif assume_a == "sym":
255+
return (x + x.T) / 2
256+
else:
257+
return x
258+
259+
solve_input_val = A_func(A_val)
260+
261+
y = solve_op(A_func(A), b)
262+
solve_func = pytensor.function([A, b], y)
263+
X_np = solve_func(A_val.copy(), b_val.copy())
263264

264265
np.testing.assert_allclose(
265-
scipy.linalg.solve(A_val, b_val, assume_a=assume_a, transposed=transposed),
266-
solve_func(A_val, b_val),
266+
scipy.linalg.solve(solve_input_val, b_val, assume_a=assume_a),
267+
X_np,
267268
)
268269

270+
np.testing.assert_allclose(A_func(A_val) @ X_np, b_val, atol=1e-6)
271+
269272
eps = 2e-8 if config.floatX == "float64" else None
270273

271-
utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps)
274+
# To correctly check the gradients, we need to include a transformation from the space of unconstrained matrices
275+
# (A) to a valid input matrix for the given solver. This is done by the A_func function. If this isn't included,
276+
# the random perturbations used by verify_grad will result in invalid input matrices, and
277+
# LAPACK will silently do the wrong thing, making the gradients wrong
278+
utt.verify_grad(
279+
lambda A, b: solve_op(A_func(A), b), [A_val, b_val], 3, rng, eps=eps
280+
)
272281

273282

274283
class TestSolveTriangular(utt.InferShapeTester):

0 commit comments

Comments
 (0)