Skip to content

Commit 748c61c

Browse files
author
Junpeng Lao
authored
Merge pull request #2328 from junpenglao/fix_#2258
Fix #2258
2 parents 44b7d12 + 2ba776c commit 748c61c

File tree

4 files changed

+84
-17
lines changed

4 files changed

+84
-17
lines changed

.travis.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@ install:
1111
- pip install coveralls pylint
1212

1313
env:
14-
- PYTHON_VERSION=2.7 FLOATX='float32' TESTCMD="--durations=10 --ignore=pymc3/tests/test_examples.py --cov-append --ignore=pymc3/tests/test_distributions_random.py --ignore=pymc3/tests/test_variational_inference.py --ignore=pymc3/tests/test_shared.py --ignore=pymc3/tests/test_smc.py --ignore=pymc3/tests/test_updates.py --ignore=pymc3/tests/test_posteriors.py"
15-
- PYTHON_VERSION=2.7 FLOATX='float32' RUN_PYLINT="true" TESTCMD="--durations=10 --cov-append pymc3/tests/test_distributions_random.py pymc3/tests/test_shared.py pymc3/tests/test_smc.py"
14+
- PYTHON_VERSION=2.7 FLOATX='float32' TESTCMD="--durations=10 --ignore=pymc3/tests/test_examples.py --cov-append --ignore=pymc3/tests/test_distributions_random.py --ignore=pymc3/tests/test_variational_inference.py --ignore=pymc3/tests/test_shared.py --ignore=pymc3/tests/test_smc.py --ignore=pymc3/tests/test_updates.py --ignore=pymc3/tests/test_posteriors.py --ignore=pymc3/tests/test_sampling.py"
15+
- PYTHON_VERSION=2.7 FLOATX='float32' RUN_PYLINT="true" TESTCMD="--durations=10 --cov-append pymc3/tests/test_distributions_random.py pymc3/tests/test_shared.py pymc3/tests/test_smc.py pymc3/tests/test_sampling.py"
1616
- PYTHON_VERSION=2.7 FLOATX='float32' TESTCMD="--durations=10 --cov-append pymc3/tests/test_examples.py pymc3/tests/test_variational_inference.py pymc3/tests/test_updates.py pymc3/tests/test_posteriors.py"
17-
- PYTHON_VERSION=2.7 FLOATX='float64' TESTCMD="--durations=10 --ignore=pymc3/tests/test_examples.py --cov-append --ignore=pymc3/tests/test_distributions_random.py --ignore=pymc3/tests/test_variational_inference.py --ignore=pymc3/tests/test_shared.py --ignore=pymc3/tests/test_smc.py --ignore=pymc3/tests/test_updates.py --ignore=pymc3/tests/test_posteriors.py"
18-
- PYTHON_VERSION=2.7 FLOATX='float64' RUN_PYLINT="true" TESTCMD="--durations=10 --cov-append pymc3/tests/test_distributions_random.py pymc3/tests/test_shared.py pymc3/tests/test_smc.py"
17+
- PYTHON_VERSION=2.7 FLOATX='float64' TESTCMD="--durations=10 --ignore=pymc3/tests/test_examples.py --cov-append --ignore=pymc3/tests/test_distributions_random.py --ignore=pymc3/tests/test_variational_inference.py --ignore=pymc3/tests/test_shared.py --ignore=pymc3/tests/test_smc.py --ignore=pymc3/tests/test_updates.py --ignore=pymc3/tests/test_posteriors.py --ignore=pymc3/tests/test_sampling.py"
18+
- PYTHON_VERSION=2.7 FLOATX='float64' RUN_PYLINT="true" TESTCMD="--durations=10 --cov-append pymc3/tests/test_distributions_random.py pymc3/tests/test_shared.py pymc3/tests/test_smc.py pymc3/tests/test_sampling.py"
1919
- PYTHON_VERSION=2.7 FLOATX='float64' TESTCMD="--durations=10 --cov-append pymc3/tests/test_examples.py pymc3/tests/test_variational_inference.py pymc3/tests/test_updates.py pymc3/tests/test_posteriors.py"
20-
- PYTHON_VERSION=3.6 FLOATX='float64' TESTCMD="--durations=10 --cov-append --ignore=pymc3/tests/test_examples.py --ignore=pymc3/tests/test_distributions_random.py --ignore=pymc3/tests/test_variational_inference.py --ignore=pymc3/tests/test_shared.py --ignore=pymc3/tests/test_smc.py --ignore=pymc3/tests/test_updates.py --ignore=pymc3/tests/test_posteriors.py"
21-
- PYTHON_VERSION=3.6 FLOATX='float64' TESTCMD="--durations=10 --cov-append pymc3/tests/test_distributions_random.py pymc3/tests/test_shared.py pymc3/tests/test_smc.py"
20+
- PYTHON_VERSION=3.6 FLOATX='float64' TESTCMD="--durations=10 --cov-append --ignore=pymc3/tests/test_examples.py --ignore=pymc3/tests/test_distributions_random.py --ignore=pymc3/tests/test_variational_inference.py --ignore=pymc3/tests/test_shared.py --ignore=pymc3/tests/test_smc.py --ignore=pymc3/tests/test_updates.py --ignore=pymc3/tests/test_posteriors.py --ignore=pymc3/tests/test_sampling.py"
21+
- PYTHON_VERSION=3.6 FLOATX='float64' TESTCMD="--durations=10 --cov-append pymc3/tests/test_distributions_random.py pymc3/tests/test_shared.py pymc3/tests/test_smc.py pymc3/tests/test_sampling.py"
2222
- PYTHON_VERSION=3.6 FLOATX='float64' TESTCMD="--durations=10 --cov-append pymc3/tests/test_examples.py pymc3/tests/test_variational_inference.py pymc3/tests/test_updates.py pymc3/tests/test_posteriors.py"
2323
script:
2424
- . ./scripts/test.sh $TESTCMD

