49
49
50
50
51
51
logger = logging .getLogger (__name__ )
52
+ ALL_INVERSE_OPS = (MatrixInverse , MatrixPinv )
52
53
53
54
54
55
def is_matrix_transpose (x : TensorVariable ) -> bool :
@@ -593,11 +594,11 @@ def rewrite_inv_inv(fgraph, node):
593
594
list of Variable, optional
594
595
List of optimized variables, or None if no optimization was performed
595
596
"""
596
- valid_inverses = (MatrixInverse , MatrixPinv )
597
+ ALL_INVERSE_OPS = (MatrixInverse , MatrixPinv )
597
598
# Check if its a valid inverse operation (either inv/pinv)
598
599
# In case the outer operation is an inverse, it directly goes to the next step of finding inner operation
599
600
# If the outer operation is not a valid inverse, we do not apply this rewrite
600
- if not isinstance (node .op .core_op , valid_inverses ):
601
+ if not isinstance (node .op .core_op , ALL_INVERSE_OPS ):
601
602
return None
602
603
603
604
potential_inner_inv = node .inputs [0 ].owner
@@ -608,7 +609,7 @@ def rewrite_inv_inv(fgraph, node):
608
609
if not (
609
610
potential_inner_inv
610
611
and isinstance (potential_inner_inv .op , Blockwise )
611
- and isinstance (potential_inner_inv .op .core_op , valid_inverses )
612
+ and isinstance (potential_inner_inv .op .core_op , ALL_INVERSE_OPS )
612
613
):
613
614
return None
614
615
return [potential_inner_inv .inputs [0 ]]
@@ -632,20 +633,19 @@ def rewrite_inv_eye_to_eye(fgraph, node):
632
633
list of Variable, optional
633
634
List of optimized variables, or None if no optimization was performed
634
635
"""
635
- valid_inverses = (MatrixInverse , MatrixPinv )
636
636
core_op = node .op .core_op
637
- if not (isinstance (core_op , valid_inverses )):
637
+ if not (isinstance (core_op , ALL_INVERSE_OPS )):
638
638
return None
639
639
640
640
# Check whether input to inverse is Eye and the 1's are on main diagonal
641
- eye_check = node .inputs [0 ]
641
+ potential_eye = node .inputs [0 ]
642
642
if not (
643
- eye_check .owner
644
- and isinstance (eye_check .owner .op , Eye )
645
- and getattr (eye_check .owner .inputs [- 1 ], "data" , - 1 ).item () == 0
643
+ potential_eye .owner
644
+ and isinstance (potential_eye .owner .op , Eye )
645
+ and getattr (potential_eye .owner .inputs [- 1 ], "data" , - 1 ).item () == 0
646
646
):
647
647
return None
648
- return [eye_check ]
648
+ return [potential_eye ]
649
649
650
650
651
651
@register_canonicalize
@@ -668,9 +668,8 @@ def rewrite_inv_diag_to_diag_reciprocal(fgraph, node):
668
668
list of Variable, optional
669
669
List of optimized variables, or None if no optimization was performed
670
670
"""
671
- valid_inverses = (MatrixInverse , MatrixPinv )
672
671
core_op = node .op .core_op
673
- if not (isinstance (core_op , valid_inverses )):
672
+ if not (isinstance (core_op , ALL_INVERSE_OPS )):
674
673
return None
675
674
676
675
inputs = node .inputs [0 ]
@@ -681,9 +680,8 @@ def rewrite_inv_diag_to_diag_reciprocal(fgraph, node):
681
680
and AllocDiag .is_offset_zero (inputs .owner )
682
681
):
683
682
inv_input = inputs .owner .inputs [0 ]
684
- if inv_input .type .ndim == 1 :
685
- inv_val = pt .diag (1 / inv_input )
686
- return [inv_val ]
683
+ inv_val = pt .diag (1 / inv_input )
684
+ return [inv_val ]
687
685
688
686
# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
689
687
inputs_or_none = _find_diag_from_eye_mul (inputs )
@@ -700,8 +698,7 @@ def rewrite_inv_diag_to_diag_reciprocal(fgraph, node):
700
698
701
699
# For a matrix, we have to first extract the diagonal (non-zero values) and then only use those
702
700
if non_eye_input .type .broadcastable [- 2 :] == (False , False ):
703
- # For Matrix
704
- return [eye_input / non_eye_input .diagonal (axis1 = - 1 , axis2 = - 2 )]
705
- else :
706
- # For Vector or Scalar
707
- return [eye_input / non_eye_input ]
701
+ non_eye_diag = non_eye_input .diagonal (axis1 = - 1 , axis2 = - 2 )
702
+ non_eye_input = pt .shape_padaxis (non_eye_diag , - 2 )
703
+
704
+ return [eye_input / non_eye_input ]
0 commit comments