Skip to content

Commit 34a8bd4

Browse files
committed
Fix stickbreaking transform dtype
1 parent 4f54823 commit 34a8bd4

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

pymc3/distributions/transforms.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,15 +209,15 @@ def forward(self, x_):
209209
z = x0 / s
210210
Km1 = x.shape[0] - 1
211211
k = tt.arange(Km1)[(slice(None), ) + (None, ) * (x.ndim - 1)]
212-
eq_share = logit(1. / (Km1 + 1 - k)) # - tt.log(Km1 - k)
212+
eq_share = logit(1. / (Km1 + 1 - k).astype(str(x_.dtype)))
213213
y = logit(z) - eq_share
214214
return y.T
215215

216216
def backward(self, y_):
217217
y = y_.T
218218
Km1 = y.shape[0]
219219
k = tt.arange(Km1)[(slice(None), ) + (None, ) * (y.ndim - 1)]
220-
eq_share = logit(1. / (Km1 + 1 - k)) # - tt.log(Km1 - k)
220+
eq_share = logit(1. / (Km1 + 1 - k).astype(str(y_.dtype)))
221221
z = invlogit(y + eq_share, self.eps)
222222
yl = tt.concatenate([z, tt.ones(y[:1].shape)])
223223
yu = tt.concatenate([tt.ones(y[:1].shape), 1 - z])
@@ -229,7 +229,7 @@ def jacobian_det(self, y_):
229229
y = y_.T
230230
Km1 = y.shape[0]
231231
k = tt.arange(Km1)[(slice(None), ) + (None, ) * (y.ndim - 1)]
232-
eq_share = logit(1. / (Km1 + 1 - k)) # -tt.log(Km1 - k)
232+
eq_share = logit(1. / (Km1 + 1 - k).astype(str(y_.dtype)))
233233
yl = y + eq_share
234234
yu = tt.concatenate([tt.ones(y[:1].shape), 1 - invlogit(yl, self.eps)])
235235
S = tt.extra_ops.cumprod(yu, 0)

0 commit comments

Comments
 (0)