Skip to content

Commit ec38857

Browse files
committed
adds a rewrite for det(kronecker) instead of slogdet
1 parent a377c22 commit ec38857

File tree

2 files changed

+114
-58
lines changed

2 files changed

+114
-58
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 78 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
inv,
3333
kron,
3434
pinv,
35-
slogdet,
3635
svd,
3736
)
3837
from pytensor.tensor.rewriting.basic import (
@@ -781,43 +780,43 @@ def rewrite_det_blockdiag(fgraph, node):
781780
return [prod(det_sub_matrices)]
782781

783782

784-
@register_canonicalize
785-
@register_stabilize
786-
@node_rewriter([slogdet])
787-
def rewrite_slogdet_blockdiag(fgraph, node):
788-
"""
789-
This rewrite simplifies the slogdet of a blockdiagonal matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those
790-
791-
slogdet(block_diag(a,b,c,....)) = prod(sign(a), sign(b), sign(c),...), sum(logdet(a), logdet(b), logdet(c),....)
792-
793-
Parameters
794-
----------
795-
fgraph: FunctionGraph
796-
Function graph being optimized
797-
node: Apply
798-
Node of the function graph to be optimized
799-
800-
Returns
801-
-------
802-
list of Variable, optional
803-
List of optimized variables, or None if no optimization was performed
804-
"""
805-
# Check for inner block_diag operation
806-
potential_block_diag = node.inputs[0].owner
807-
if not (
808-
potential_block_diag
809-
and isinstance(potential_block_diag.op, Blockwise)
810-
and isinstance(potential_block_diag.op.core_op, BlockDiagonal)
811-
):
812-
return None
813-
814-
# Find the composing sub_matrices
815-
sub_matrices = potential_block_diag.inputs
816-
sign_sub_matrices, logdet_sub_matrices = zip(
817-
*[slogdet(sub_matrices[i]) for i in range(len(sub_matrices))]
818-
)
819-
820-
return [prod(sign_sub_matrices), sum(logdet_sub_matrices)]
783+
# @register_canonicalize
784+
# @register_stabilize
785+
# @node_rewriter([slogdet])
786+
# def rewrite_slogdet_blockdiag(fgraph, node):
787+
# """
788+
# This rewrite simplifies the slogdet of a blockdiagonal matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those
789+
790+
# slogdet(block_diag(a,b,c,....)) = prod(sign(a), sign(b), sign(c),...), sum(logdet(a), logdet(b), logdet(c),....)
791+
792+
# Parameters
793+
# ----------
794+
# fgraph: FunctionGraph
795+
# Function graph being optimized
796+
# node: Apply
797+
# Node of the function graph to be optimized
798+
799+
# Returns
800+
# -------
801+
# list of Variable, optional
802+
# List of optimized variables, or None if no optimization was performed
803+
# """
804+
# # Check for inner block_diag operation
805+
# potential_block_diag = node.inputs[0].owner
806+
# if not (
807+
# potential_block_diag
808+
# and isinstance(potential_block_diag.op, Blockwise)
809+
# and isinstance(potential_block_diag.op.core_op, BlockDiagonal)
810+
# ):
811+
# return None
812+
813+
# # Find the composing sub_matrices
814+
# sub_matrices = potential_block_diag.inputs
815+
# sign_sub_matrices, logdet_sub_matrices = zip(
816+
# *[slogdet(sub_matrices[i]) for i in range(len(sub_matrices))]
817+
# )
818+
819+
# return [prod(sign_sub_matrices), sum(logdet_sub_matrices)]
821820

822821

823822
@register_canonicalize
@@ -854,12 +853,47 @@ def rewrite_diag_kronecker(fgraph, node):
854853
return [outer_prod_as_vector]
855854

856855

856+
# @register_canonicalize
857+
# @register_stabilize
858+
# @node_rewriter([slogdet])
859+
# def rewrite_slogdet_kronecker(fgraph, node):
860+
# """
861+
# This rewrite simplifies the slogdet of a kronecker-structured matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those
862+
863+
# Parameters
864+
# ----------
865+
# fgraph: FunctionGraph
866+
# Function graph being optimized
867+
# node: Apply
868+
# Node of the function graph to be optimized
869+
870+
# Returns
871+
# -------
872+
# list of Variable, optional
873+
# List of optimized variables, or None if no optimization was performed
874+
# """
875+
# # Check for inner kron operation
876+
# potential_kron = node.inputs[0].owner
877+
# if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)):
878+
# return None
879+
880+
# # Find the matrices
881+
# a, b = potential_kron.inputs
882+
# signs, logdets = zip(*[slogdet(a), slogdet(b)])
883+
# sizes = [a.shape[-1], b.shape[-1]]
884+
# prod_sizes = prod(sizes, no_zeros_in_input=True)
885+
# signs_final = [signs[i] ** (prod_sizes / sizes[i]) for i in range(2)]
886+
# logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)]
887+
888+
# return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)]
889+
890+
857891
@register_canonicalize
858892
@register_stabilize
859-
@node_rewriter([slogdet])
860-
def rewrite_slogdet_kronecker(fgraph, node):
893+
@node_rewriter([det])
894+
def rewrite_det_kronecker(fgraph, node):
861895
"""
862-
This rewrite simplifies the slogdet of a kronecker-structured matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those
896+
This rewrite simplifies the determinant of a kronecker-structured matrix by extracting the individual sub matrices and returning the det values computed using those
863897
864898
Parameters
865899
----------
@@ -880,13 +914,12 @@ def rewrite_slogdet_kronecker(fgraph, node):
880914

881915
# Find the matrices
882916
a, b = potential_kron.inputs
883-
signs, logdets = zip(*[slogdet(a), slogdet(b)])
917+
dets = [det(a), det(b)]
884918
sizes = [a.shape[-1], b.shape[-1]]
885919
prod_sizes = prod(sizes, no_zeros_in_input=True)
886-
signs_final = [signs[i] ** (prod_sizes / sizes[i]) for i in range(2)]
887-
logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)]
920+
det_final = prod([dets[i] ** (prod_sizes / sizes[i]) for i in range(2)])
888921

889-
return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)]
922+
return [det_final]
890923

891924

892925
@register_canonicalize

tests/tensor/rewriting/test_linalg.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -776,11 +776,40 @@ def test_diag_kronecker_rewrite():
776776
)
777777

778778

779-
def test_slogdet_kronecker_rewrite():
779+
# def test_slogdet_kronecker_rewrite():
780+
# a, b = pt.dmatrices("a", "b")
781+
# kron_prod = pt.linalg.kron(a, b)
782+
# sign_output, logdet_output = pt.linalg.slogdet(kron_prod)
783+
# f_rewritten = function([kron_prod], [sign_output, logdet_output], mode="FAST_RUN")
784+
785+
# # Rewrite Test
786+
# nodes = f_rewritten.maker.fgraph.apply_nodes
787+
# assert not any(isinstance(node.op, KroneckerProduct) for node in nodes)
788+
789+
# # Value Test
790+
# a_test, b_test = np.random.rand(2, 20, 20)
791+
# kron_prod_test = np.kron(a_test, b_test)
792+
# sign_output_test, logdet_output_test = np.linalg.slogdet(kron_prod_test)
793+
# rewritten_sign_val, rewritten_logdet_val = f_rewritten(kron_prod_test)
794+
# assert_allclose(
795+
# sign_output_test,
796+
# rewritten_sign_val,
797+
# atol=1e-3 if config.floatX == "float32" else 1e-8,
798+
# rtol=1e-3 if config.floatX == "float32" else 1e-8,
799+
# )
800+
# assert_allclose(
801+
# logdet_output_test,
802+
# rewritten_logdet_val,
803+
# atol=1e-3 if config.floatX == "float32" else 1e-8,
804+
# rtol=1e-3 if config.floatX == "float32" else 1e-8,
805+
# )
806+
807+
808+
def test_det_kronecker_rewrite():
780809
a, b = pt.dmatrices("a", "b")
781810
kron_prod = pt.linalg.kron(a, b)
782-
sign_output, logdet_output = pt.linalg.slogdet(kron_prod)
783-
f_rewritten = function([kron_prod], [sign_output, logdet_output], mode="FAST_RUN")
811+
det_output = pt.linalg.det(kron_prod)
812+
f_rewritten = function([kron_prod], [det_output], mode="FAST_RUN")
784813

785814
# Rewrite Test
786815
nodes = f_rewritten.maker.fgraph.apply_nodes
@@ -789,17 +818,11 @@ def test_slogdet_kronecker_rewrite():
789818
# Value Test
790819
a_test, b_test = np.random.rand(2, 20, 20)
791820
kron_prod_test = np.kron(a_test, b_test)
792-
sign_output_test, logdet_output_test = np.linalg.slogdet(kron_prod_test)
793-
rewritten_sign_val, rewritten_logdet_val = f_rewritten(kron_prod_test)
821+
det_output_test = np.linalg.det(kron_prod_test)
822+
rewritten_det_val = f_rewritten(kron_prod_test)
794823
assert_allclose(
795-
sign_output_test,
796-
rewritten_sign_val,
797-
atol=1e-3 if config.floatX == "float32" else 1e-8,
798-
rtol=1e-3 if config.floatX == "float32" else 1e-8,
799-
)
800-
assert_allclose(
801-
logdet_output_test,
802-
rewritten_logdet_val,
824+
det_output_test,
825+
rewritten_det_val,
803826
atol=1e-3 if config.floatX == "float32" else 1e-8,
804827
rtol=1e-3 if config.floatX == "float32" else 1e-8,
805828
)

0 commit comments

Comments
 (0)