Skip to content

Commit bb3bda2

Browse files
committed
Use numpy.testing.assert_allclose in test_transforms.py
1 parent accabdf commit bb3bda2

File tree

1 file changed

+23
-24
lines changed

1 file changed

+23
-24
lines changed

tests/logprob/test_transforms.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@
7070
TanhTransform,
7171
TransformValuesMapping,
7272
TransformValuesRewrite,
73-
transformed_variable,
7473
)
7574
from pymc.testing import assert_no_rvs
7675

@@ -299,15 +298,15 @@ def a_backward_fn_(x):
299298
exp_log_jac_val = np.linalg.slogdet(jacobian_val)[-1]
300299

301300
log_jac_val = log_jac_fn(a_trans_value)
302-
np.testing.assert_almost_equal(exp_log_jac_val, log_jac_val, decimal=4)
301+
np.testing.assert_allclose(exp_log_jac_val, log_jac_val, rtol=1e-4, atol=1e-10)
303302

304303
exp_logprob_val = a_dist.logpdf(a_val).sum()
305304
exp_logprob_val += exp_log_jac_val.sum()
306305
exp_logprob_val += b_dist.logpdf(b_val).sum()
307306

308307
logprob_val = logp_vals_fn(a_trans_value, b_val)
309308

310-
np.testing.assert_almost_equal(exp_logprob_val, logprob_val, decimal=4)
309+
np.testing.assert_allclose(exp_logprob_val, logprob_val, rtol=1e-4, atol=1e-10)
311310

312311

313312
@pytest.mark.parametrize("use_jacobian", [True, False])
@@ -322,7 +321,7 @@ def test_simple_transformed_logprob_nojac(use_jacobian):
322321
)
323322
tr_logp_combined = pt.sum([pt.sum(factor) for factor in tr_logp.values()])
324323

325-
assert np.isclose(
324+
np.testing.assert_allclose(
326325
tr_logp_combined.eval({x_vv: np.log(2.5)}),
327326
sp.stats.halfnorm(0, 3).logpdf(2.5) + (np.log(2.5) if use_jacobian else 0.0),
328327
)
@@ -443,7 +442,7 @@ def test_nondefault_transforms():
443442
exp_logp += sp.stats.norm(loc_val, scale_val).logpdf(x_val)
444443
exp_logp += x_val_tr # log log_jac_det
445444

446-
assert np.isclose(
445+
np.testing.assert_allclose(
447446
logp_combined.eval({loc: loc_val, scale: scale_val_tr, x: x_val_tr}),
448447
exp_logp,
449448
)
@@ -466,7 +465,7 @@ def test_default_transform_multiout():
466465
)
467466
logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()])
468467

