Skip to content

Commit a04e243

Browse files
committed
fixed Levy SSM
1 parent 1c253aa commit a04e243

File tree

1 file changed

+97
-126
lines changed

1 file changed

+97
-126
lines changed

examples/levy-ssm/script.jl

Lines changed: 97 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,35 @@
11
# # Levy-SSM latent state inference
2-
using AdvancedPS: SSMProblems
3-
using AdvancedPS
42
using Random
53
using Plots
64
using Distributions
75
using AdvancedPS
86
using LinearAlgebra
97
using SSMProblems
108

11-
struct GammaProcess
12-
C::Float64
13-
β::Float64
14-
tol::Float64
9+
struct GammaProcess{T}
10+
C::T
11+
β::T
12+
tol::T
13+
GammaProcess(C::T, β::T; ϵ::T=1e-10) where {T<:Real} = new{T}(C, β, ϵ)
1514
end
1615

1716
struct GammaPath{T}
1817
jumps::Vector{T}
1918
times::Vector{T}
2019
end
2120

22-
struct LangevinDynamics{T}
23-
A::Matrix{T}
24-
L::Vector{T}
25-
θ::T
26-
H::Vector{T}
27-
σe::T
28-
end
29-
30-
struct NormalMeanVariance{T}
31-
μ::T
32-
σ::T
33-
end
34-
3521
function simulate(
36-
rng::AbstractRNG,
37-
process::GammaProcess,
38-
rate::Float64,
39-
start::Float64,
40-
finish::Float64,
41-
t0::Float64=0.0,
42-
)
22+
rng::AbstractRNG, process::GammaProcess{T}, rate::T, start::T, finish::T, t0::T=zero(T)
23+
) where {T<:Real}
4324
let β = process.β, C = process.C, tolerance = process.tol
44-
jumps = Float64[]
25+
jumps = T[]
4526
last_jump = Inf
4627
t = t0
4728
truncated = last_jump < tolerance
4829
while !truncated
49-
t += rand(rng, Exponential(1.0 / rate))
50-
xi = 1.0 /* (exp(t / C) - 1))
51-
prob = (1.0 + β * xi) * exp(-β * xi)
30+
t += rand(rng, Exponential(one(T) / rate))
31+
xi = one(T) /* (exp(t / C) - one(T)))
32+
prob = (one(T) + β * xi) * exp(-β * xi)
5233
if rand(rng) < prob
5334
push!(jumps, xi)
5435
last_jump = xi
@@ -60,26 +41,67 @@ function simulate(
6041
end
6142
end
6243

63-
function integral(times::Array{Float64}, path::GammaPath)
44+
function integral(times::Array{<:Real}, path::GammaPath)
6445
let jumps = path.jumps, jump_times = path.times
6546
return [sum(jumps[jump_times .<= t]) for t in times]
6647
end
6748
end
6849

50+
struct LangevinDynamics{T}
51+
A::Matrix{T}
52+
L::Vector{T}
53+
θ::T
54+
H::Vector{T}
55+
σe::T
56+
end
57+
58+
struct NormalMeanVariance{T}
59+
μ::T
60+
σ::T
61+
end
62+
63+
f(dt, θ) = exp* dt)
64+
function Base.exp(dyn::LangevinDynamics{T}, dt::T) where {T<:Real}
65+
let θ = dyn.θ
66+
f_val = f(dt, θ)
67+
return [one(T) (f_val - 1)/θ; zero(T) f_val]
68+
end
69+
end
70+
71+
function meancov(
72+
t::T, dyn::LangevinDynamics, path::GammaPath, nvm::NormalMeanVariance
73+
) where {T<:Real}
74+
μ = zeros(T, 2)
75+
Σ = zeros(T, (2, 2))
76+
let times = path.times, jumps = path.jumps, μw = nvm.μ, σw = nvm.σ
77+
for (v, z) in zip(times, jumps)
78+
ft = exp(dyn, (t - v)) * dyn.L
79+
μ += ft .* μw .* z
80+
Σ += ft * transpose(ft) .* σw^2 .* z
81+
end
82+
83+
# Guarantees positive semi-definiteness
84+
return μ, Σ + T(1e-6) * I
85+
end
86+
end
87+
6988
# Gamma Process
7089
C = 1.0
7190
β = 1.0
72-
ϵ = 1e-10
73-
process = GammaProcess(C, β, ϵ)
91+
process = GammaProcess(C, β)
7492

7593
# Normal Mean-Variance representation
7694
μw = 0.0
7795
σw = 1.0
7896
nvm = NormalMeanVariance(μw, σw)
7997

8098
# Levy SSM with Langevin dynamics
81-
# dx(t) = A x(t) dt + L dW(t)
82-
# y(t) = H x(t) + ϵ(t)
99+
# ```math
100+
# dx_{t} = A x_{t} dt + L dW_{t}
101+
# ```
102+
# ```math
103+
# y_{t} = H x_{t} + ϵ{t}
104+
# ```
83105
θ = -0.5
84106
A = [
85107
0.0 1.0
@@ -91,44 +113,17 @@ H = [1.0, 0]
91113
dyn = LangevinDynamics(A, L, θ, H, σe)
92114

93115
# Simulation parameters
94-
start, finish = 0, 100
95116
N = 200
96-
ts = range(start, finish; length=N)
97-
seed = 4
98-
rng = Random.MersenneTwister(seed)
99-
Np = 50
100-
Ns = 100
101-
102-
f(dt, θ) = exp* dt)
103-
function Base.exp(dyn::LangevinDynamics, dt::Real)
104-
let θ = dyn.θ
105-
f_val = f(dt, θ)
106-
return [1.0 (f_val - 1)/θ; 0 f_val]
107-
end
108-
end
109-
110-
function meancov(
111-
t::T, dyn::LangevinDynamics, path::GammaPath, nvm::NormalMeanVariance
112-
) where {T<:Real}
113-
μ = zeros(T, 2)
114-
Σ = zeros(T, (2, 2))
115-
let times = path.times, jumps = path.jumps, μw = nvm.μ, σw = nvm.σ
116-
for (v, z) in zip(times, jumps)
117-
ft = exp(dyn, (t - v)) * dyn.L
118-
μ += ft .* μw .* z
119-
Σ += ft * transpose(ft) .* σw^2 .* z
120-
end
121-
return μ, Σ
122-
end
123-
end
117+
ts = range(0, 100; length=N)
124118

119+
rng = Random.MersenneTwister(seed)
125120
X = zeros(Float64, (N, 2))
126121
Y = zeros(Float64, N)
127122
for (i, t) in enumerate(ts)
128123
if i > 1
129124
s = ts[i - 1]
130125
dt = t - s
131-
path = simulate(rng, process, dt, s, t, ϵ)
126+
path = simulate(rng, process, dt, s, t)
132127
μ, Σ = meancov(t, dyn, path, nvm)
133128
X[i, :] .= rand(rng, MultivariateNormal(exp(dyn, dt) * X[i - 1, :] + μ, Σ))
134129
end
@@ -138,77 +133,57 @@ for (i, t) in enumerate(ts)
138133
end
139134
end
140135

141-
# AdvancedPS
142-
Parameters = @NamedTuple begin
143-
dyn::LangevinDynamics
144-
process::GammaProcess
145-
nvm::NormalMeanVariance
146-
times::Vector{Float64}
147-
end
136+
# NOTE: doesn't match 1:1, but I think that's okay
137+
rng = Random.MersenneTwister(seed)
138+
_, x, y = sample(rng, levyssm, N)
148139

149-
struct MixedState{T}
150-
x::Vector{T}
151-
path::GammaPath{T}
140+
# TODO: this can surely be optimized
141+
struct LevyLangevin{T} <: LatentDynamics{T,Vector{T}}
142+
dt::T
143+
dyn::LangevinDynamics{T}
144+
process::GammaProcess{T}
145+
nvm::NormalMeanVariance{T}
152146
end
153147

154-
mutable struct LevyLangevin <: SSMProblems.AbstractStateSpaceModel
155-
X::Vector{MixedState{Float64}}
156-
observations::Vector{Float64}
157-
θ::Parameters
158-
LevyLangevin::Parameters) = new(Vector{MixedState{Float64}}(), θ)
159-
function LevyLangevin(y::Vector{Float64}, θ::Parameters)
160-
return new(Vector{MixedState{Float64}}(), y, θ)
161-
end
148+
function SSMProblems.distribution(proc::LevyLangevin{T}) where {T<:Real}
149+
return MultivariateNormal(zeros(T, 2), I)
162150
end
163151

164-
function SSMProblems.transition!!(rng::AbstractRNG, model::LevyLangevin)
165-
return MixedState(
166-
rand(rng, MultivariateNormal([0, 0], I)), GammaPath(Float64[], Float64[])
167-
)
152+
function SSMProblems.distribution(proc::LevyLangevin{T}, step::Int, state) where {T<:Real}
153+
dt = proc.dt
154+
path = simulate(rng, proc.process, dt, (step - 1) * dt, step * dt)
155+
μ, Σ = meancov(step * dt, proc.dyn, path, proc.nvm)
156+
return MultivariateNormal(exp(proc.dyn, dt) * state + μ, Σ)
168157
end
169158

170-
function SSMProblems.transition!!(
171-
rng::AbstractRNG, model::LevyLangevin, state::MixedState, step
172-
)
173-
times = model.θ.times
174-
s = times[step - 1]
175-
t = times[step]
176-
dt = t - s
177-
path = simulate(rng, model.θ.process, dt, s, t)
178-
μ, Σ = meancov(t, model.θ.dyn, path, model.θ.nvm)
179-
Σ += 1e-6 * I
180-
return MixedState(rand(rng, MultivariateNormal(exp(dyn, dt) * state.x + μ, Σ)), path)
159+
struct LinearGaussianObservation{T<:Real} <: ObservationProcess{T,T}
160+
H::Vector{T}
161+
R::T
181162
end
182163

183-
function SSMProblems.transition_logdensity(
184-
model::LevyLangevin, prev_state::MixedState, current_state::MixedState, step
185-
)
186-
times = model.θ.times
187-
s = times[step - 1]
188-
t = times[step]
189-
dt = t - s
190-
path = simulate(rng, model.θ.process, dt, s, t)
191-
μ, Σ = meancov(t, model.θ.dyn, path, model.θ.nvm)
192-
Σ += 1e-6 * I
193-
return logpdf(MultivariateNormal(exp(dyn, dt) * prev_state.x + μ, Σ), current_state.x)
164+
function SSMProblems.distribution(proc::LinearGaussianObservation, step::Int, state)
165+
return Normal(transpose(proc.H) * state, proc.R)
194166
end
195167

196-
function SSMProblems.emission_logdensity(model::LevyLangevin, state::MixedState, step)
197-
return logpdf(Normal(transpose(H) * state.x, σe), model.observations[step])
168+
function LevyModel(dt, A, L, θ, H, σe, C, β, ϵ, μw, σw)
169+
dyn = LevyLangevin(
170+
dt,
171+
LangevinDynamics(A, L, θ, H, σe),
172+
GammaProcess(C, β; ϵ),
173+
NormalMeanVariance(μw, σw),
174+
)
175+
176+
obs = LinearGaussianObservation(H, σe)
177+
return StateSpaceModel(dyn, obs)
198178
end
199179

200-
AdvancedPS.isdone(model::LevyLangevin, step) = step > length(model.θ.times)
180+
levyssm = LevyModel(0.5025125628140756, A, L, θ, H, σe, C, β, ϵ, μw, σw);
201181

202-
θ₀ = Parameters((dyn, process, nvm, ts))
203-
model = LevyLangevin(Y, θ₀)
204-
pg = AdvancedPS.PGAS(Np)
205-
chains = sample(rng, model, pg, Ns; progress=false);
182+
pg = AdvancedPS.PGAS(50);
183+
chains = sample(rng, levyssm(Y), pg, 100);
206184

207185
# Concat all sampled states
208-
particles = hcat([chain.trajectory.model.X for chain in chains]...)
209-
marginal_states = map(s -> s.x, particles);
210-
jump_times = map(s -> s.path.times, particles);
211-
jump_intensities = map(s -> s.path.jumps, particles);
186+
marginal_states = hcat([chain.trajectory.model.X for chain in chains]...)
212187

213188
# Plot marginal state and jump intensities for one trajectory
214189
p1 = plot(
@@ -224,12 +199,8 @@ plot!(
224199
label="Marginal State (x2)",
225200
)
226201

227-
p2 = scatter(
228-
vcat([t for t in jump_times[:, end]]...),
229-
vcat([j for j in jump_intensities[:, end]]...);
230-
color=:darkorange,
231-
label="Jumps",
232-
)
202+
# TODO: collect jumps from the model
203+
p2 = scatter([], []; color=:darkorange, label="Jumps")
233204

234205
plot(
235206
p1, p2; plot_title="Marginal State and Jump Intensities", layout=(2, 1), size=(600, 600)

0 commit comments

Comments
 (0)