Skip to content

Commit 680a8eb

Browse files
authored
MAINT Samplers should keep the same dtype as provided by theano. (#1253)
* MAINT Allow downcast in metropolis and nuts functions. * ENH Samplers should keep the same dtype as provided by theano. * BUG np.array -> array.
1 parent f82971c commit 680a8eb

File tree

4 files changed

+29
-27
lines changed

4 files changed

+29
-27
lines changed

pymc3/distributions/distribution.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
import theano.tensor as tt
3+
import theano
34
from theano import function
45

56
from ..memoize import memoize
@@ -78,20 +79,20 @@ def TensorType(dtype, shape):
7879
return tt.TensorType(str(dtype), np.atleast_1d(shape) == 1)
7980

8081
class NoDistribution(Distribution):
81-
82+
8283
def __init__(self, shape, dtype, testval=None, defaults=[], transform=None, parent_dist=None, *args, **kwargs):
83-
super(NoDistribution, self).__init__(shape=shape, dtype=dtype,
84-
testval=testval, defaults=defaults,
84+
super(NoDistribution, self).__init__(shape=shape, dtype=dtype,
85+
testval=testval, defaults=defaults,
8586
*args, **kwargs)
8687
self.parent_dist = parent_dist
87-
88+
8889

8990
def __getattr__(self, name):
9091
try:
9192
self.__dict__[name]
9293
except KeyError:
9394
return getattr(self.parent_dist, name)
94-
95+
9596
def logp(self, x):
9697
return 0
9798

@@ -102,12 +103,12 @@ def __init__(self, shape=(), dtype='int64', defaults=['mode'], *args, **kwargs):
102103

103104
class Continuous(Distribution):
104105
"""Base class for continuous distributions"""
105-
def __init__(self, shape=(), dtype='float64', defaults=['median', 'mean', 'mode'], *args, **kwargs):
106+
def __init__(self, shape=(), dtype=theano.config.floatX, defaults=['median', 'mean', 'mode'], *args, **kwargs):
106107
super(Continuous, self).__init__(shape, dtype, defaults=defaults, *args, **kwargs)
107108

108109
class DensityDist(Distribution):
109110
"""Distribution based on a given log density function."""
110-
def __init__(self, logp, shape=(), dtype='float64',testval=0, *args, **kwargs):
111+
def __init__(self, logp, shape=(), dtype=theano.config.floatX,testval=0, *args, **kwargs):
111112
super(DensityDist, self).__init__(shape, dtype, testval, *args, **kwargs)
112113
self.logp = logp
113114

pymc3/step_methods/metropolis.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,10 @@ def astep(self, q0):
121121
q = (q0 + delta).astype(int)
122122
else:
123123
delta[self.discrete] = round(delta[self.discrete], 0).astype(int)
124-
q = q0 + delta
124+
q = (q0 + delta)
125125
else:
126-
q = q0 + delta
126+
q0 = q0.astype(theano.config.floatX)
127+
q = (q0 + delta).astype(theano.config.floatX)
127128

128129
q_new = metrop_select(self.delta_logp(q, q0), q, q0)
129130

@@ -182,7 +183,7 @@ def tune(scale, acc_rate):
182183

183184
class BinaryMetropolis(ArrayStep):
184185
"""Metropolis-Hastings optimized for binary variables
185-
186+
186187
Parameters
187188
----------
188189
vars : list
@@ -195,7 +196,7 @@ class BinaryMetropolis(ArrayStep):
195196
The frequency of tuning. Defaults to 100 iterations.
196197
model : PyMC Model
197198
Optional model for sampling step. Defaults to None (taken from context).
198-
199+
199200
"""
200201

201202
def __init__(self, vars, scaling=1., tune=True, tune_interval=100, model=None):
@@ -294,6 +295,6 @@ def delta_logp(logp, vars, shared):
294295

295296
logp1 = CallableTensor(logp0)(inarray1)
296297

297-
f = theano.function([inarray1, inarray0], logp1 - logp0)
298+
f = theano.function([inarray1, inarray0], logp1 - logp0, allow_input_downcast=True)
298299
f.trust_input = True
299300
return f

pymc3/step_methods/nuts.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,12 @@ def __init__(self, vars=None, scaling=None, step_scale=0.25, is_cov=False, state
6969
if isinstance(scaling, dict):
7070
scaling = guess_scaling(Point(scaling, model=model), model=model, vars = vars)
7171

72-
72+
scaling = scaling.astype(theano.config.floatX)
7373

7474
n = scaling.shape[0]
7575

7676
self.step_size = step_scale / n**(1/4.)
7777

78-
7978
self.potential = quad_potential(scaling, is_cov, as_cov=False)
8079

8180
if state is None:
@@ -92,27 +91,26 @@ def __init__(self, vars=None, scaling=None, step_scale=0.25, is_cov=False, state
9291
self.u = log(self.step_size*10)
9392
self.m = 1
9493

95-
96-
9794
shared = make_shared_replacements(vars, model)
9895
self.leapfrog1_dE = leapfrog1_dE(model.logpt, vars, shared, self.potential, profile=profile)
9996

10097
super(NUTS, self).__init__(vars, shared, **kwargs)
10198

10299
def astep(self, q0):
100+
q0 = q0.astype(theano.config.floatX)
103101
H = self.leapfrog1_dE #Hamiltonian(self.logp, self.dlogp, self.potential)
104102
Emax = self.Emax
105-
e = self.step_size
103+
e = array(self.step_size, dtype=theano.config.floatX)
106104

107-
p0 = self.potential.random()
105+
p0 = self.potential.random().astype(theano.config.floatX)
108106
u = uniform()
109107
q = qn = qp = q0
110108
p = pn = pp = p0
111109

112110
n, s, j = 1, 1, 0
113111

114112
while s == 1:
115-
v = bern(.5) * 2 - 1
113+
v = array(bern(.5) * 2 - 1, dtype=theano.config.floatX)
116114

117115
if v == -1:
118116
qn, pn, _, _, q1, n1, s1, a, na = buildtree(H, qn, pn, u, v, j, e, Emax, q0, p0)
@@ -196,22 +194,22 @@ def leapfrog1_dE(logp, vars, shared, pot, profile):
196194

197195
H = Hamiltonian(logp, dlogp, pot)
198196

199-
p = tt.dvector('p')
197+
p = tt.vector('p')
200198
p.tag.test_value = q.tag.test_value
201199

202-
q0 = tt.dvector('q0')
200+
q0 = tt.vector('q0')
203201
q0.tag.test_value = q.tag.test_value
204-
p0 = tt.dvector('p0')
202+
p0 = tt.vector('p0')
205203
p0.tag.test_value = p.tag.test_value
206204

207-
e = tt.dscalar('e')
208-
e.tag.test_value = 1
205+
e = tt.scalar('e')
206+
e.tag.test_value = 1.
209207

210208
q1, p1 = leapfrog(H, q, p, 1, e)
211209
E = energy(H, q1, p1)
212210
E0 = energy(H, q0, p0)
213211
dE = E - E0
214212

215-
f = theano.function([q, p, e, q0, p0], [q1, p1, dE], profile=profile)
213+
f = theano.function([q, p, e, q0, p0], [q1, p1, dE], profile=profile, allow_input_downcast=True)
216214
f.trust_input = True
217215
return f

pymc3/vartypes.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import theano
2+
13
__all__ = ['bool_types', 'int_types', 'float_types', 'complex_types', 'continuous_types',
24
'discrete_types', 'default_type', 'typefilter']
35

@@ -19,9 +21,9 @@
1921
discrete_types = bool_types | int_types
2022

2123
default_type = {'discrete': 'int64',
22-
'continuous': 'float64'}
24+
'continuous': theano.config.floatX}
2325

2426

2527
def typefilter(vars, types):
2628
# Returns variables of type `types` from `vars`
27-
return [v for v in vars if v.dtype in types]
29+
return [v for v in vars if v.dtype in types]

0 commit comments

Comments
 (0)