Skip to content

Commit 3359a2a

Browse files
committed
paramterized tests and added batch case
1 parent 1db9999 commit 3359a2a

File tree

2 files changed

+32
-31
lines changed

2 files changed

+32
-31
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949

5050

5151
logger = logging.getLogger(__name__)
52+
ALL_INVERSE_OPS = (MatrixInverse, MatrixPinv)
5253

5354

5455
def is_matrix_transpose(x: TensorVariable) -> bool:
@@ -593,11 +594,11 @@ def rewrite_inv_inv(fgraph, node):
593594
list of Variable, optional
594595
List of optimized variables, or None if no optimization was performed
595596
"""
596-
valid_inverses = (MatrixInverse, MatrixPinv)
597+
ALL_INVERSE_OPS = (MatrixInverse, MatrixPinv)
597598
# Check if its a valid inverse operation (either inv/pinv)
598599
# In case the outer operation is an inverse, it directly goes to the next step of finding inner operation
599600
# If the outer operation is not a valid inverse, we do not apply this rewrite
600-
if not isinstance(node.op.core_op, valid_inverses):
601+
if not isinstance(node.op.core_op, ALL_INVERSE_OPS):
601602
return None
602603

603604
potential_inner_inv = node.inputs[0].owner
@@ -608,7 +609,7 @@ def rewrite_inv_inv(fgraph, node):
608609
if not (
609610
potential_inner_inv
610611
and isinstance(potential_inner_inv.op, Blockwise)
611-
and isinstance(potential_inner_inv.op.core_op, valid_inverses)
612+
and isinstance(potential_inner_inv.op.core_op, ALL_INVERSE_OPS)
612613
):
613614
return None
614615
return [potential_inner_inv.inputs[0]]
@@ -632,20 +633,19 @@ def rewrite_inv_eye_to_eye(fgraph, node):
632633
list of Variable, optional
633634
List of optimized variables, or None if no optimization was performed
634635
"""
635-
valid_inverses = (MatrixInverse, MatrixPinv)
636636
core_op = node.op.core_op
637-
if not (isinstance(core_op, valid_inverses)):
637+
if not (isinstance(core_op, ALL_INVERSE_OPS)):
638638
return None
639639

640640
# Check whether input to inverse is Eye and the 1's are on main diagonal
641-
eye_check = node.inputs[0]
641+
potential_eye = node.inputs[0]
642642
if not (
643-
eye_check.owner
644-
and isinstance(eye_check.owner.op, Eye)
645-
and getattr(eye_check.owner.inputs[-1], "data", -1).item() == 0
643+
potential_eye.owner
644+
and isinstance(potential_eye.owner.op, Eye)
645+
and getattr(potential_eye.owner.inputs[-1], "data", -1).item() == 0
646646
):
647647
return None
648-
return [eye_check]
648+
return [potential_eye]
649649

650650

