Skip to content

Commit 85e0f9d

Browse files
committed
REF Make all_continuous() a general purpose function.
1 parent 4d53a36 commit 85e0f9d

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

pymc3/model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -703,6 +703,16 @@ def as_iterargs(data):
703703
else:
704704
return [data]
705705

706+
707+
def all_continuous(vars):
708+
"""Check that vars not include discrete variables, excepting ObservedRVs.
709+
"""
710+
vars_ = [var for var in vars if not isinstance(var, pm.model.ObservedRV)]
711+
if any([var.dtype in pm.discrete_types for var in vars_]):
712+
return False
713+
else:
714+
return True
715+
706716
# theano stuff
707717
theano.config.warn.sum_div_dimshuffle_bug = False
708718
theano.config.compute_test_value = 'raise'

pymc3/variational/advi.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@ def advi(vars=None, start=None, model=None, n=5000, accurate_elbo=False,
110110
vars = model.vars
111111
vars = pm.inputvars(vars)
112112

113-
check_discrete_rvs(vars)
113+
if not pm.model.all_continuous(vars):
114+
raise ValueError('Model should not include discrete RVs for ADVI.')
114115

115116
n_mcsamples = 100 if accurate_elbo else 1
116117

0 commit comments

Comments
 (0)