Skip to content

Commit e62671c

Browse files
committed
Add random method to GaussianRandomWalk
1 parent 6c17578 commit e62671c

File tree

2 files changed

+59
-13
lines changed

2 files changed

+59
-13
lines changed

pymc3/distributions/timeseries.py

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
from scipy import stats
12
import theano.tensor as tt
23
from theano import scan
34

45
from pymc3.util import get_variable_name
56
from .continuous import get_tau_sigma, Normal, Flat
67
from . import multivariate
78
from . import distribution
9+
from .shape_utils import to_tuple
810

911

1012
__all__ = [
@@ -166,17 +168,22 @@ def logp(self, value):
166168

167169

168170
class GaussianRandomWalk(distribution.Continuous):
169-
R"""
170-
Random Walk with Normal innovations
171+
R"""Random Walk with Normal innovations
171172
172173
Parameters
173174
----------
174175
mu: tensor
175176
innovation drift, defaults to 0.0
177+
For vector valued mu, first dimension must match shape of the random walk, and
178+
the first element will be discarded (since there is no innovation in the first timestep)
176179
sigma : tensor
177180
sigma > 0, innovation standard deviation (only required if tau is not specified)
181+
For vector valued sigma, first dimension must match shape of the random walk, and
182+
the first element will be discarded (since there is no innovation in the first timestep)
178183
tau : tensor
179184
tau > 0, innovation precision (only required if sigma is not specified)
185+
For vector valued tau, first dimension must match shape of the random walk, and
186+
the first element will be discarded (since there is no innovation in the first timestep)
180187
init : distribution
181188
distribution for initial value (Defaults to Flat())
182189
"""
@@ -187,9 +194,14 @@ def __init__(self, tau=None, init=Flat.dist(), sigma=None, mu=0.,
187194
if sd is not None:
188195
sigma = sd
189196
tau, sigma = get_tau_sigma(tau=tau, sigma=sigma)
190-
self.tau = tau = tt.as_tensor_variable(tau)
191-
self.sigma = self.sd = sigma = tt.as_tensor_variable(sigma)
192-
self.mu = mu = tt.as_tensor_variable(mu)
197+
self.tau = tt.as_tensor_variable(tau)
198+
sigma = tt.as_tensor_variable(sigma)
199+
if sigma.ndim > 0:
200+
sigma = sigma[:-1]
201+
self.sigma = self.sd = sigma
202+
self.mu = tt.as_tensor_variable(mu)
203+
if self.mu.ndim > 0:
204+
self.mu = self.mu[:-1]
193205
self.init = init
194206
self.mean = tt.as_tensor_variable(0.)
195207

@@ -206,15 +218,41 @@ def logp(self, x):
206218
-------
207219
TensorVariable
208220
"""
209-
sigma = self.sigma
210-
mu = self.mu
211-
init = self.init
221+
if x.ndim > 0:
222+
x_im1 = x[:-1]
223+
x_i = x[1:]
212224

213-
x_im1 = x[:-1]
214-
x_i = x[1:]
225+
sigma = self.sigma
226+
mu = self.mu
227+
228+
innov_like = Normal.dist(mu=x_im1 + mu, sigma=sigma).logp(x_i)
229+
return self.init.logp(x[0]) + tt.sum(innov_like)
230+
return self.init.logp(x)
231+
232+
def random(self, point=None, size=None):
233+
"""Draw random values from GaussianRandomWalk.
215234
216-
innov_like = Normal.dist(mu=x_im1 + mu, sigma=sigma).logp(x_i)
217-
return init.logp(x[0]) + tt.sum(innov_like)
235+
Parameters
236+
----------
237+
point : dict, optional
238+
Dict of variable values on which random values are to be
239+
conditioned (uses default point if not specified).
240+
size : int, optional
241+
Desired size of random sample (returns one sample if not
242+
specified).
243+
244+
Returns
245+
-------
246+
array
247+
"""
248+
sigma, mu = distribution.draw_values([self.sigma, self.mu], point=point, size=size)
249+
return distribution.generate_samples(self._random, sigma=sigma, mu=mu, size=size,
250+
dist_shape=self.shape)
251+
252+
def _random(self, sigma, mu, size):
253+
"""Implement a Gaussian random walk as a cumulative sum of normals."""
254+
rv = stats.norm(mu, sigma)
255+
return rv.rvs(size).cumsum(axis=0)
218256

219257
def _repr_latex_(self, name=None, dist=None):
220258
if dist is None:

pymc3/tests/test_distributions_random.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,14 @@ def test_different_shapes_and_sample_sizes(self, shape):
251251
assert e == a
252252

253253

254+
class TestGaussianRandomWalk(BaseTestCases.BaseTestCase):
255+
distribution = pm.GaussianRandomWalk
256+
params = {'mu': 1., 'sigma': 1.}
257+
258+
@pytest.mark.xfail(reason="Supporting this makes a nasty API")
259+
def test_broadcast_shape(self):
260+
super().test_broadcast_shape()
261+
254262
class TestNormal(BaseTestCases.BaseTestCase):
255263
distribution = pm.Normal
256264
params = {'mu': 0., 'tau': 1.}
@@ -1006,7 +1014,7 @@ def test_density_dist_without_random_not_sampleable(self):
10061014
normal_dist = pm.Normal.dist(mu, 1)
10071015
pm.DensityDist('density_dist', normal_dist.logp, observed=np.random.randn(100))
10081016
trace = pm.sample(100)
1009-
1017+
10101018
samples = 500
10111019
with pytest.raises(ValueError):
10121020
pm.sample_posterior_predictive(trace, samples=samples, model=model, size=100)

0 commit comments

Comments
 (0)