Skip to content

Commit 55da81d

Browse files
Junpeng LaoColCarroll
authored andcommitted
Add LogExpM1 transformation (#2601)
* Add softplus transformation Softplus transformation (from non-negative to reals) might be more numerically stable (see Fig. 9 in Kucukelbir et al. 2017). * add test * Change name and implement a more numerically stable logexpm1 See https://github.com/tensorflow/tensorflow/blob/0b0d3c12ace80381f4a44365d30275a9a262609b/tensorflow/python/ops/distributions/util.py#L1009 for the derivation * change default transformation for PositiveContinuous * Revert "change default transformation for PositiveContinuous" This reverts commit 8bc036c. * name change
1 parent 0bb0ad1 commit 55da81d

File tree

2 files changed

+33
-4
lines changed

2 files changed

+33
-4
lines changed

pymc3/distributions/transforms.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .distribution import draw_values
99
import numpy as np
1010

11-
__all__ = ['transform', 'stick_breaking', 'logodds', 'interval',
11+
__all__ = ['transform', 'stick_breaking', 'logodds', 'interval', 'log_exp_m1',
1212
'lowerbound', 'upperbound', 'log', 'sum_to_1', 't_stick_breaking']
1313

1414

@@ -105,12 +105,31 @@ def jacobian_det(self, x):
105105
log = Log()
106106

107107

108+
class LogExpM1(ElemwiseTransform):
109+
name = "log_exp_m1"
110+
111+
def backward(self, x):
112+
return tt.nnet.softplus(x)
113+
114+
def forward(self, x):
115+
"""Inverse operation of softplus
116+
y = Log(Exp(x) - 1)
117+
= Log(1 - Exp(-x)) + x
118+
"""
119+
return tt.log(1.-tt.exp(-x)) + x
120+
121+
def forward_val(self, x, point=None):
122+
return self.forward(x)
123+
124+
def jacobian_det(self, x):
125+
return -tt.nnet.softplus(-x)
126+
127+
log_exp_m1 = LogExpM1()
128+
129+
108130
class LogOdds(ElemwiseTransform):
109131
name = "logodds"
110132

111-
def __init__(self):
112-
pass
113-
114133
def backward(self, x):
115134
return invlogit(x, 0.0)
116135

pymc3/tests/test_transforms.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,16 @@ def test_log():
104104
close_to_logical(vals > 0, True, tol)
105105

106106

107+
def test_log_exp_m1():
108+
check_transform_identity(tr.log_exp_m1, Rplusbig)
109+
check_jacobian_det(tr.log_exp_m1, Rplusbig, elemwise=True)
110+
check_jacobian_det(tr.log_exp_m1, Vector(Rplusbig, 2),
111+
tt.dvector, [0, 0], elemwise=True)
112+
113+
vals = get_values(tr.log_exp_m1)
114+
close_to_logical(vals > 0, True, tol)
115+
116+
107117
def test_logodds():
108118
check_transform_identity(tr.logodds, Unit)
109119
check_jacobian_det(tr.logodds, Unit, elemwise=True)

0 commit comments

Comments
 (0)