Skip to content

Commit 563a27a

Browse files
committed
update Levy SSM
1 parent c258ee3 commit 563a27a

File tree

1 file changed

+23
-56
lines changed

1 file changed

+23
-56
lines changed

examples/levy-ssm/script.jl

Lines changed: 23 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -85,59 +85,6 @@ function meancov(
8585
end
8686
end
8787

88-
# Gamma Process
89-
C = 1.0
90-
β = 1.0
91-
process = GammaProcess(C, β)
92-
93-
# Normal Mean-Variance representation
94-
μw = 0.0
95-
σw = 1.0
96-
nvm = NormalMeanVariance(μw, σw)
97-
98-
# Levy SSM with Langevin dynamics
99-
# ```math
100-
# dx_{t} = A x_{t} dt + L dW_{t}
101-
# ```
102-
# ```math
103-
# y_{t} = H x_{t} + ϵ{t}
104-
# ```
105-
θ = -0.5
106-
A = [
107-
0.0 1.0
108-
0.0 θ
109-
]
110-
L = [0.0; 1.0]
111-
σe = 1.0
112-
H = [1.0, 0]
113-
dyn = LangevinDynamics(A, L, θ, H, σe)
114-
115-
# Simulation parameters
116-
N = 200
117-
ts = range(0, 100; length=N)
118-
119-
rng = Random.MersenneTwister(seed)
120-
X = zeros(Float64, (N, 2))
121-
Y = zeros(Float64, N)
122-
for (i, t) in enumerate(ts)
123-
if i > 1
124-
s = ts[i - 1]
125-
dt = t - s
126-
path = simulate(rng, process, dt, s, t)
127-
μ, Σ = meancov(t, dyn, path, nvm)
128-
X[i, :] .= rand(rng, MultivariateNormal(exp(dyn, dt) * X[i - 1, :] + μ, Σ))
129-
end
130-
131-
let H = dyn.H, σe = dyn.σe
132-
Y[i] = transpose(H) * X[i, :] + rand(rng, Normal(0, σe))
133-
end
134-
end
135-
136-
# NOTE: doesn't match 1:1, but I think that's okay
137-
rng = Random.MersenneTwister(seed)
138-
_, x, y = sample(rng, levyssm, N)
139-
140-
# TODO: this can surely be optimized
14188
struct LevyLangevin{T} <: LatentDynamics{T,Vector{T}}
14289
dt::T
14390
dyn::LangevinDynamics{T}
@@ -165,7 +112,11 @@ function SSMProblems.distribution(proc::LinearGaussianObservation, step::Int, st
165112
return Normal(transpose(proc.H) * state, proc.R)
166113
end
167114

168-
function LevyModel(dt, A, L, θ, H, σe, C, β, ϵ, μw, σw)
115+
function LevyModel(dt, θ, σe, C, β, μw, σw; ϵ=1e-10)
116+
A = [0.0 1.0; 0.0 θ]
117+
L = [0.0; 1.0]
118+
H = [1.0, 0]
119+
169120
dyn = LevyLangevin(
170121
dt,
171122
LangevinDynamics(A, L, θ, H, σe),
@@ -177,8 +128,24 @@ function LevyModel(dt, A, L, θ, H, σe, C, β, ϵ, μw, σw)
177128
return StateSpaceModel(dyn, obs)
178129
end
179130

180-
levyssm = LevyModel(0.5025125628140756, A, L, θ, H, σe, C, β, ϵ, μw, σw);
131+
# Levy SSM with Langevin dynamics
132+
# ```math
133+
# dx_{t} = A x_{t} dt + L dW_{t}
134+
# ```
135+
# ```math
136+
# y_{t} = H x_{t} + ϵ{t}
137+
# ```
138+
139+
# Simulation parameters
140+
N = 200
141+
ts = range(0, 100; length=N)
142+
levyssm = LevyModel(step(ts), θ, 1.0, 1.0, 1.0, 0.0, 1.0);
143+
144+
# Simulate data
145+
rng = Random.MersenneTwister(1234);
146+
_, X, Y = sample(rng, levyssm, N);
181147

148+
# Run sampler
182149
pg = AdvancedPS.PGAS(50);
183150
chains = sample(rng, levyssm(Y), pg, 100);
184151

@@ -221,7 +188,7 @@ for d in 1:2
221188
fillalpha=0.2,
222189
title="Marginal State Trajectories (X$d)",
223190
)
224-
plot!(p, ts, X[:, d]; color=:dodgerblue, label="True Trajectory")
191+
plot!(p, ts, getindex.(X, d); color=:dodgerblue, label="True Trajectory")
225192
push!(ps, p)
226193
end
227194
plot(ps...; layout=(2, 1), size=(600, 600))

0 commit comments

Comments
 (0)