Skip to content

Commit cc4e034

Browse files
author
Junpeng Lao
authored
implement forward_val in transform as a numpy function (#2920)
* implement `forward_val` in transform as a numpy function it is faster to do forward transformation in numpy, could be useful for sampling from prior. * add test * change transform useage in SMC * rename test
1 parent 33d0ab7 commit cc4e034

File tree

4 files changed

+46
-31
lines changed

4 files changed

+46
-31
lines changed

pymc3/distributions/transforms.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from ..math import logit, invlogit
88
from .distribution import draw_values
99
import numpy as np
10+
from scipy.special import logit as nplogit
1011

1112
__all__ = ['transform', 'stick_breaking', 'logodds', 'interval', 'log_exp_m1',
1213
'lowerbound', 'upperbound', 'log', 'sum_to_1', 't_stick_breaking']
@@ -97,7 +98,7 @@ def forward(self, x):
9798
return tt.log(x)
9899

99100
def forward_val(self, x, point=None):
100-
return self.forward(x)
101+
return np.log(x)
101102

102103
def jacobian_det(self, x):
103104
return x
@@ -119,7 +120,7 @@ def forward(self, x):
119120
return tt.log(1.-tt.exp(-x)) + x
120121

121122
def forward_val(self, x, point=None):
122-
return self.forward(x)
123+
return np.log(1.-np.exp(-x)) + x
123124

124125
def jacobian_det(self, x):
125126
return -tt.nnet.softplus(-x)
@@ -137,7 +138,7 @@ def forward(self, x):
137138
return logit(x)
138139

139140
def forward_val(self, x, point=None):
140-
return self.forward(x)
141+
return nplogit(x)
141142

142143
logodds = LogOdds()
143144

@@ -166,7 +167,7 @@ def forward_val(self, x, point=None):
166167
# For an explanation see pull/2328#issuecomment-309303811
167168
a, b = draw_values([self.a-0., self.b-0.],
168169
point=point)
169-
return floatX(tt.log(x - a) - tt.log(b - x))
170+
return floatX(np.log(x - a) - np.log(b - x))
170171

171172
def jacobian_det(self, x):
172173
s = tt.nnet.softplus(-x)
@@ -198,7 +199,7 @@ def forward_val(self, x, point=None):
198199
# For an explanation see pull/2328#issuecomment-309303811
199200
a = draw_values([self.a-0.],
200201
point=point)[0]
201-
return floatX(tt.log(x - a))
202+
return floatX(np.log(x - a))
202203

203204
def jacobian_det(self, x):
204205
return x
@@ -229,7 +230,7 @@ def forward_val(self, x, point=None):
229230
# For an explanation see pull/2328#issuecomment-309303811
230231
b = draw_values([self.b-0.],
231232
point=point)[0]
232-
return floatX(tt.log(b - x))
233+
return floatX(np.log(b - x))
233234

234235
def jacobian_det(self, x):
235236
return x
@@ -249,7 +250,7 @@ def forward(self, x):
249250
return x[:-1]
250251

251252
def forward_val(self, x, point=None):
252-
return self.forward(x)
253+
return x[:-1]
253254

254255
def jacobian_det(self, x):
255256
return 0
@@ -284,8 +285,17 @@ def forward(self, x_):
284285
y = logit(z) - eq_share
285286
return floatX(y.T)
286287

287-
def forward_val(self, x, point=None):
288-
return self.forward(x)
288+
def forward_val(self, x_, point=None):
289+
x = x_.T
290+
# reverse cumsum
291+
x0 = x[:-1]
292+
s = np.cumsum(x0[::-1], 0)[::-1] + x[-1]
293+
z = x0 / s
294+
Km1 = x.shape[0] - 1
295+
k = np.arange(Km1)[(slice(None),) + (None,) * (x.ndim - 1)]
296+
eq_share = nplogit(1. / (Km1 + 1 - k).astype(str(x_.dtype)))
297+
y = nplogit(z) - eq_share
298+
return floatX(y.T)
289299

290300
def backward(self, y_):
291301
y = y_.T
@@ -326,7 +336,7 @@ def forward(self, x):
326336
return tt.as_tensor_variable(x)
327337

328338
def forward_val(self, x, point=None):
329-
return self.forward(x)
339+
return x
330340

331341
def jacobian_det(self, x):
332342
return 0
@@ -346,8 +356,9 @@ def backward(self, x):
346356
def forward(self, y):
347357
return tt.advanced_set_subtensor1(y, tt.log(y[self.diag_idxs]), self.diag_idxs)
348358

349-
def forward_val(self, x, point=None):
350-
return self.forward(x)
359+
def forward_val(self, y, point=None):
360+
y[self.diag_idxs] = np.log(y[self.diag_idxs])
361+
return y
351362

352363
def jacobian_det(self, y):
353364
return tt.sum(y[self.diag_idxs])

pymc3/step_methods/smc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,9 @@ def __init__(self, vars=None, out_vars=None, samples=1000, n_chains=100, n_steps
186186
init_rnd = {}
187187
for v in vars:
188188
if pm.util.is_transformed_name(v.name):
189-
trans = v.distribution.transform_used.forward
190-
rnd = trans(v.distribution.dist.random(size=self.n_chains, point=start))
191-
init_rnd[v.name] = rnd.eval()
189+
trans = v.distribution.transform_used.forward_val
190+
init_rnd[v.name] = trans(v.distribution.dist.random(
191+
size=self.n_chains, point=start))
192192
else:
193193
init_rnd[v.name] = v.random(size=self.n_chains, point=start)
194194

pymc3/tests/test_transforms.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,21 @@
1313
tol = 1e-7 if theano.config.floatX == 'flaot64' else 1e-6
1414

1515

16-
def check_transform_identity(transform, domain, constructor=tt.dscalar, test=0):
16+
def check_transform(transform, domain, constructor=tt.dscalar, test=0):
1717
x = constructor('x')
1818
x.tag.test_value = test
19+
# test forward and forward_val
20+
forward_f = theano.function([x], transform.forward(x))
21+
# test transform identity
1922
identity_f = theano.function([x], transform.backward(transform.forward(x)))
2023

2124
for val in domain.vals:
2225
close_to(val, identity_f(val), tol)
26+
close_to(transform.forward_val(val), forward_f(val), tol)
2327

2428

25-
def check_vector_transform_identity(transform, domain):
26-
return check_transform_identity(transform, domain, tt.dvector, test=np.array([0, 0]))
29+
def check_vector_transform(transform, domain):
30+
return check_transform(transform, domain, tt.dvector, test=np.array([0, 0]))
2731

2832

2933
def get_values(transform, domain=R, constructor=tt.dscalar, test=0):
@@ -35,9 +39,9 @@ def get_values(transform, domain=R, constructor=tt.dscalar, test=0):
3539

3640

3741
def test_simplex():
38-
check_vector_transform_identity(tr.stick_breaking, Simplex(2))
39-
check_vector_transform_identity(tr.stick_breaking, Simplex(4))
40-
check_transform_identity(tr.stick_breaking, MultiSimplex(
42+
check_vector_transform(tr.stick_breaking, Simplex(2))
43+
check_vector_transform(tr.stick_breaking, Simplex(4))
44+
check_transform(tr.stick_breaking, MultiSimplex(
4145
3, 2), constructor=tt.dmatrix, test=np.zeros((2, 2)))
4246

4347

@@ -56,8 +60,8 @@ def test_simplex_jacobian_det():
5660

5761

5862
def test_sum_to_1():
59-
check_vector_transform_identity(tr.sum_to_1, Simplex(2))
60-
check_vector_transform_identity(tr.sum_to_1, Simplex(4))
63+
check_vector_transform(tr.sum_to_1, Simplex(2))
64+
check_vector_transform(tr.sum_to_1, Simplex(4))
6165

6266

6367
def test_sum_to_1_jacobian_det():
@@ -95,7 +99,7 @@ def check_jacobian_det(transform, domain,
9599

96100

97101
def test_log():
98-
check_transform_identity(tr.log, Rplusbig)
102+
check_transform(tr.log, Rplusbig)
99103
check_jacobian_det(tr.log, Rplusbig, elemwise=True)
100104
check_jacobian_det(tr.log, Vector(Rplusbig, 2),
101105
tt.dvector, [0, 0], elemwise=True)
@@ -105,7 +109,7 @@ def test_log():
105109

106110

107111
def test_log_exp_m1():
108-
check_transform_identity(tr.log_exp_m1, Rplusbig)
112+
check_transform(tr.log_exp_m1, Rplusbig)
109113
check_jacobian_det(tr.log_exp_m1, Rplusbig, elemwise=True)
110114
check_jacobian_det(tr.log_exp_m1, Vector(Rplusbig, 2),
111115
tt.dvector, [0, 0], elemwise=True)
@@ -115,7 +119,7 @@ def test_log_exp_m1():
115119

116120

117121
def test_logodds():
118-
check_transform_identity(tr.logodds, Unit)
122+
check_transform(tr.logodds, Unit)
119123
check_jacobian_det(tr.logodds, Unit, elemwise=True)
120124
check_jacobian_det(tr.logodds, Vector(Unit, 2),
121125
tt.dvector, [.5, .5], elemwise=True)
@@ -127,7 +131,7 @@ def test_logodds():
127131

128132
def test_lowerbound():
129133
trans = tr.lowerbound(0.0)
130-
check_transform_identity(trans, Rplusbig)
134+
check_transform(trans, Rplusbig)
131135
check_jacobian_det(trans, Rplusbig, elemwise=True)
132136
check_jacobian_det(trans, Vector(Rplusbig, 2),
133137
tt.dvector, [0, 0], elemwise=True)
@@ -138,7 +142,7 @@ def test_lowerbound():
138142

139143
def test_upperbound():
140144
trans = tr.upperbound(0.0)
141-
check_transform_identity(trans, Rminusbig)
145+
check_transform(trans, Rminusbig)
142146
check_jacobian_det(trans, Rminusbig, elemwise=True)
143147
check_jacobian_det(trans, Vector(Rminusbig, 2),
144148
tt.dvector, [-1, -1], elemwise=True)
@@ -151,7 +155,7 @@ def test_interval():
151155
for a, b in [(-4, 5.5), (.1, .7), (-10, 4.3)]:
152156
domain = Unit * np.float64(b - a) + np.float64(a)
153157
trans = tr.interval(a, b)
154-
check_transform_identity(trans, domain)
158+
check_transform(trans, domain)
155159
check_jacobian_det(trans, domain, elemwise=True)
156160

157161
vals = get_values(trans)
@@ -161,7 +165,7 @@ def test_interval():
161165

162166
def test_circular():
163167
trans = tr.circular
164-
check_transform_identity(trans, Circ)
168+
check_transform(trans, Circ)
165169
check_jacobian_det(trans, Circ)
166170

167171
vals = get_values(trans)

pymc3/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def update_start_vals(a, b, model):
143143
d.transformation for d in model.deterministics if d.name == name]
144144
if transform_func:
145145
b[tname] = transform_func[0].forward_val(
146-
a[name], point=b).eval()
146+
a[name], point=b)
147147

148148
a.update({k: v for k, v in b.items() if k not in a})
149149

0 commit comments

Comments
 (0)