Skip to content

Commit 6e4f04b

Browse files
committed
minor changes
1 parent 60e34fe commit 6e4f04b

File tree

3 files changed

+41
-17
lines changed

3 files changed

+41
-17
lines changed

pytensor/tensor/nlinalg.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,10 @@ def __str__(self):
266266
return "SLogDet"
267267

268268

269-
def slogdet(x):
269+
def slogdet(x: ptb.TensorVariable) -> tuple[ptb.TensorVariable, ptb.TensorVariable]:
270+
"""
271+
This function simplfies the slogdet operation into 2 separate operations using directly the det op : sign(det_val) and log(abs(det_val))
272+
"""
270273
det_val = det(x)
271274
return ptm.sign(det_val), ptm.log(ptm.abs(det_val))
272275

pytensor/tensor/rewriting/linalg.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -956,6 +956,21 @@ def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply):
956956
@register_specialize
957957
@node_rewriter([det])
958958
def slogdet_specialization(fgraph, node):
959+
"""
960+
This rewrite targets specific operations related to slogdet i.e sign(det), log(det) and log(abs(det)) and rewrites them using the SLogDet operation.
961+
962+
Parameters
963+
----------
964+
fgraph: FunctionGraph
965+
Function graph being optimized
966+
node: Apply
967+
Node of the function graph to be optimized
968+
969+
Returns
970+
-------
971+
dictionary of Variables, optional
972+
Dictionary of nodes and what they should be replaced with, or None if no optimization was performed
973+
"""
959974
dummy_replacements = {}
960975
for client, _ in fgraph.clients[node.outputs[0]]:
961976
# Check for sign(det)
@@ -964,13 +979,16 @@ def slogdet_specialization(fgraph, node):
964979

965980
# Check for log(abs(det))
966981
elif isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Abs):
982+
potential_log = None
967983
for client_2, _ in fgraph.clients[client.outputs[0]]:
968984
if isinstance(client_2.op, Elemwise) and isinstance(
969985
client_2.op.scalar_op, Log
970986
):
971-
dummy_replacements[
972-
fgraph.clients[client.outputs[0]][0][0].outputs[0]
973-
] = "log_abs_det"
987+
potential_log = client_2
988+
if potential_log:
989+
dummy_replacements[potential_log.outputs[0]] = "log_abs_det"
990+
else:
991+
return None
974992

975993
# Check for log(det)
976994
elif isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Log):
@@ -980,15 +998,18 @@ def slogdet_specialization(fgraph, node):
980998
else:
981999
return None
9821000

983-
[x] = node.inputs
984-
sign_det_x, log_abs_det_x = SLogDet()(x)
985-
log_det_x = pt.where(pt.eq(sign_det_x, -1), np.nan, log_abs_det_x)
986-
slogdet_specialization_map = {
987-
"sign": sign_det_x,
988-
"log_abs_det": log_abs_det_x,
989-
"log_det": log_det_x,
990-
}
991-
replacements = {
992-
k: slogdet_specialization_map[v] for k, v in dummy_replacements.items()
993-
}
994-
return replacements or None
1001+
if not dummy_replacements:
1002+
return None
1003+
else:
1004+
[x] = node.inputs
1005+
sign_det_x, log_abs_det_x = SLogDet()(x)
1006+
log_det_x = pt.where(pt.eq(sign_det_x, -1), np.nan, log_abs_det_x)
1007+
slogdet_specialization_map = {
1008+
"sign": sign_det_x,
1009+
"log_abs_det": log_abs_det_x,
1010+
"log_det": log_det_x,
1011+
}
1012+
replacements = {
1013+
k: slogdet_specialization_map[v] for k, v in dummy_replacements.items()
1014+
}
1015+
return replacements

tests/tensor/rewriting/test_linalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -932,7 +932,7 @@ def test_rewrite_cholesky_diag_to_sqrt_diag_not_applied():
932932
assert any(isinstance(node.op, Cholesky) for node in nodes)
933933

934934

935-
def test_slogdet_specialisation():
935+
def test_slogdet_specialization():
936936
x, a = pt.dmatrix("x"), np.random.rand(20, 20)
937937
det_x, det_a = pt.linalg.det(x), np.linalg.det(a)
938938
log_abs_det_x, log_abs_det_a = pt.log(pt.abs(det_x)), np.log(np.abs(det_a))

0 commit comments

Comments
 (0)