@@ -939,6 +939,7 @@ def test_slogdet_specialization():
939939 log_det_x , log_det_a = pt .log (det_x ), np .log (det_a )
940940 sign_det_x , sign_det_a = pt .sign (det_x ), np .sign (det_a )
941941 exp_det_x = pt .exp (det_x )
942+
942943 # REWRITE TESTS
943944 # sign(det(x))
944945 f = function ([x ], [sign_det_x ], mode = "FAST_RUN" )
@@ -952,6 +953,7 @@ def test_slogdet_specialization():
952953 atol = 1e-3 if config .floatX == "float32" else 1e-8 ,
953954 rtol = 1e-3 if config .floatX == "float32" else 1e-8 ,
954955 )
956+
955957 # log(abs(det(x)))
956958 f = function ([x ], [log_abs_det_x ], mode = "FAST_RUN" )
957959 nodes = f .maker .fgraph .apply_nodes
@@ -964,6 +966,7 @@ def test_slogdet_specialization():
964966 atol = 1e-3 if config .floatX == "float32" else 1e-8 ,
965967 rtol = 1e-3 if config .floatX == "float32" else 1e-8 ,
966968 )
969+
967970 # log(det(x))
968971 f = function ([x ], [log_det_x ], mode = "FAST_RUN" )
969972 nodes = f .maker .fgraph .apply_nodes
@@ -976,17 +979,20 @@ def test_slogdet_specialization():
976979 atol = 1e-3 if config .floatX == "float32" else 1e-8 ,
977980 rtol = 1e-3 if config .floatX == "float32" else 1e-8 ,
978981 )
979- # more than 1 valid function
982+
983+ # More than 1 valid function
980984 f = function ([x ], [sign_det_x , log_abs_det_x ], mode = "FAST_RUN" )
981985 nodes = f .maker .fgraph .apply_nodes
982986 assert len ([node for node in nodes if isinstance (node .op , SLogDet )]) == 1
983987 assert not any (isinstance (node .op , Det ) for node in nodes )
984- # other functions (rewrite shouldnt be applied to these)
985- # only invalid functions
988+
989+ # Other functions (rewrite shouldnt be applied to these)
990+ # Only invalid functions
986991 f = function ([x ], [exp_det_x ], mode = "FAST_RUN" )
987992 nodes = f .maker .fgraph .apply_nodes
988993 assert not any (isinstance (node .op , SLogDet ) for node in nodes )
989- # invalid + valid function
994+
995+ # Invalid + Valid function
990996 f = function ([x ], [exp_det_x , sign_det_x ], mode = "FAST_RUN" )
991997 nodes = f .maker .fgraph .apply_nodes
992998 assert not any (isinstance (node .op , SLogDet ) for node in nodes )
0 commit comments