Skip to content

Commit 15b41f0

Browse files
committed
Fix probability of cosh transform
1 parent 745b444 commit 15b41f0

File tree

3 files changed

+36
-4
lines changed

3 files changed

+36
-4
lines changed

pymc/logprob/transforms.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -679,7 +679,15 @@ def forward(self, value, *inputs):
679679
return pt.cosh(value)
680680

681681
def backward(self, value, *inputs):
682-
return pt.arccosh(value)
682+
back_value = pt.arccosh(value)
683+
return (-back_value, back_value)
684+
685+
def log_jac_det(self, value, *inputs):
686+
return pt.switch(
687+
value < 1,
688+
np.nan,
689+
-pt.log(pt.sqrt(value**2 - 1)),
690+
)
683691

684692

685693
class TanhTransform(RVTransform):

tests/distributions/test_transform.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ def check_jacobian_det(
9797
rv_inputs = rv_var.owner.inputs if rv_var.owner else []
9898

9999
x = transform.backward(y, *rv_inputs)
100+
# Assume non-injective transforms are symmetric around the origin
101+
if isinstance(x, tuple):
102+
x = x[-1]
100103
if make_comparable:
101104
x = make_comparable(x)
102105

tests/logprob/test_transforms.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -991,14 +991,13 @@ def test_absolute_rv_transform(test_val):
991991
(pt.erfc, ErfcTransform()),
992992
(pt.erfcx, ErfcxTransform()),
993993
(pt.sinh, SinhTransform()),
994-
(pt.cosh, CoshTransform()),
995994
(pt.tanh, TanhTransform()),
996995
(pt.arcsinh, ArcsinhTransform()),
997996
(pt.arccosh, ArccoshTransform()),
998997
(pt.arctanh, ArctanhTransform()),
999998
],
1000999
)
1001-
def test_extra_rv_transforms(pt_transform, transform):
1000+
def test_extra_bijective_rv_transforms(pt_transform, transform):
10021001
base_rv = pt.random.normal(
10031002
0.5, 1, name="base_rv"
10041003
) # Something not centered around 0 is usually better
@@ -1011,7 +1010,29 @@ def test_extra_rv_transforms(pt_transform, transform):
10111010

10121011
vv_test = np.array(0.25) # Arbitrary test value
10131012
np.testing.assert_allclose(
1014-
rv_logp.eval({vv: vv_test}), np.nan_to_num(expected_logp.eval({vv: vv_test}), nan=-np.inf)
1013+
rv_logp.eval({vv: vv_test}),
1014+
np.nan_to_num(expected_logp.eval({vv: vv_test}), nan=-np.inf),
1015+
)
1016+
1017+
1018+
def test_cosh_rv_transform():
1019+
# Something not centered around 0 is usually better
1020+
base_rv = pt.random.normal(0.5, 1, size=(2,), name="base_rv")
1021+
rv = pt.cosh(base_rv)
1022+
1023+
vv = rv.clone()
1024+
rv_logp = logp(rv, vv)
1025+
1026+
transform = CoshTransform()
1027+
[back_neg, back_pos] = transform.backward(vv)
1028+
expected_logp = pt.logaddexp(
1029+
logp(base_rv, back_neg), logp(base_rv, back_pos)
1030+
) + transform.log_jac_det(vv)
1031+
1032+
vv_test = np.array([0.25, 1.5])
1033+
np.testing.assert_allclose(
1034+
rv_logp.eval({vv: vv_test}),
1035+
np.nan_to_num(expected_logp.eval({vv: vv_test}), nan=-np.inf),
10151036
)
10161037

10171038

0 commit comments

Comments
 (0)