pymc3/distributions/transforms.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import theano.tensor as tt
22

33
from ..model import FreeRV
4-
from ..theanof import gradient
4+
from ..theanof import gradient, floatX
55
from . import distribution
66
from ..math import logit, invlogit
7+
from .distribution import draw_values
78
import numpy as np
89

910
__all__ = ['transform', 'stick_breaking', 'logodds', 'interval',
@@ -22,6 +23,9 @@ class Transform(object):
2223
def forward(self, x):
2324
raise NotImplementedError
2425

26+
def forward_val(self, x, point):
27+
raise NotImplementedError
28+
2529
def backward(self, z):
2630
raise NotImplementedError
2731

@@ -55,6 +59,7 @@ def __init__(self, dist, transform, *args, **kwargs):
5559
arguments to Distribution"""
5660
forward = transform.forward
5761
testval = forward(dist.default())
62+
forward_val = transform.forward_val
5863

5964
self.dist = dist
6065
self.transform_used = transform
@@ -85,6 +90,9 @@ def backward(self, x):
8590

8691
def forward(self, x):
8792
return tt.log(x)
93+
94+
def forward_val(self, x, point=None):
95+
return self.forward(x)
8896

8997
def jacobian_det(self, x):
9098
return x
@@ -103,6 +111,9 @@ def backward(self, x):
103111

104112
def forward(self, x):
105113
return logit(x)
114+
115+
def forward_val(self, x, point=None):
116+
return self.forward(x)
106117

107118
logodds = LogOdds()
108119

@@ -125,6 +136,14 @@ def forward(self, x):
125136
a, b = self.a, self.b
126137
return tt.log(x - a) - tt.log(b - x)
127138

139+
def forward_val(self, x, point=None):
140+
# 2017-06-19
141+
# the `self.a-0.` below is important for the testval to propagates
142+
# For an explanation see pull/2328#issuecomment-309303811
143+
a, b = draw_values([self.a-0., self.b-0.],
144+
point=point)
145+
return floatX(tt.log(x - a) - tt.log(b - x))
146+
128147
def jacobian_det(self, x):
129148
s = tt.nnet.softplus(-x)
130149
return tt.log(self.b - self.a) - 2 * s - x
@@ -147,8 +166,15 @@ def backward(self, x):
147166

148167
def forward(self, x):
149168
a = self.a
150-
r = tt.log(x - a)
151-
return r
169+
return tt.log(x - a)
170+
171+
def forward_val(self, x, point=None):
172+
# 2017-06-19
173+
# the `self.a-0.` below is important for the testval to propagates
174+
# For an explanation see pull/2328#issuecomment-309303811
175+
a = draw_values([self.a-0.],
176+
point=point)[0]
177+
return floatX(tt.log(x - a))
152178

153179
def jacobian_det(self, x):
154180
return x
@@ -171,8 +197,15 @@ def backward(self, x):
171197

172198
def forward(self, x):
173199
b = self.b
174-
r = tt.log(b - x)
175-
return r
200+
return tt.log(b - x)
201+
202+
def forward_val(self, x, point=None):
203+
# 2017-06-19
204+
# the `self.b-0.` below is important for the testval to propagates
205+
# For an explanation see pull/2328#issuecomment-309303811
206+
b = draw_values([self.b-0.],
207+
point=point)[0]
208+
return floatX(tt.log(b - x))
176209

177210
def jacobian_det(self, x):
178211
return x
@@ -191,6 +224,9 @@ def backward(self, y):
191224
def forward(self, x):
192225
return x[:-1]
193226

227+
def forward_val(self, x, point=None):
228+
return self.forward(x)
229+
194230
def jacobian_det(self, x):
195231
return 0
196232

@@ -224,6 +260,9 @@ def forward(self, x_):
224260
y = logit(z) - eq_share
225261
return y.T
226262

263+
def forward_val(self, x, point=None):
264+
return self.forward(x)
265+
227266
def backward(self, y_):
228267
y = y_.T
229268
Km1 = y.shape[0]
@@ -262,6 +301,9 @@ def backward(self, y):
262301
def forward(self, x):
263302
return tt.as_tensor_variable(x)
264303

304+
def forward_val(self, x, point=None):
305+
return self.forward(x)
306+
265307
def jacobian_det(self, x):
266308
return 0
267309

@@ -280,5 +322,8 @@ def backward(self, x):
280322
def forward(self, y):
281323
return tt.advanced_set_subtensor1(y, tt.log(y[self.diag_idxs]), self.diag_idxs)
282324

325+
def forward_val(self, x, point=None):
326+
return self.forward(x)
327+
283328
def jacobian_det(self, y):
284329
return tt.sum(y[self.diag_idxs])

pymc3/sampling.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -478,12 +478,14 @@ def _update_start_vals(a, b, model):
478478
"""Update a with b, without overwriting existing keys. Values specified for
479479
transformed variables on the original scale are also transformed and inserted.
480480
"""
481-
for name in a:
482-
for tname in b:
483-
if is_transformed_name(tname) and get_untransformed_name(tname) == name:
484-
transform_func = [d.transformation for d in model.deterministics if d.name == name]
485-
if transform_func:
486-
b[tname] = transform_func[0].forward(a[name]).eval()
481+
if model is not None:
482+
for free_RV in model.free_RVs:
483+
tname = free_RV.name
484+
for name in a:
485+
if is_transformed_name(tname) and get_untransformed_name(tname) == name:
486+
transform_func = [d.transformation for d in model.deterministics if d.name == name]
487+
if transform_func:
488+
b[tname] = transform_func[0].forward_val(a[name], point=b).eval()
487489

488490
a.update({k: v for k, v in b.items() if k not in a})
489491

pymc3/tests/test_sampling.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,26 @@ def test_soft_update_transformed(self):
150150
pm.sampling._update_start_vals(start, test_point, model)
151151
assert_almost_equal(np.exp(start['a_log__']), start['a'])
152152

153+
def test_soft_update_parent(self):
154+
with pm.Model() as model:
155+
a = pm.Uniform('a', lower=0., upper=1.)
156+
b = pm.Uniform('b', lower=2., upper=3.)
157+
pm.Uniform('lower', lower=a, upper=3.)
158+
pm.Uniform('upper', lower=0., upper=b)
159+
pm.Uniform('interv', lower=a, upper=b)
160+
161+
start = {'a': .3, 'b': 2.1, 'lower': 1.4, 'upper': 1.4, 'interv':1.4}
162+
test_point = {'lower_interval__': -0.3746934494414109,
163+
'upper_interval__': 0.693147180559945,
164+
'interv_interval__': 0.4519851237430569}
165+
pm.sampling._update_start_vals(start, model.test_point, model)
166+
assert_almost_equal(start['lower_interval__'],
167+
test_point['lower_interval__'])
168+
assert_almost_equal(start['upper_interval__'],
169+
test_point['upper_interval__'])
170+
assert_almost_equal(start['interv_interval__'],
171+
test_point['interv_interval__'])
172+
153173

154174
@pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
155175
class TestNamedSampling(SeededTest):

0 commit comments

Comments
 (0)