Skip to content

Commit 5604d9a

Browse files
Refactor rewrite_det_diag_to_prod_diag to use AllocDiag2
1 parent 6857bea commit 5604d9a

File tree

1 file changed

+31
-37
lines changed

1 file changed

+31
-37
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 31 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,16 @@
55
from pytensor import Variable
66
from pytensor.graph import Apply, FunctionGraph
77
from pytensor.graph.rewriting.basic import (
8-
PatternNodeRewriter,
98
copy_stack_trace,
109
node_rewriter,
1110
)
1211
from pytensor.scalar.basic import Mul
13-
from pytensor.tensor.basic import ARange, Eye, TensorVariable, alloc, diagonal
12+
from pytensor.tensor.basic import (
13+
AllocDiag2,
14+
Eye,
15+
TensorVariable,
16+
diagonal,
17+
)
1418
from pytensor.tensor.blas import Dot22
1519
from pytensor.tensor.blockwise import Blockwise
1620
from pytensor.tensor.elemwise import DimShuffle, Elemwise
@@ -41,7 +45,6 @@
4145
solve,
4246
solve_triangular,
4347
)
44-
from pytensor.tensor.subtensor import advanced_set_subtensor
4548

4649

4750
logger = logging.getLogger(__name__)
@@ -420,11 +423,15 @@ def _find_diag_from_eye_mul(potential_mul_input):
420423
@register_canonicalize("shape_unsafe")
421424
@register_stabilize("shape_unsafe")
422425
@node_rewriter([det])
423-
def rewrite_det_diag_from_eye_mul(fgraph, node):
426+
def rewrite_det_diag_to_prod_diag(fgraph, node):
424427
"""
425-
This rewrite takes advantage of the fact that for a diagonal matrix, the determinant value is the product of its diagonal elements.
428+
This rewrite takes advantage of the fact that for a diagonal matrix, the determinant value is the product of its
429+
diagonal elements.
426430
427-
The presence of a diagonal matrix is detected by inspecting the graph. This rewrite can identify diagonal matrices that arise as the result of elementwise multiplication with an identity matrix. Specialized computation is used to make this rewrite as efficient as possible, depending on whether the multiplication was with a scalar, vector or a matrix.
431+
The presence of a diagonal matrix is detected by inspecting the graph. This rewrite can identify diagonal matrices
432+
that arise as the result of elementwise multiplication with an identity matrix. Specialized computation is used to
433+
make this rewrite as efficient as possible, depending on whether the multiplication was with a scalar,
434+
vector or a matrix.
428435
429436
Parameters
430437
----------
@@ -438,53 +445,40 @@ def rewrite_det_diag_from_eye_mul(fgraph, node):
438445
list of Variable, optional
439446
List of optimized variables, or None if no optimization was performed
440447
"""
441-
potential_mul_input = node.inputs[0]
442-
eye_non_eye_inputs = _find_diag_from_eye_mul(potential_mul_input)
443-
if eye_non_eye_inputs is None:
448+
inputs = node.inputs[0]
449+
450+
# Check for use of pt.diag first
451+
if inputs.owner and isinstance(inputs.owner.op, AllocDiag2):
452+
diag_input = inputs.owner.inputs[0]
453+
det_val = diag_input.prod(axis=-1)
454+
return [det_val]
455+
456+
# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
457+
inputs_or_none = _find_diag_from_eye_mul(inputs)
458+
if inputs_or_none is None:
444459
return None
445-
eye_input, non_eye_inputs = eye_non_eye_inputs
460+
eye_input, non_eye_inputs = inputs_or_none
446461

447462
# Dealing with only one other input
448463
if len(non_eye_inputs) != 1:
449464
return None
450465

451-
useful_eye, useful_non_eye = eye_input[0], non_eye_inputs[0]
466+
eye_input, non_eye_input = eye_input[0], non_eye_inputs[0]
452467

453468
# Checking if original x was scalar/vector/matrix
454-
if useful_non_eye.type.broadcastable[-2:] == (True, True):
469+
if non_eye_input.type.broadcastable[-2:] == (True, True):
455470
# For scalar
456-
det_val = useful_non_eye.squeeze(axis=(-1, -2)) ** (useful_eye.shape[0])
457-
elif useful_non_eye.type.broadcastable[-2:] == (False, False):
471+
det_val = non_eye_input.squeeze(axis=(-1, -2)) ** (eye_input.shape[0])
472+
elif non_eye_input.type.broadcastable[-2:] == (False, False):
458473
# For Matrix
459-
det_val = useful_non_eye.diagonal(axis1=-1, axis2=-2).prod(axis=-1)
474+
det_val = non_eye_input.diagonal(axis1=-1, axis2=-2).prod(axis=-1)
460475
else:
461476
# For vector
462-
det_val = useful_non_eye.prod(axis=(-1, -2))
477+
det_val = non_eye_input.prod(axis=(-1, -2))
463478
det_val = det_val.astype(node.outputs[0].type.dtype)
464479
return [det_val]
465480

466481

467-
arange = ARange("int64")
468-
det_diag_from_diag = PatternNodeRewriter(
469-
(
470-
det,
471-
(
472-
advanced_set_subtensor,
473-
(alloc, 0, "sh1", "sh2"),
474-
"x",
475-
(arange, 0, "stop", 1),
476-
(arange, 0, "stop", 1),
477-
),
478-
),
479-
(prod, "x"),
480-
name="det_diag_from_diag",
481-
allow_multiple_clients=True,
482-
)
483-
register_canonicalize(det_diag_from_diag)
484-
register_stabilize(det_diag_from_diag)
485-
register_specialize(det_diag_from_diag)
486-
487-
488482
@register_canonicalize
489483
@register_stabilize
490484
@register_specialize

0 commit comments

Comments
 (0)