Skip to content

Commit ba7923b

Browse files
authored
Fix dirichlet shape (#1241)
* BUG Reverse shape order of Dirichlet, thanks to @taku-y. * DOC Various fixes to example notebooks. Mainly rename ElemwiseCategoricalStep to ElemwiseCategorical. * MAINT Remove MyDirichlet. * Fix variable shape's inconsistency. * Fix test function for stickbreaking transform.
1 parent ed06f89 commit ba7923b

File tree

6 files changed

+223
-240
lines changed

6 files changed

+223
-240
lines changed

docs/source/notebooks/dp_mix.ipynb

Lines changed: 72 additions & 80 deletions
Large diffs are not rendered by default.

docs/source/notebooks/gaussian-mixture-model-advi.ipynb

Lines changed: 73 additions & 81 deletions
Large diffs are not rendered by default.

docs/source/notebooks/gaussian_mixture_model.ipynb

Lines changed: 40 additions & 43 deletions
Large diffs are not rendered by default.

pymc3/distributions/multivariate.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,9 @@ class MvStudentT(Continuous):
8080
Multivariate Student T log-likelihood.
8181
8282
.. math::
83-
83+
8484
f(\mathbf{x}| \nu,\mu,\Sigma) = \frac{\Gamma\left[(\nu+p)/2\right]}{\Gamma(\nu/2)\nu^{p/2}\pi^{p/2}\left|{\Sigma}\right|^{1/2}\left[1+\frac{1}{\nu}({\mathbf x}-{\mu})^T{\Sigma}^{-1}({\mathbf x}-{\mu})\right]^{(\nu+p)/2}}
85-
85+
8686
======== ==========================
8787
Support :math:`x \in \mathbb{R}^k`
8888
Mean :math:`\mu` if :math:`\nu > 1` else undefined
@@ -103,37 +103,35 @@ def __init__(self, nu, Sigma, mu=None, *args, **kwargs):
103103
self.nu = nu
104104
self.mu = np.zeros(Sigma.shape[0]) if mu is None else mu
105105
self.Sigma = Sigma
106-
106+
107107
self.mean = self.median = self.mode = self.mu = mu
108-
108+
109109
def random(self, point=None, size=None):
110110
chi2 = np.random.chisquare
111111
mvn = np.random.multivariate_normal
112-
112+
113113
nu, S, mu = draw_values([self.nu, self.Sigma, self.mu], point=point)
114-
114+
115115
return (np.sqrt(nu) * (mvn(np.zeros(len(S)), S, size).T
116116
/ chi2(nu, size))).T + mu
117-
118-
def logp(self, value):
119-
117+
118+
def logp(self, value):
119+
120120
S = self.Sigma
121121
nu = self.nu
122122
mu = self.mu
123123

124124
d = S.shape[0]
125125
n = value.shape[0]
126-
126+
127127
X = value - mu
128-
128+
129129
Q = X.dot(matrix_inverse(S)).dot(X.T).sum()
130130
log_det = tt.log(det(S))
131131
log_pdf = gammaln((nu + d)/2.) - 0.5 * (d*tt.log(np.pi*nu) + log_det) - gammaln(nu/2.)
132132
log_pdf -= 0.5*(nu + d)*tt.log(1 + Q/nu)
133-
133+
134134
return log_pdf
135-
136-
137135

138136

139137
class Dirichlet(Continuous):
@@ -167,7 +165,7 @@ class Dirichlet(Continuous):
167165
"""
168166
def __init__(self, a, transform=transforms.stick_breaking,
169167
*args, **kwargs):
170-
shape = a.shape[0]
168+
shape = a.shape[-1]
171169
kwargs.setdefault("shape", shape)
172170
super(Dirichlet, self).__init__(transform=transform, *args, **kwargs)
173171

@@ -195,8 +193,8 @@ def logp(self, value):
195193
a = self.a
196194

197195
# only defined for sum(value) == 1
198-
return bound(tt.sum(logpow(value, a - 1) - gammaln(a), axis=0)
199-
+ gammaln(tt.sum(a, axis=0)),
196+
return bound(tt.sum(logpow(value, a - 1) - gammaln(a), axis=-1)
197+
+ gammaln(tt.sum(a, axis=-1)),
200198
tt.all(value >= 0), tt.all(value <= 1),
201199
k > 1, tt.all(a > 0))
202200

pymc3/distributions/transforms.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ def __init__(self, dist, transform, *args, **kwargs):
6262
v.shape.tag.test_value, v.dtype,
6363
testval, dist.defaults, *args, **kwargs)
6464

65+
if transform.name == 'stickbreaking':
66+
b = np.hstack(((np.atleast_1d(self.shape) == 1)[:-1], False))
67+
self.type = tt.TensorType(v.dtype, b) # force the last dim not broadcastable
68+
69+
6570
def logp(self, x):
6671
return (self.dist.logp(self.transform_used.backward(x)) +
6772
self.transform_used.jacobian_det(x))
@@ -177,46 +182,46 @@ def jacobian_det(self, x):
177182

178183
sum_to_1 = SumTo1()
179184

180-
181185
class StickBreaking(Transform):
182186
"""Transforms K dimensional simplex space (values in [0,1] and sum to 1) to K - 1 vector of real values.
183-
184187
Primarily borrowed from the STAN implementation.
185188
"""
186189

187190
name = "stickbreaking"
188191

189-
def forward(self, x):
192+
def forward(self, x_):
193+
x = x_.T
190194
# reverse cumsum
191195
x0 = x[:-1]
192196
s = tt.extra_ops.cumsum(x0[::-1], 0)[::-1] + x[-1]
193197
z = x0/s
194198
Km1 = x.shape[0] - 1
195199
k = tt.arange(Km1)[(slice(None), ) + (None, ) * (x.ndim - 1)]
196-
eq_share = - tt.log(Km1 - k) # logit(1./(Km1 + 1 - k))
200+
eq_share = logit(1./(Km1 + 1 - k)) # - tt.log(Km1 - k)
197201
y = logit(z) - eq_share
198-
return y
202+
return y.T
199203

200-
def backward(self, y):
204+
def backward(self, y_):
205+
y = y_.T
201206
Km1 = y.shape[0]
202207
k = tt.arange(Km1)[(slice(None), ) + (None, ) * (y.ndim - 1)]
203-
eq_share = - tt.log(Km1 - k) # logit(1./(Km1 + 1 - k))
208+
eq_share = logit(1./(Km1 + 1 - k)) #- tt.log(Km1 - k)
204209
z = inverse_logit(y + eq_share)
205210
yl = tt.concatenate([z, tt.ones(y[:1].shape)])
206211
yu = tt.concatenate([tt.ones(y[:1].shape), 1-z])
207212
S = tt.extra_ops.cumprod(yu, 0)
208213
x = S * yl
209-
return x
214+
return x.T
210215

211-
def jacobian_det(self, y):
216+
def jacobian_det(self, y_):
217+
y = y_.T
212218
Km1 = y.shape[0]
213219
k = tt.arange(Km1)[(slice(None), ) + (None, ) * (y.ndim - 1)]
214-
eq_share = -tt.log(Km1 - k) # logit(1./(Km1 + 1 - k))
220+
eq_share = logit(1./(Km1 + 1 - k)) # -tt.log(Km1 - k)
215221
yl = y + eq_share
216222
yu = tt.concatenate([tt.ones(y[:1].shape), 1-inverse_logit(yl)])
217223
S = tt.extra_ops.cumprod(yu, 0)
218-
return tt.sum(tt.log(S[:-1]) - tt.log1p(tt.exp(yl)) - tt.log1p(tt.exp(-yl)),
219-
0)
224+
return tt.sum(tt.log(S[:-1]) - tt.log1p(tt.exp(yl)) - tt.log1p(tt.exp(-yl)), 0).T
220225

221226
stick_breaking = StickBreaking()
222227

@@ -229,7 +234,7 @@ class Circular(ElemwiseTransform):
229234
def backward(self, y):
230235
return tt.arctan2(tt.sin(y), tt.cos(y))
231236

232-
def forward(self, x):
237+
def forward(self, x):
233238
return x
234239

235240
circular = Circular()

pymc3/tests/test_distributions.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,9 @@ def __init__(self, n):
104104

105105
class MultiSimplex(object):
106106
def __init__(self, n_dependent, n_independent):
107-
transposed_vals = list(itertools.product(list(simplex_values(n_dependent)), repeat=n_independent))
108-
self.vals = list(np.transpose(transposed_vals, (0, 2, 1)))
107+
self.vals = [np.vstack(v) for v in list(itertools.product(list(simplex_values(n_dependent)), repeat=n_independent))]
109108

110-
self.shape = (n_dependent, n_independent)
109+
self.shape = (n_independent, n_dependent)
111110
self.dtype = Unit.dtype
112111
return
113112

@@ -450,13 +449,13 @@ def check_lkj(x, n, p, lp):
450449
lp, decimal=6, err_msg=str(pt))
451450

452451
def betafn(a):
453-
return scipy.special.gammaln(a).sum(0) - scipy.special.gammaln(a.sum(0))
452+
return scipy.special.gammaln(a).sum(-1) - scipy.special.gammaln(a.sum(-1))
454453

455454
def logpow(v, p):
456455
return np.choose(v==0, [p * np.log(v), 0])
457456

458457
def dirichlet_logpdf(value, a):
459-
return (-betafn(a) + logpow(value, a-1).sum(0)).sum()
458+
return (-betafn(a) + logpow(value, a-1).sum(-1)).sum()
460459

461460
def test_dirichlet():
462461
for n in [2,3]:
@@ -471,7 +470,7 @@ def check_dirichlet(n):
471470

472471
def check_dirichlet2D(ndep, nind):
473472
pymc3_matches_scipy(
474-
Dirichlet, MultiSimplex(ndep, nind), {'a': Vector(Vector(Rplus, nind), ndep) },
473+
Dirichlet, MultiSimplex(ndep, nind), {'a': Vector(Vector(Rplus, ndep), nind) },
475474
dirichlet_logpdf
476475
)
477476

0 commit comments

Comments
 (0)