Skip to content

Commit 0ca78be

Browse files
maedoctwiecki
authored andcommitted
ENH add SDE class (#1269)
1 parent fc71f68 commit 0ca78be

File tree

1 file changed

+28
-1
lines changed

1 file changed

+28
-1
lines changed

pymc3/distributions/timeseries.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from .continuous import Normal, Flat
55
from .distribution import Continuous
66

7-
__all__ = ['AR1', 'GaussianRandomWalk', 'GARCH11']
7+
__all__ = ['AR1', 'GaussianRandomWalk', 'GARCH11', 'EulerMaruyama']
88

99

1010
class AR1(Continuous):
@@ -124,3 +124,30 @@ def logp(self, x):
124124
vol = self._get_volatility(x[:-1])
125125
return (Normal.dist(0., sd=self.initial_vol).logp(x[0]) +
126126
tt.sum(Normal.dist(0, sd=vol).logp(x[1:])))
127+
128+
129+
class EulerMaruyama(Continuous):
130+
"""
131+
Stochastic differential equation discretized with the Euler-Maruyama method.
132+
133+
Parameters
134+
----------
135+
dt : float
136+
time step of discretization
137+
sde_fn : callable
138+
function returning the drift and diffusion coefficients of SDE
139+
sde_pars : tuple
140+
parameters of the SDE, passed as *args to sde_fn
141+
"""
142+
def __init__(self, dt, sde_fn, sde_pars, *args, **kwds):
143+
super(EulerMaruyama, self).__init__(*args, **kwds)
144+
self.dt = dt
145+
self.sde_fn = sde_fn
146+
self.sde_pars = sde_pars
147+
148+
def logp(self, x):
149+
xt = x[:-1]
150+
f, g = self.sde_fn(x[:-1], *self.sde_pars)
151+
mu = xt + self.dt * f
152+
sd = tt.sqrt(self.dt) * g
153+
return tt.sum(Normal.dist(mu=mu, sd=sd).logp(x[1:]))

0 commit comments

Comments
 (0)