469-
assert np.isclose(
468+
np.testing.assert_allclose(
470469
logp_combined.eval({x: 1}),
471470
sp.stats.norm(0, 1).logpdf(1),
472471
)
@@ -526,7 +525,7 @@ def test_nondefault_transform_multiout(transform_x, transform_y, multiout_measur
526525
else:
527526
expected_logp += np.log(y_vv_test) + 2 - np.log(y_vv_test)
528527

529-
np.testing.assert_almost_equal(
528+
np.testing.assert_allclose(
530529
logp_combined.eval({x_vv: x_vv_test, y_vv: y_vv_test}), expected_logp
531530
)
532531

@@ -643,19 +642,19 @@ def test_chained_transform():
643642
x_val = x.eval()
644643

645644
x_val_forward = ch.forward(x_val, *x.owner.inputs).eval()
646-
assert np.allclose(
645+
np.testing.assert_allclose(
647646
x_val_forward,
648647
np.exp(x_val * scale) + loc,
649648
)
650649

651650
x_val_backward = ch.backward(x_val_forward, *x.owner.inputs, scale, loc).eval()
652-
assert np.allclose(
651+
np.testing.assert_allclose(
653652
x_val_backward,
654653
x_val,
655654
)
656655

657656
log_jac_det = ch.log_jac_det(x_val_forward, *x.owner.inputs, scale, loc)
658-
assert np.isclose(
657+
np.testing.assert_allclose(
659658
pt.sum(log_jac_det).eval(),
660659
np.sum(-np.log(scale) - np.log(x_val_forward - loc)),
661660
)
@@ -767,7 +766,7 @@ def test_transformed_rv_and_value():
767766

768767
y_test_val = -5
769768

770-
assert np.isclose(
769+
np.testing.assert_allclose(
771770
logp_fn(y_test_val),
772771
sp.stats.halfnorm(0, 1).logpdf(np.exp(y_test_val)) + y_test_val,
773772
)
@@ -830,7 +829,7 @@ def test_reciprocal_rv_transform(numerator):
830829
x_logp_fn = pytensor.function([x_vv], logp(x_rv, x_vv))
831830

832831
x_test_val = np.r_[-0.5, 1.5]
833-
assert np.allclose(
832+
np.testing.assert_allclose(
834833
x_logp_fn(x_test_val),
835834
sp.stats.invgamma(shape, scale=scale * numerator).logpdf(x_test_val),
836835
)
@@ -845,7 +844,7 @@ def test_sqr_transform():
845844
x_logp_fn = pytensor.function([x_vv], logp(x_rv, x_vv))
846845

847846
x_test_val = np.r_[-0.5, 0.5, 1, 2.5]
848-
assert np.allclose(
847+
np.testing.assert_allclose(
849848
x_logp_fn(x_test_val),
850849
sp.stats.chi2(df=1).logpdf(x_test_val),
851850
)
@@ -860,7 +859,7 @@ def test_sqrt_transform():
860859
x_logp_fn = pytensor.function([x_vv], logp(x_rv, x_vv))
861860

862861
x_test_val = np.r_[-2.5, 0.5, 1, 2.5]
863-
assert np.allclose(
862+
np.testing.assert_allclose(
864863
x_logp_fn(x_test_val),
865864
sp.stats.chi(df=3).logpdf(x_test_val),
866865
)
@@ -915,7 +914,7 @@ def test_absolute_transform(test_val):
915914
x_logp_fn = pytensor.function([x_vv], logp(x_rv, x_vv))
916915
y_logp_fn = pytensor.function([y_vv], logp(y_rv, y_vv))
917916

918-
assert np.allclose(x_logp_fn(test_val), y_logp_fn(test_val))
917+
np.testing.assert_allclose(x_logp_fn(test_val), y_logp_fn(test_val))
919918

920919

921920
def test_negated_rv_transform():
@@ -925,7 +924,7 @@ def test_negated_rv_transform():
925924
x_vv = x_rv.clone()
926925
x_logp_fn = pytensor.function([x_vv], pt.sum(logp(x_rv, x_vv)))
927926

928-
assert np.isclose(x_logp_fn(-1.5), sp.stats.halfnorm.logpdf(1.5))
927+
np.testing.assert_allclose(x_logp_fn(-1.5), sp.stats.halfnorm.logpdf(1.5))
929928

930929

931930
def test_subtracted_rv_transform():
@@ -936,7 +935,7 @@ def test_subtracted_rv_transform():
936935
x_vv = x_rv.clone()
937936
x_logp_fn = pytensor.function([x_vv], pt.sum(logp(x_rv, x_vv)))
938937

939-
assert np.isclose(x_logp_fn(7.3), sp.stats.norm.logpdf(5.0 - 7.3, 1.0))
938+
np.testing.assert_allclose(x_logp_fn(7.3), sp.stats.norm.logpdf(5.0 - 7.3, 1.0))
940939

941940

942941
def test_scan_transform():
@@ -1012,7 +1011,7 @@ def test_multivariate_transform(shift, scale):
10121011

10131012
x_vv_test = np.array([5.0, 4.9, -6.3])
10141013
scale_mat = scale * np.eye(x_vv_test.shape[0])
1015-
np.testing.assert_almost_equal(
1014+
np.testing.assert_allclose(
10161015
logp.eval({x_vv: x_vv_test}),
10171016
sp.stats.multivariate_normal.logpdf(
10181017
x_vv_test,
@@ -1048,7 +1047,7 @@ def test_erf_logp(pt_transform, transform):
10481047
expected_logp = logp(base_rv, transform.backward(vv)) + transform.log_jac_det(vv)
10491048

10501049
vv_test = np.array(0.25) # Arbitrary test value
1051-
np.testing.assert_almost_equal(
1050+
np.testing.assert_allclose(
10521051
rv_logp.eval({vv: vv_test}), np.nan_to_num(expected_logp.eval({vv: vv_test}), nan=-np.inf)
10531052
)
10541053

@@ -1088,8 +1087,8 @@ def test_logcdf_measurable_transform():
10881087
logcdf_fn = pytensor.function([value], logcdf(x, value))
10891088

10901089
assert logcdf_fn(0) == -np.inf
1091-
np.testing.assert_almost_equal(logcdf_fn(np.exp(0.5)), np.log(0.5))
1092-
np.testing.assert_almost_equal(logcdf_fn(5), 0)
1090+
np.testing.assert_allclose(logcdf_fn(np.exp(0.5)), np.log(0.5))
1091+
np.testing.assert_allclose(logcdf_fn(5), 0)
10931092

10941093

10951094
def test_logcdf_measurable_non_injective_fails():
@@ -1104,9 +1103,9 @@ def test_icdf_measurable_transform():
11041103
value = x.type()
11051104
icdf_fn = pytensor.function([value], icdf(x, value))
11061105

1107-
np.testing.assert_almost_equal(icdf_fn(1e-16), 1)
1108-
np.testing.assert_almost_equal(icdf_fn(0.5), np.exp(0.5))
1109-
np.testing.assert_almost_equal(icdf_fn(1 - 1e-16), np.e)
1106+
np.testing.assert_allclose(icdf_fn(1e-16), 1)
1107+
np.testing.assert_allclose(icdf_fn(0.5), np.exp(0.5))
1108+
np.testing.assert_allclose(icdf_fn(1 - 1e-16), np.e)
11101109

11111110

11121111
def test_icdf_measurable_non_injective_fails():

0 commit comments

Comments
 (0)