Skip to content

Commit 883f4ef

Browse files
ferrineColCarroll
authored andcommitted
raise an error when model has discrete vars (#2917)
* raise an error when model has discrete vars * unicode char
1 parent e97f5ab commit 883f4ef

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

pymc3/tests/test_variational_inference.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
)
2222
from pymc3.variational import flows
2323
from pymc3.variational.opvi import Approximation, Group
24-
24+
from pymc3.variational import opvi
2525
from . import models
2626
from .helpers import not_raises
2727

@@ -835,6 +835,19 @@ def test_sample_replacements(binomial_model_inference):
835835
assert sampled.shape[0] == 101
836836

837837

838+
def test_discrete_not_allowed():
839+
mu_true = np.array([-2, 0, 2])
840+
z_true = np.random.randint(len(mu_true), size=100)
841+
y = np.random.normal(mu_true[z_true], np.ones_like(z_true))
842+
843+
with pm.Model():
844+
mu = pm.Normal('mu', mu=0, sd=10, shape=3)
845+
z = pm.Categorical('z', p=tt.ones(3) / 3, shape=len(y))
846+
pm.Normal('y_obs', mu=mu[z], sd=1., observed=y)
847+
with pytest.raises(opvi.ParametrizationError):
848+
pm.fit(n=1) # fails
849+
850+
838851
def test_var_replacement():
839852
X_mean = pm.floatX(np.linspace(0, 10, 10))
840853
y = pm.floatX(np.random.normal(X_mean*4, .05))

pymc3/variational/opvi.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,9 @@ def __init_group__(self, group):
894894
self.replacements = dict()
895895
self.group = [get_transformed(var) for var in self.group]
896896
for var in self.group:
897+
if isinstance(var.distribution, pm.Discrete):
898+
raise ParametrizationError('Discrete variables are not supported by VI: {}'
899+
.format(var))
897900
begin = self.ddim
898901
if self.batched:
899902
if var.ndim < 1:

0 commit comments

Comments
 (0)