Skip to content

Commit 8c16f65

Browse files
committed
parametrized test
1 parent b6f6e5c commit 8c16f65

File tree

1 file changed

+19
-56
lines changed

1 file changed

+19
-56
lines changed

tests/tensor/rewriting/test_linalg.py

Lines changed: 19 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -547,59 +547,22 @@ def test_svd_uv_merge():
547547
assert svd_counter == 1
548548

549549

550-
def test_inv_inv_rewrite():
551-
x = pt.matrix("a")
552-
inv_ops = [pt.linalg.inv, pt.linalg.pinv]
553-
solve_ops = [pt.linalg.solve, pt.linalg.solve_triangular]
554-
valid_inverses = (MatrixInverse, MatrixPinv)
555-
valid_solves = (Solve, SolveTriangular)
556-
all_valid = (MatrixInverse, MatrixPinv, Solve, SolveTriangular)
557-
# inv(inv)
558-
for inv_op in inv_ops:
559-
for inv_op_2 in inv_ops:
560-
ii_x = inv_op(inv_op_2(x))
561-
f_rewritten = function([x], ii_x, mode="FAST_RUN")
562-
nodes = f_rewritten.maker.fgraph.apply_nodes
563-
564-
assert not any(isinstance(node.op, valid_inverses) for node in nodes)
565-
566-
x_testing = np.random.rand(10, 10)
567-
np.testing.assert_allclose(f_rewritten(x_testing), x_testing)
568-
# solve(solve)
569-
b_eye = pt.eye(10)
570-
for solve_op in solve_ops:
571-
for solve_op_2 in solve_ops:
572-
ss_x = solve_op(solve_op_2(x, b_eye), b_eye)
573-
with pytensor.config.change_flags(optimizer_verbose=True):
574-
f_rewritten = function([x], ss_x, mode="FAST_RUN")
575-
nodes = f_rewritten.maker.fgraph.apply_nodes
576-
577-
assert not any(isinstance(node.op, valid_solves) for node in nodes)
578-
579-
x_testing = np.random.rand(10, 10)
580-
np.testing.assert_allclose(f_rewritten(x_testing), x_testing)
581-
582-
# inv(solve)
583-
for inv_op in inv_ops:
584-
for solve_op in solve_ops:
585-
is_x = inv_op(solve_op(x, b_eye))
586-
with pytensor.config.change_flags(optimizer_verbose=True):
587-
f_rewritten = function([x], is_x, mode="FAST_RUN")
588-
nodes = f_rewritten.maker.fgraph.apply_nodes
589-
assert not any(isinstance(node.op, all_valid) for node in nodes)
590-
591-
x_testing = np.random.rand(10, 10)
592-
np.testing.assert_allclose(f_rewritten(x_testing), x_testing)
593-
594-
# solve(inv)
595-
for solve_op in solve_ops:
596-
for inv_op in inv_ops:
597-
si_x = solve_op(inv_op(x), b_eye)
598-
with pytensor.config.change_flags(optimizer_verbose=True):
599-
f_rewritten = function([x], si_x, mode="FAST_RUN")
600-
nodes = f_rewritten.maker.fgraph.apply_nodes
601-
602-
assert not any(isinstance(node.op, all_valid) for node in nodes)
603-
604-
x_testing = np.random.rand(10, 10)
605-
np.testing.assert_allclose(f_rewritten(x_testing), x_testing)
550+
@pytest.mark.parametrize("inv_op_1", ["inv", "pinv", "solve", "solve_triangular"])
551+
@pytest.mark.parametrize("inv_op_2", ["inv", "pinv", "solve", "solve_triangular"])
552+
def test_inv_inv_rewrite(inv_op_1, inv_op_2):
553+
def get_pt_function(x, op_name):
554+
if "solve" in op_name:
555+
return getattr(pt.linalg, op_name)(x, pt.eye(x.shape[0]))
556+
return getattr(pt.linalg, op_name)(x)
557+
558+
x = pt.matrix("x")
559+
op1 = get_pt_function(x, inv_op_1)
560+
op2 = get_pt_function(op1, inv_op_2)
561+
f_rewritten = function([x], op2, mode="FAST_RUN")
562+
nodes = f_rewritten.maker.fgraph.apply_nodes
563+
564+
valid_inverses = (MatrixInverse, MatrixPinv, Solve, SolveTriangular)
565+
566+
assert not any(isinstance(node.op, valid_inverses) for node in nodes)
567+
x_testing = np.random.rand(10, 10)
568+
np.testing.assert_allclose(f_rewritten(x_testing), x_testing)

0 commit comments

Comments
 (0)