651651
@register_canonicalize
@@ -668,9 +668,8 @@ def rewrite_inv_diag_to_diag_reciprocal(fgraph, node):
668668
list of Variable, optional
669669
List of optimized variables, or None if no optimization was performed
670670
"""
671-
valid_inverses = (MatrixInverse, MatrixPinv)
672671
core_op = node.op.core_op
673-
if not (isinstance(core_op, valid_inverses)):
672+
if not (isinstance(core_op, ALL_INVERSE_OPS)):
674673
return None
675674

676675
inputs = node.inputs[0]
@@ -681,9 +680,8 @@ def rewrite_inv_diag_to_diag_reciprocal(fgraph, node):
681680
and AllocDiag.is_offset_zero(inputs.owner)
682681
):
683682
inv_input = inputs.owner.inputs[0]
684-
if inv_input.type.ndim == 1:
685-
inv_val = pt.diag(1 / inv_input)
686-
return [inv_val]
683+
inv_val = pt.diag(1 / inv_input)
684+
return [inv_val]
687685

688686
# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
689687
inputs_or_none = _find_diag_from_eye_mul(inputs)
@@ -700,8 +698,7 @@ def rewrite_inv_diag_to_diag_reciprocal(fgraph, node):
700698

701699
# For a matrix, we have to first extract the diagonal (non-zero values) and then only use those
702700
if non_eye_input.type.broadcastable[-2:] == (False, False):
703-
# For Matrix
704-
return [eye_input / non_eye_input.diagonal(axis1=-1, axis2=-2)]
705-
else:
706-
# For Vector or Scalar
707-
return [eye_input / non_eye_input]
701+
non_eye_diag = non_eye_input.diagonal(axis1=-1, axis2=-2)
702+
non_eye_input = pt.shape_padaxis(non_eye_diag, -2)
703+
704+
return [eye_input / non_eye_input]

tests/tensor/rewriting/test_linalg.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -560,22 +560,24 @@ def test_svd_uv_merge():
560560
assert svd_counter == 1
561561

562562

563+
def get_pt_function(x, op_name):
564+
return getattr(pt.linalg, op_name)(x)
565+
566+
563567
@pytest.mark.parametrize("inv_op_1", ["inv", "pinv"])
564568
@pytest.mark.parametrize("inv_op_2", ["inv", "pinv"])
565569
def test_inv_inv_rewrite(inv_op_1, inv_op_2):
566-
def get_pt_function(x, op_name):
567-
return getattr(pt.linalg, op_name)(x)
568-
569570
x = pt.matrix("x")
570571
op1 = get_pt_function(x, inv_op_1)
571572
op2 = get_pt_function(op1, inv_op_2)
572573
rewritten_out = rewrite_graph(op2)
573574
assert rewritten_out == x
574575

575576

576-
def test_inv_eye_to_eye():
577+
@pytest.mark.parametrize("inv_op", ["inv", "pinv"])
578+
def test_inv_eye_to_eye(inv_op):
577579
x = pt.eye(10)
578-
x_inv = pt.linalg.inv(x)
580+
x_inv = get_pt_function(x, inv_op)
579581
f_rewritten = function([], x_inv, mode="FAST_RUN")
580582
nodes = f_rewritten.maker.fgraph.apply_nodes
581583

@@ -598,15 +600,16 @@ def test_inv_eye_to_eye():
598600

599601
@pytest.mark.parametrize(
600602
"shape",
601-
[(), (7,), (7, 7)],
602-
ids=["scalar", "vector", "matrix"],
603+
[(), (7,), (7, 7), (5, 7, 7)],
604+
ids=["scalar", "vector", "matrix", "batched"],
603605
)
604-
def test_inv_diag_from_eye_mul(shape):
606+
@pytest.mark.parametrize("inv_op", ["inv", "pinv"])
607+
def test_inv_diag_from_eye_mul(shape, inv_op):
605608
# Initializing x based on scalar/vector/matrix
606609
x = pt.tensor("x", shape=shape)
607610
x_diag = pt.eye(7) * x
608611
# Calculating inverse using pt.linalg.inv
609-
x_inv = pt.linalg.inv(x_diag)
612+
x_inv = get_pt_function(x_diag, inv_op)
610613

611614
# REWRITE TEST
612615
f_rewritten = function([x], x_inv, mode="FAST_RUN")
@@ -634,10 +637,11 @@ def test_inv_diag_from_eye_mul(shape):
634637
)
635638

636639

637-
def test_inv_diag_from_diag():
640+
@pytest.mark.parametrize("inv_op", ["inv", "pinv"])
641+
def test_inv_diag_from_diag(inv_op):
638642
x = pt.dvector("x")
639643
x_diag = pt.diag(x)
640-
x_inv = pt.linalg.inv(x_diag)
644+
x_inv = get_pt_function(x_diag, inv_op)
641645

642646
# REWRITE TEST
643647
f_rewritten = function([x], x_inv, mode="FAST_RUN")

0 commit comments

Comments
 (0)