@@ -781,7 +781,7 @@ def test_det_kronecker_rewrite():
781781 a , b = pt .dmatrices ("a" , "b" )
782782 kron_prod = pt .linalg .kron (a , b )
783783 det_output = pt .linalg .det (kron_prod )
784- f_rewritten = function ([kron_prod ], [det_output ], mode = "FAST_RUN" )
784+ f_rewritten = function ([a , b ], [det_output ], mode = "FAST_RUN" )
785785
786786 # Rewrite Test
787787 nodes = f_rewritten .maker .fgraph .apply_nodes
@@ -791,7 +791,7 @@ def test_det_kronecker_rewrite():
791791 a_test , b_test = np .random .rand (2 , 20 , 20 )
792792 kron_prod_test = np .kron (a_test , b_test )
793793 det_output_test = np .linalg .det (kron_prod_test )
794- rewritten_det_val = f_rewritten (kron_prod_test )
794+ rewritten_det_val = f_rewritten (a_test , b_test )
795795 assert_allclose (
796796 det_output_test ,
797797 rewritten_det_val ,
@@ -800,6 +800,35 @@ def test_det_kronecker_rewrite():
800800 )
801801
802802
803+ def test_slogdet_kronecker_rewrite ():
804+ a , b = pt .dmatrices ("a" , "b" )
805+ kron_prod = pt .linalg .kron (a , b )
806+ sign_output , logdet_output = pt .linalg .slogdet (kron_prod )
807+ f_rewritten = function ([a , b ], [sign_output , logdet_output ], mode = "FAST_RUN" )
808+
809+ # Rewrite Test
810+ nodes = f_rewritten .maker .fgraph .apply_nodes
811+ assert not any (isinstance (node .op , KroneckerProduct ) for node in nodes )
812+
813+ # Value Test
814+ a_test , b_test = np .random .rand (2 , 20 , 20 )
815+ kron_prod_test = np .kron (a_test , b_test )
816+ sign_output_test , logdet_output_test = np .linalg .slogdet (kron_prod_test )
817+ rewritten_sign_val , rewritten_logdet_val = f_rewritten (a_test , b_test )
818+ assert_allclose (
819+ sign_output_test ,
820+ rewritten_sign_val ,
821+ atol = 1e-3 if config .floatX == "float32" else 1e-8 ,
822+ rtol = 1e-3 if config .floatX == "float32" else 1e-8 ,
823+ )
824+ assert_allclose (
825+ logdet_output_test ,
826+ rewritten_logdet_val ,
827+ atol = 1e-3 if config .floatX == "float32" else 1e-8 ,
828+ rtol = 1e-3 if config .floatX == "float32" else 1e-8 ,
829+ )
830+
831+
803832def test_cholesky_eye_rewrite ():
804833 x = pt .eye (10 )
805834 L = pt .linalg .cholesky (x )
@@ -904,20 +933,60 @@ def test_rewrite_cholesky_diag_to_sqrt_diag_not_applied():
904933
905934
906935def test_slogdet_specialisation ():
907- x = pt .dmatrix ("x" )
908- det_x = pt .linalg .det (x )
909- log_abs_det_x = pt .log (pt .abs (det_x ))
910- sign_det_x = pt .sign (det_x )
936+ x , a = pt .dmatrix ("x" ), np .random .rand (20 , 20 )
937+ det_x , det_a = pt .linalg .det (x ), np .linalg .det (a )
938+ log_abs_det_x , log_abs_det_a = pt .log (pt .abs (det_x )), np .log (np .abs (det_a ))
939+ log_det_x , log_det_a = pt .log (det_x ), np .log (det_a )
940+ sign_det_x , sign_det_a = pt .sign (det_x ), np .sign (det_a )
911941 exp_det_x = pt .exp (det_x )
942+ # REWRITE TESTS
912943 # sign(det(x))
913944 f = function ([x ], [sign_det_x ], mode = "FAST_RUN" )
914945 nodes = f .maker .fgraph .apply_nodes
915- assert any (isinstance (node .op , SLogDet ) for node in nodes )
946+ assert len ([node for node in nodes if isinstance (node .op , SLogDet )]) == 1
947+ assert not any (isinstance (node .op , Det ) for node in nodes )
948+ rw_sign_det_a = f (a )
949+ assert_allclose (
950+ sign_det_a ,
951+ rw_sign_det_a ,
952+ atol = 1e-3 if config .floatX == "float32" else 1e-8 ,
953+ rtol = 1e-3 if config .floatX == "float32" else 1e-8 ,
954+ )
916955 # log(abs(det(x)))
917956 f = function ([x ], [log_abs_det_x ], mode = "FAST_RUN" )
918957 nodes = f .maker .fgraph .apply_nodes
919- assert any (isinstance (node .op , SLogDet ) for node in nodes )
958+ assert len ([node for node in nodes if isinstance (node .op , SLogDet )]) == 1
959+ assert not any (isinstance (node .op , Det ) for node in nodes )
960+ rw_log_abs_det_a = f (a )
961+ assert_allclose (
962+ log_abs_det_a ,
963+ rw_log_abs_det_a ,
964+ atol = 1e-3 if config .floatX == "float32" else 1e-8 ,
965+ rtol = 1e-3 if config .floatX == "float32" else 1e-8 ,
966+ )
967+ # log(det(x))
968+ f = function ([x ], [log_det_x ], mode = "FAST_RUN" )
969+ nodes = f .maker .fgraph .apply_nodes
970+ assert len ([node for node in nodes if isinstance (node .op , SLogDet )]) == 1
971+ assert not any (isinstance (node .op , Det ) for node in nodes )
972+ rw_log_det_a = f (a )
973+ assert_allclose (
974+ log_det_a ,
975+ rw_log_det_a ,
976+ atol = 1e-3 if config .floatX == "float32" else 1e-8 ,
977+ rtol = 1e-3 if config .floatX == "float32" else 1e-8 ,
978+ )
979+ # more than 1 valid function
980+ f = function ([x ], [sign_det_x , log_abs_det_x ], mode = "FAST_RUN" )
981+ nodes = f .maker .fgraph .apply_nodes
982+ assert len ([node for node in nodes if isinstance (node .op , SLogDet )]) == 1
983+ assert not any (isinstance (node .op , Det ) for node in nodes )
920984 # other functions (rewrite shouldnt be applied to these)
985+ # only invalid functions
921986 f = function ([x ], [exp_det_x ], mode = "FAST_RUN" )
922987 nodes = f .maker .fgraph .apply_nodes
923988 assert not any (isinstance (node .op , SLogDet ) for node in nodes )
989+ # invalid + valid function
990+ f = function ([x ], [exp_det_x , sign_det_x ], mode = "FAST_RUN" )
991+ nodes = f .maker .fgraph .apply_nodes
992+ assert not any (isinstance (node .op , SLogDet ) for node in nodes )
0 commit comments