Skip to content

Commit 3ca8af6

Browse files
authored
Merge pull request #1229 from taku-y/sample_vp
Hide transformed vars in sample_vp().
2 parents ce550a6 + 034caff commit 3ca8af6

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

pymc3/tests/test_advi.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import pymc3 as pm
23
from pymc3 import Model, Normal, DiscreteUniform, Poisson, switch, Exponential
34
from pymc3.theanof import inputvars
45
from pymc3.variational.advi import variational_gradient_estimate, advi, advi_minibatch, sample_vp
@@ -202,3 +203,22 @@ def create_minibatch(data):
202203

203204
np.testing.assert_allclose(np.mean(trace['mu']), mu_post, rtol=0.4)
204205
np.testing.assert_allclose(np.std(trace['mu']), np.sqrt(1. / d), rtol=0.4)
206+
207+
def test_sample_vp():
208+
n_samples = 100
209+
210+
rng = np.random.RandomState(0)
211+
xs = rng.binomial(n=1, p=0.2, size=n_samples)
212+
213+
with pm.Model() as model:
214+
p = pm.Beta('p', alpha=1, beta=1)
215+
pm.Binomial('xs', n=1, p=p, observed=xs)
216+
v_params = advi(n=1000)
217+
218+
with model:
219+
trace = sample_vp(v_params, hide_transformed=True)
220+
assert(set(trace.varnames) == set('p'))
221+
222+
with model:
223+
trace = sample_vp(v_params, hide_transformed=False)
224+
assert(set(trace.varnames) == set(('p', 'p_logodds_')))

pymc3/variational/advi.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,8 @@ def adagrad(grad, param, learning_rate, epsilon, n):
325325
tt.sqrt(accu_sum + epsilon))
326326
return updates
327327

328-
def sample_vp(vparams, draws=1000, model=None, random_seed=20090425):
328+
def sample_vp(vparams, draws=1000, model=None, random_seed=20090425,
329+
hide_transformed=True):
329330
"""Draw samples from variational posterior.
330331
331332
Parameters
@@ -338,6 +339,8 @@ def sample_vp(vparams, draws=1000, model=None, random_seed=20090425):
338339
Probabilistic model.
339340
random_seed : int
340341
Seed of random number generator.
342+
hide_transformed : bool
343+
If False, transformed variables are also sampled. Default is True.
341344
342345
Returns
343346
-------
@@ -366,8 +369,13 @@ def sample_vp(vparams, draws=1000, model=None, random_seed=20090425):
366369
samples = theano.clone(vars, updates)
367370
f = theano.function([], samples)
368371

372+
# Random variables which will be sampled
373+
vars_sampled = [v for v in model.unobserved_RVs if not str(v).endswith('_')] \
374+
if hide_transformed else \
375+
[v for v in model.unobserved_RVs]
376+
369377
varnames = [str(var) for var in model.unobserved_RVs]
370-
trace = NDArray(model=model, vars=model.unobserved_RVs)
378+
trace = NDArray(model=model, vars=vars_sampled)
371379
trace.setup(draws=draws, chain=0)
372380

373381
for i in range(draws):

0 commit comments

Comments
 (0)