@@ -956,6 +956,21 @@ def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply):
956956@register_specialize
957957@node_rewriter ([det ])
958958def slogdet_specialization (fgraph , node ):
959+ """
960+ This rewrite targets specific operations related to slogdet i.e sign(det), log(det) and log(abs(det)) and rewrites them using the SLogDet operation.
961+
962+ Parameters
963+ ----------
964+ fgraph: FunctionGraph
965+ Function graph being optimized
966+ node: Apply
967+ Node of the function graph to be optimized
968+
969+ Returns
970+ -------
971+ dictionary of Variables, optional
972+ Dictionary of nodes and what they should be replaced with, or None if no optimization was performed
973+ """
959974 dummy_replacements = {}
960975 for client , _ in fgraph .clients [node .outputs [0 ]]:
961976 # Check for sign(det)
@@ -964,13 +979,16 @@ def slogdet_specialization(fgraph, node):
964979
965980 # Check for log(abs(det))
966981 elif isinstance (client .op , Elemwise ) and isinstance (client .op .scalar_op , Abs ):
982+ potential_log = None
967983 for client_2 , _ in fgraph .clients [client .outputs [0 ]]:
968984 if isinstance (client_2 .op , Elemwise ) and isinstance (
969985 client_2 .op .scalar_op , Log
970986 ):
971- dummy_replacements [
972- fgraph .clients [client .outputs [0 ]][0 ][0 ].outputs [0 ]
973- ] = "log_abs_det"
987+ potential_log = client_2
988+ if potential_log :
989+ dummy_replacements [potential_log .outputs [0 ]] = "log_abs_det"
990+ else :
991+ return None
974992
975993 # Check for log(det)
976994 elif isinstance (client .op , Elemwise ) and isinstance (client .op .scalar_op , Log ):
@@ -980,15 +998,18 @@ def slogdet_specialization(fgraph, node):
980998 else :
981999 return None
9821000
983- [x ] = node .inputs
984- sign_det_x , log_abs_det_x = SLogDet ()(x )
985- log_det_x = pt .where (pt .eq (sign_det_x , - 1 ), np .nan , log_abs_det_x )
986- slogdet_specialization_map = {
987- "sign" : sign_det_x ,
988- "log_abs_det" : log_abs_det_x ,
989- "log_det" : log_det_x ,
990- }
991- replacements = {
992- k : slogdet_specialization_map [v ] for k , v in dummy_replacements .items ()
993- }
994- return replacements or None
1001+ if not dummy_replacements :
1002+ return None
1003+ else :
1004+ [x ] = node .inputs
1005+ sign_det_x , log_abs_det_x = SLogDet ()(x )
1006+ log_det_x = pt .where (pt .eq (sign_det_x , - 1 ), np .nan , log_abs_det_x )
1007+ slogdet_specialization_map = {
1008+ "sign" : sign_det_x ,
1009+ "log_abs_det" : log_abs_det_x ,
1010+ "log_det" : log_det_x ,
1011+ }
1012+ replacements = {
1013+ k : slogdet_specialization_map [v ] for k , v in dummy_replacements .items ()
1014+ }
1015+ return replacements
0 commit comments