Skip to content

Commit 12694dd

Browse files
ColCarrollJunpeng Lao
authored andcommitted
Allow prior sampling from DensityDist (#3045)
* Allow prior sampling from DensityDist * Missed a file * Fix randomly failing test * More tolerance
1 parent 7077455 commit 12694dd

File tree

4 files changed

+17
-7
lines changed

4 files changed

+17
-7
lines changed

pymc3/distributions/distribution.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from theano import function
77
import theano
88
from ..memoize import memoize
9-
from ..model import Model, get_named_nodes_and_relations, FreeRV, ObservedRV
9+
from ..model import Model, get_named_nodes_and_relations, FreeRV, ObservedRV, MultiObservedRV
1010
from ..vartypes import string_types
1111

1212
__all__ = ['DensityDist', 'Distribution', 'Continuous', 'Discrete',
@@ -375,7 +375,7 @@ def _draw_value(param, point=None, givens=None, size=None):
375375
return param.value
376376
elif isinstance(param, tt.sharedvar.SharedVariable):
377377
return param.get_value()
378-
elif isinstance(param, tt.TensorVariable):
378+
elif isinstance(param, (tt.TensorVariable, MultiObservedRV)):
379379
if point and hasattr(param, 'model') and param.name in point:
380380
return point[param.name]
381381
elif hasattr(param, 'random') and param.random is not None:
@@ -404,8 +404,7 @@ def _draw_value(param, point=None, givens=None, size=None):
404404
return np.array([func(*v) for v in zip(*values)])
405405
else:
406406
return func(*values)
407-
else:
408-
raise ValueError('Unexpected type in draw_value: %s' % type(param))
407+
raise ValueError('Unexpected type in draw_value: %s' % type(param))
409408

410409

411410
def to_tuple(shape):

pymc3/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def get_named_nodes_and_relations(graph):
112112

113113
def _get_named_nodes_and_relations(graph, parent, leaf_nodes,
114114
node_parents, node_children):
115-
if graph.owner is None: # Leaf node
115+
if getattr(graph, 'owner', None) is None: # Leaf node
116116
if graph.name is not None: # Named leaf node
117117
leaf_nodes.update({graph.name: graph})
118118
if parent is not None: # Is None for the root node

pymc3/tests/test_sampling.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ def test_exec_nuts_init(method):
339339
assert isinstance(start[0], dict)
340340
assert 'a' in start[0] and 'b_log__' in start[0]
341341

342-
class TestSampleGenerative(SeededTest):
342+
class TestSamplePriorPredictive(SeededTest):
343343
def test_ignores_observed(self):
344344
observed = np.random.normal(10, 1, size=200)
345345
with pm.Model():
@@ -421,3 +421,14 @@ def test_shared(self):
421421
gen2 = pm.sample_prior_predictive(draws)
422422

423423
assert gen2['y'].shape == (draws, n2)
424+
425+
def test_density_dist(self):
426+
427+
obs = np.random.normal(-1, 0.1, size=10)
428+
with pm.Model():
429+
mu = pm.Normal('mu', 0, 1)
430+
sd = pm.Gamma('sd', 1, 2)
431+
a = pm.DensityDist('a', pm.Normal.dist(mu, sd).logp, random=pm.Normal.dist(mu, sd).random, observed=obs)
432+
prior = pm.sample_prior_predictive()
433+
434+
npt.assert_almost_equal(prior['a'].mean(), 0, decimal=1)

pymc3/tests/test_smc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def test_sample_n_core(self, n_jobs, stage):
5959

6060
x = mtrace.get_values('X')
6161
mu1d = np.abs(x).mean(axis=0)
62-
np.testing.assert_allclose(self.muref, mu1d, rtol=0., atol=0.03)
62+
np.testing.assert_allclose(self.muref, mu1d, rtol=0., atol=0.06)
6363
# Scenario IV Ching, J. & Chen, Y. 2007
6464
#assert np.round(np.log(self.ATMIP_test.marginal_likelihood)) == -12.0
6565

0 commit comments

Comments
 (0)