Skip to content

Commit e2f87e8

Browse files
Expand test coverage for Solve
1 parent 6c6b607 commit e2f87e8

File tree

1 file changed

+44
-16
lines changed

1 file changed

+44
-16
lines changed

tests/tensor/test_slinalg.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,22 @@ def test_solve_raises_on_invalid_A():
214214
Solve(assume_a="test", b_ndim=2)
215215

216216

217+
solve_test_cases = [
218+
("gen", False, False),
219+
("gen", False, True),
220+
("sym", False, False),
221+
("sym", True, False),
222+
("sym", True, True),
223+
("pos", False, False),
224+
("pos", True, False),
225+
("pos", True, True),
226+
]
227+
solve_test_ids = [
228+
f'{assume_a}_{"lower" if lower else "upper"}_{"A^T" if transposed else "A"}'
229+
for assume_a, lower, transposed in solve_test_cases
230+
]
231+
232+
217233
class TestSolve(utt.InferShapeTester):
218234
@pytest.mark.parametrize("b_shape", [(5, 1), (5,)])
219235
def test_infer_shape(self, b_shape):
@@ -235,16 +251,26 @@ def test_infer_shape(self, b_shape):
235251
@pytest.mark.parametrize(
236252
"b_size", [(5, 1), (5, 5), (5,)], ids=["b_col_vec", "b_matrix", "b_vec"]
237253
)
238-
@pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str)
239-
def test_solve_correctness(self, b_size: tuple[int], assume_a: str):
254+
@pytest.mark.parametrize(
255+
"assume_a, lower, transposed", solve_test_cases, ids=solve_test_ids
256+
)
257+
def test_solve_correctness(
258+
self, b_size: tuple[int], assume_a: str, lower: bool, transposed: bool
259+
):
240260
rng = np.random.default_rng(utt.fetch_seed())
241261
A = pt.tensor("A", shape=(5, 5))
242262
b = pt.tensor("b", shape=b_size)
243263

244264
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
245265
b_val = rng.normal(size=b_size).astype(config.floatX)
246266

247-
solve_op = functools.partial(solve, assume_a=assume_a, b_ndim=len(b_size))
267+
solve_op = functools.partial(
268+
solve,
269+
assume_a=assume_a,
270+
lower=lower,
271+
transposed=transposed,
272+
b_ndim=len(b_size),
273+
)
248274

249275
def A_func(x):
250276
if assume_a == "pos":
@@ -254,6 +280,11 @@ def A_func(x):
254280
else:
255281
return x
256282

283+
def T(x):
284+
if transposed:
285+
return x.T
286+
return x
287+
257288
solve_input_val = A_func(A_val)
258289

259290
y = solve_op(A_func(A), b)
@@ -264,30 +295,27 @@ def A_func(x):
264295
RTOL = 1e-8 if config.floatX.endswith("64") else 1e-4
265296

266297
np.testing.assert_allclose(
267-
scipy.linalg.solve(solve_input_val, b_val, assume_a=assume_a),
298+
scipy.linalg.solve(
299+
solve_input_val,
300+
b_val,
301+
assume_a=assume_a,
302+
transposed=transposed,
303+
lower=lower,
304+
),
268305
X_np,
269306
atol=ATOL,
270307
rtol=RTOL,
271308
)
272309

273-
np.testing.assert_allclose(A_func(A_val) @ X_np, b_val, atol=ATOL, rtol=RTOL)
310+
np.testing.assert_allclose(T(A_func(A_val)) @ X_np, b_val, atol=ATOL, rtol=RTOL)
274311

275312
@pytest.mark.parametrize(
276313
"b_size", [(5, 1), (5, 5), (5,)], ids=["b_col_vec", "b_matrix", "b_vec"]
277314
)
278315
@pytest.mark.parametrize(
279316
"assume_a, lower, transposed",
280-
[
281-
("gen", False, False),
282-
("gen", False, True),
283-
("sym", False, False),
284-
("sym", True, False),
285-
("sym", True, True),
286-
("pos", False, False),
287-
("pos", True, False),
288-
("pos", True, True),
289-
],
290-
ids=str,
317+
solve_test_cases,
318+
ids=solve_test_ids,
291319
)
292320
@pytest.mark.skipif(
293321
config.floatX == "float32", reason="Gradients not numerically stable in float32"

0 commit comments

Comments
 (0)