Skip to content

Commit 7327f9d

Browse files
Junpeng Laotwiecki
authored andcommitted
Fix Binomial GLM
In glm module binomial familiy is actually Bernoulli. This PR fix the Binomial likelihood, with the flexibility of specifying the `n`. Default n set to 1 for backward compatibility. Also added test.
1 parent f6cb5ff commit 7327f9d

File tree

3 files changed

+20
-3
lines changed

3 files changed

+20
-3
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
- Plots of discrete distributions in the docstrings
1717
- Add logitnormal distribution
1818
- Densityplot: add support for discrete variables
19+
- Fix the Binomial likelihood in `.glm.families.Binomial`, with the flexibility of specifying the `n`.
1920

2021
### Fixes
2122

pymc3/glm/families.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numbers
2+
import numpy as np
23
from copy import copy
34

45
import theano.tensor as tt
@@ -51,7 +52,7 @@ def _get_priors(self, model=None, name=''):
5152
model = modelcontext(model)
5253
priors = {}
5354
for key, val in self.priors.items():
54-
if isinstance(val, numbers.Number):
55+
if isinstance(val, (numbers.Number, np.ndarray, np.generic)):
5556
priors[key] = val
5657
else:
5758
priors[key] = model.Var('{}{}'.format(name, key), val)
@@ -99,8 +100,9 @@ class Normal(Family):
99100

100101
class Binomial(Family):
101102
link = logit
102-
likelihood = pm_dists.Bernoulli
103+
likelihood = pm_dists.Binomial
103104
parent = 'p'
105+
priors = {'n': 1}
104106

105107

106108
class Poisson(Family):

pymc3/tests/test_glm.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,13 @@ def setup_class(cls):
2626

2727
x_logistic, y_logistic = generate_data(cls.intercept, cls.slope, size=3000)
2828
y_logistic = 1 / (1 + np.exp(-y_logistic))
29-
bern_trials = [np.random.binomial(1, i) for i in y_logistic]
29+
bern_trials = np.random.binomial(1, y_logistic)
3030
cls.data_logistic = dict(x=x_logistic, y=bern_trials)
3131

32+
n_trials = np.random.randint(1, 20, size=y_logistic.shape)
33+
binom_trials = np.random.binomial(n_trials, y_logistic)
34+
cls.data_logistic2 = dict(x=x_logistic, y=binom_trials, n=n_trials)
35+
3236
def test_linear_component(self):
3337
with Model() as model:
3438
lm = LinearComponent.from_formula('y ~ x', self.data_linear)
@@ -65,6 +69,16 @@ def test_glm_link_func(self):
6569
assert round(abs(np.mean(trace['Intercept'])-self.intercept), 1) == 0
6670
assert round(abs(np.mean(trace['x'])-self.slope), 1) == 0
6771

72+
def test_glm_link_func2(self):
73+
with Model() as model:
74+
GLM.from_formula('y ~ x', self.data_logistic2,
75+
family=families.Binomial(priors={'n': self.data_logistic2['n']}))
76+
trace = sample(1000, progressbar=False,
77+
random_seed=self.random_seed)
78+
79+
assert round(abs(np.mean(trace['Intercept'])-self.intercept), 1) == 0
80+
assert round(abs(np.mean(trace['x'])-self.slope), 1) == 0
81+
6882
def test_more_than_one_glm_is_ok(self):
6983
with Model():
7084
GLM.from_formula('y ~ x', self.data_logistic,

0 commit comments

Comments
 (0)