@@ -422,33 +422,33 @@ def _find_diag_from_eye_mul(potential_mul_input):
422
422
)
423
423
)
424
424
]
425
+
425
426
if not eye_input :
426
427
return None
427
428
428
429
eye_input = eye_input [0 ]
430
+ # If eye_input is an Eye Op (it's not wrapped in a DimShuffle), check it doesn't have an offset
431
+ if isinstance (eye_input .owner .op , Eye ) and (
432
+ not Eye .is_offset_zero (eye_input .owner )
433
+ or eye_input .broadcastable [- 2 :] != (False , False )
434
+ ):
435
+ return None
429
436
430
- # If this multiplication came from a batched operation, it will be wrapped in a DimShuffle
437
+ # Otherwise, an Eye was found but it is wrapped in a DimShuffle (i.e. there was some broadcasting going on).
438
+ # We have to look inside DimShuffle to decide if the rewrite can be applied
431
439
if isinstance (eye_input .owner .op , DimShuffle ) and (
432
440
eye_input .owner .op .is_left_expand_dims
433
441
or eye_input .owner .op .is_right_expand_dims
434
442
):
435
443
inner_eye = eye_input .owner .inputs [0 ]
436
- if not isinstance (inner_eye .owner .op , Eye ):
437
- return None
438
- # Check if 1's are being put on the main diagonal only (k = 0)
439
- # and if the identity matrix is degenerate (column or row matrix)
440
- if not (
441
- Eye .is_offset_zero (inner_eye .owner )
442
- and inner_eye .broadcastable [- 1 :] != (False , False )
444
+ # We can only rewrite when the Eye is on the main diagonal (the offset is zero) and the identity isn't
445
+ # degenerate
446
+ if not Eye .is_offset_zero (inner_eye .owner ) or inner_eye .broadcastable [- 2 :] != (
447
+ False ,
448
+ False ,
443
449
):
444
450
return None
445
451
446
- elif not (
447
- Eye .is_offset_zero (eye_input .owner )
448
- and eye_input .broadcastable [- 1 :] != (False , False )
449
- ):
450
- return None
451
-
452
452
# Get all non Eye inputs (scalars/matrices/vectors)
453
453
non_eye_inputs = list (set (inputs_to_mul ) - {eye_input })
454
454
return eye_input , non_eye_inputs
@@ -493,7 +493,6 @@ def rewrite_det_diag_to_prod_diag(fgraph, node):
493
493
494
494
# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
495
495
inputs_or_none = _find_diag_from_eye_mul (inputs )
496
-
497
496
if inputs_or_none is None :
498
497
return None
499
498
0 commit comments