70
70
TanhTransform ,
71
71
TransformValuesMapping ,
72
72
TransformValuesRewrite ,
73
- transformed_variable ,
74
73
)
75
74
from pymc .testing import assert_no_rvs
76
75
@@ -299,15 +298,15 @@ def a_backward_fn_(x):
299
298
exp_log_jac_val = np .linalg .slogdet (jacobian_val )[- 1 ]
300
299
301
300
log_jac_val = log_jac_fn (a_trans_value )
302
- np .testing .assert_almost_equal (exp_log_jac_val , log_jac_val , decimal = 4 )
301
+ np .testing .assert_allclose (exp_log_jac_val , log_jac_val , rtol = 1e-4 , atol = 1e-10 )
303
302
304
303
exp_logprob_val = a_dist .logpdf (a_val ).sum ()
305
304
exp_logprob_val += exp_log_jac_val .sum ()
306
305
exp_logprob_val += b_dist .logpdf (b_val ).sum ()
307
306
308
307
logprob_val = logp_vals_fn (a_trans_value , b_val )
309
308
310
- np .testing .assert_almost_equal (exp_logprob_val , logprob_val , decimal = 4 )
309
+ np .testing .assert_allclose (exp_logprob_val , logprob_val , rtol = 1e-4 , atol = 1e-10 )
311
310
312
311
313
312
@pytest .mark .parametrize ("use_jacobian" , [True , False ])
@@ -322,7 +321,7 @@ def test_simple_transformed_logprob_nojac(use_jacobian):
322
321
)
323
322
tr_logp_combined = pt .sum ([pt .sum (factor ) for factor in tr_logp .values ()])
324
323
325
- assert np .isclose (
324
+ np .testing . assert_allclose (
326
325
tr_logp_combined .eval ({x_vv : np .log (2.5 )}),
327
326
sp .stats .halfnorm (0 , 3 ).logpdf (2.5 ) + (np .log (2.5 ) if use_jacobian else 0.0 ),
328
327
)
@@ -443,7 +442,7 @@ def test_nondefault_transforms():
443
442
exp_logp += sp .stats .norm (loc_val , scale_val ).logpdf (x_val )
444
443
exp_logp += x_val_tr # log log_jac_det
445
444
446
- assert np .isclose (
445
+ np .testing . assert_allclose (
447
446
logp_combined .eval ({loc : loc_val , scale : scale_val_tr , x : x_val_tr }),
448
447
exp_logp ,
449
448
)
@@ -466,7 +465,7 @@ def test_default_transform_multiout():
466
465
)
467
466
logp_combined = pt .sum ([pt .sum (factor ) for factor in logp .values ()])
468
467
469
- assert np .isclose (
468
+ np .testing . assert_allclose (
470
469
logp_combined .eval ({x : 1 }),
471
470
sp .stats .norm (0 , 1 ).logpdf (1 ),
472
471
)
@@ -526,7 +525,7 @@ def test_nondefault_transform_multiout(transform_x, transform_y, multiout_measur
526
525
else :
527
526
expected_logp += np .log (y_vv_test ) + 2 - np .log (y_vv_test )
528
527
529
- np .testing .assert_almost_equal (
528
+ np .testing .assert_allclose (
530
529
logp_combined .eval ({x_vv : x_vv_test , y_vv : y_vv_test }), expected_logp
531
530
)
532
531
@@ -643,19 +642,19 @@ def test_chained_transform():
643
642
x_val = x .eval ()
644
643
645
644
x_val_forward = ch .forward (x_val , * x .owner .inputs ).eval ()
646
- assert np .allclose (
645
+ np .testing . assert_allclose (
647
646
x_val_forward ,
648
647
np .exp (x_val * scale ) + loc ,
649
648
)
650
649
651
650
x_val_backward = ch .backward (x_val_forward , * x .owner .inputs , scale , loc ).eval ()
652
- assert np .allclose (
651
+ np .testing . assert_allclose (
653
652
x_val_backward ,
654
653
x_val ,
655
654
)
656
655
657
656
log_jac_det = ch .log_jac_det (x_val_forward , * x .owner .inputs , scale , loc )
658
- assert np .isclose (
657
+ np .testing . assert_allclose (
659
658
pt .sum (log_jac_det ).eval (),
660
659
np .sum (- np .log (scale ) - np .log (x_val_forward - loc )),
661
660
)
@@ -767,7 +766,7 @@ def test_transformed_rv_and_value():
767
766
768
767
y_test_val = - 5
769
768
770
- assert np .isclose (
769
+ np .testing . assert_allclose (
771
770
logp_fn (y_test_val ),
772
771
sp .stats .halfnorm (0 , 1 ).logpdf (np .exp (y_test_val )) + y_test_val ,
773
772
)
@@ -830,7 +829,7 @@ def test_reciprocal_rv_transform(numerator):
830
829
x_logp_fn = pytensor .function ([x_vv ], logp (x_rv , x_vv ))
831
830
832
831
x_test_val = np .r_ [- 0.5 , 1.5 ]
833
- assert np .allclose (
832
+ np .testing . assert_allclose (
834
833
x_logp_fn (x_test_val ),
835
834
sp .stats .invgamma (shape , scale = scale * numerator ).logpdf (x_test_val ),
836
835
)
@@ -845,7 +844,7 @@ def test_sqr_transform():
845
844
x_logp_fn = pytensor .function ([x_vv ], logp (x_rv , x_vv ))
846
845
847
846
x_test_val = np .r_ [- 0.5 , 0.5 , 1 , 2.5 ]
848
- assert np .allclose (
847
+ np .testing . assert_allclose (
849
848
x_logp_fn (x_test_val ),
850
849
sp .stats .chi2 (df = 1 ).logpdf (x_test_val ),
851
850
)
@@ -860,7 +859,7 @@ def test_sqrt_transform():
860
859
x_logp_fn = pytensor .function ([x_vv ], logp (x_rv , x_vv ))
861
860
862
861
x_test_val = np .r_ [- 2.5 , 0.5 , 1 , 2.5 ]
863
- assert np .allclose (
862
+ np .testing . assert_allclose (
864
863
x_logp_fn (x_test_val ),
865
864
sp .stats .chi (df = 3 ).logpdf (x_test_val ),
866
865
)
@@ -915,7 +914,7 @@ def test_absolute_transform(test_val):
915
914
x_logp_fn = pytensor .function ([x_vv ], logp (x_rv , x_vv ))
916
915
y_logp_fn = pytensor .function ([y_vv ], logp (y_rv , y_vv ))
917
916
918
- assert np .allclose (x_logp_fn (test_val ), y_logp_fn (test_val ))
917
+ np .testing . assert_allclose (x_logp_fn (test_val ), y_logp_fn (test_val ))
919
918
920
919
921
920
def test_negated_rv_transform ():
@@ -925,7 +924,7 @@ def test_negated_rv_transform():
925
924
x_vv = x_rv .clone ()
926
925
x_logp_fn = pytensor .function ([x_vv ], pt .sum (logp (x_rv , x_vv )))
927
926
928
- assert np .isclose (x_logp_fn (- 1.5 ), sp .stats .halfnorm .logpdf (1.5 ))
927
+ np .testing . assert_allclose (x_logp_fn (- 1.5 ), sp .stats .halfnorm .logpdf (1.5 ))
929
928
930
929
931
930
def test_subtracted_rv_transform ():
@@ -936,7 +935,7 @@ def test_subtracted_rv_transform():
936
935
x_vv = x_rv .clone ()
937
936
x_logp_fn = pytensor .function ([x_vv ], pt .sum (logp (x_rv , x_vv )))
938
937
939
- assert np .isclose (x_logp_fn (7.3 ), sp .stats .norm .logpdf (5.0 - 7.3 , 1.0 ))
938
+ np .testing . assert_allclose (x_logp_fn (7.3 ), sp .stats .norm .logpdf (5.0 - 7.3 , 1.0 ))
940
939
941
940
942
941
def test_scan_transform ():
@@ -1012,7 +1011,7 @@ def test_multivariate_transform(shift, scale):
1012
1011
1013
1012
x_vv_test = np .array ([5.0 , 4.9 , - 6.3 ])
1014
1013
scale_mat = scale * np .eye (x_vv_test .shape [0 ])
1015
- np .testing .assert_almost_equal (
1014
+ np .testing .assert_allclose (
1016
1015
logp .eval ({x_vv : x_vv_test }),
1017
1016
sp .stats .multivariate_normal .logpdf (
1018
1017
x_vv_test ,
@@ -1048,7 +1047,7 @@ def test_erf_logp(pt_transform, transform):
1048
1047
expected_logp = logp (base_rv , transform .backward (vv )) + transform .log_jac_det (vv )
1049
1048
1050
1049
vv_test = np .array (0.25 ) # Arbitrary test value
1051
- np .testing .assert_almost_equal (
1050
+ np .testing .assert_allclose (
1052
1051
rv_logp .eval ({vv : vv_test }), np .nan_to_num (expected_logp .eval ({vv : vv_test }), nan = - np .inf )
1053
1052
)
1054
1053
@@ -1088,8 +1087,8 @@ def test_logcdf_measurable_transform():
1088
1087
logcdf_fn = pytensor .function ([value ], logcdf (x , value ))
1089
1088
1090
1089
assert logcdf_fn (0 ) == - np .inf
1091
- np .testing .assert_almost_equal (logcdf_fn (np .exp (0.5 )), np .log (0.5 ))
1092
- np .testing .assert_almost_equal (logcdf_fn (5 ), 0 )
1090
+ np .testing .assert_allclose (logcdf_fn (np .exp (0.5 )), np .log (0.5 ))
1091
+ np .testing .assert_allclose (logcdf_fn (5 ), 0 )
1093
1092
1094
1093
1095
1094
def test_logcdf_measurable_non_injective_fails ():
@@ -1104,9 +1103,9 @@ def test_icdf_measurable_transform():
1104
1103
value = x .type ()
1105
1104
icdf_fn = pytensor .function ([value ], icdf (x , value ))
1106
1105
1107
- np .testing .assert_almost_equal (icdf_fn (1e-16 ), 1 )
1108
- np .testing .assert_almost_equal (icdf_fn (0.5 ), np .exp (0.5 ))
1109
- np .testing .assert_almost_equal (icdf_fn (1 - 1e-16 ), np .e )
1106
+ np .testing .assert_allclose (icdf_fn (1e-16 ), 1 )
1107
+ np .testing .assert_allclose (icdf_fn (0.5 ), np .exp (0.5 ))
1108
+ np .testing .assert_allclose (icdf_fn (1 - 1e-16 ), np .e )
1110
1109
1111
1110
1112
1111
def test_icdf_measurable_non_injective_fails ():
0 commit comments