11# # Levy-SSM latent state inference
2- using AdvancedPS: SSMProblems
3- using AdvancedPS
42using Random
53using Plots
64using Distributions
75using AdvancedPS
86using LinearAlgebra
97using SSMProblems
108
11- struct GammaProcess
12- C:: Float64
13- β:: Float64
14- tol:: Float64
9+ struct GammaProcess{T}
10+ C:: T
11+ β:: T
12+ tol:: T
13+ GammaProcess (C:: T , β:: T ; ϵ:: T = 1e-10 ) where {T<: Real } = new {T} (C, β, ϵ)
1514end
1615
1716struct GammaPath{T}
1817 jumps:: Vector{T}
1918 times:: Vector{T}
2019end
2120
22- struct LangevinDynamics{T}
23- A:: Matrix{T}
24- L:: Vector{T}
25- θ:: T
26- H:: Vector{T}
27- σe:: T
28- end
29-
30- struct NormalMeanVariance{T}
31- μ:: T
32- σ:: T
33- end
34-
3521function simulate (
36- rng:: AbstractRNG ,
37- process:: GammaProcess ,
38- rate:: Float64 ,
39- start:: Float64 ,
40- finish:: Float64 ,
41- t0:: Float64 = 0.0 ,
42- )
22+ rng:: AbstractRNG , process:: GammaProcess{T} , rate:: T , start:: T , finish:: T , t0:: T = zero (T)
23+ ) where {T<: Real }
4324 let β = process. β, C = process. C, tolerance = process. tol
44- jumps = Float64 []
25+ jumps = T []
4526 last_jump = Inf
4627 t = t0
4728 truncated = last_jump < tolerance
4829 while ! truncated
49- t += rand (rng, Exponential (1.0 / rate))
50- xi = 1.0 / (β * (exp (t / C) - 1 ))
51- prob = (1.0 + β * xi) * exp (- β * xi)
30+ t += rand (rng, Exponential (one (T) / rate))
31+ xi = one (T) / (β * (exp (t / C) - one (T) ))
32+ prob = (one (T) + β * xi) * exp (- β * xi)
5233 if rand (rng) < prob
5334 push! (jumps, xi)
5435 last_jump = xi
@@ -60,26 +41,67 @@ function simulate(
6041 end
6142end
6243
63- function integral (times:: Array{Float64 } , path:: GammaPath )
44+ function integral (times:: Array{<:Real } , path:: GammaPath )
6445 let jumps = path. jumps, jump_times = path. times
6546 return [sum (jumps[jump_times .<= t]) for t in times]
6647 end
6748end
6849
50+ struct LangevinDynamics{T}
51+ A:: Matrix{T}
52+ L:: Vector{T}
53+ θ:: T
54+ H:: Vector{T}
55+ σe:: T
56+ end
57+
58+ struct NormalMeanVariance{T}
59+ μ:: T
60+ σ:: T
61+ end
62+
63+ f (dt, θ) = exp (θ * dt)
64+ function Base. exp (dyn:: LangevinDynamics{T} , dt:: T ) where {T<: Real }
65+ let θ = dyn. θ
66+ f_val = f (dt, θ)
67+ return [one (T) (f_val - 1 )/ θ; zero (T) f_val]
68+ end
69+ end
70+
71+ function meancov (
72+ t:: T , dyn:: LangevinDynamics , path:: GammaPath , nvm:: NormalMeanVariance
73+ ) where {T<: Real }
74+ μ = zeros (T, 2 )
75+ Σ = zeros (T, (2 , 2 ))
76+ let times = path. times, jumps = path. jumps, μw = nvm. μ, σw = nvm. σ
77+ for (v, z) in zip (times, jumps)
78+ ft = exp (dyn, (t - v)) * dyn. L
79+ μ += ft .* μw .* z
80+ Σ += ft * transpose (ft) .* σw^ 2 .* z
81+ end
82+
83+ # Guarantees positive semi-definiteness
84+ return μ, Σ + T (1e-6 ) * I
85+ end
86+ end
87+
6988# Gamma Process
7089C = 1.0
7190β = 1.0
72- ϵ = 1e-10
73- process = GammaProcess (C, β, ϵ)
91+ process = GammaProcess (C, β)
7492
7593# Normal Mean-Variance representation
7694μw = 0.0
7795σw = 1.0
7896nvm = NormalMeanVariance (μw, σw)
7997
8098# Levy SSM with Langevin dynamics
81- # dx(t) = A x(t) dt + L dW(t)
82- # y(t) = H x(t) + ϵ(t)
99+ # ```math
100+ # dx_{t} = A x_{t} dt + L dW_{t}
101+ # ```
102+ # ```math
103+ # y_{t} = H x_{t} + ϵ{t}
104+ # ```
83105θ = - 0.5
84106A = [
85107 0.0 1.0
@@ -91,44 +113,17 @@ H = [1.0, 0]
91113dyn = LangevinDynamics (A, L, θ, H, σe)
92114
93115# Simulation parameters
94- start, finish = 0 , 100
95116N = 200
96- ts = range (start, finish; length= N)
97- seed = 4
98- rng = Random. MersenneTwister (seed)
99- Np = 50
100- Ns = 100
101-
102- f (dt, θ) = exp (θ * dt)
103- function Base. exp (dyn:: LangevinDynamics , dt:: Real )
104- let θ = dyn. θ
105- f_val = f (dt, θ)
106- return [1.0 (f_val - 1 )/ θ; 0 f_val]
107- end
108- end
109-
110- function meancov (
111- t:: T , dyn:: LangevinDynamics , path:: GammaPath , nvm:: NormalMeanVariance
112- ) where {T<: Real }
113- μ = zeros (T, 2 )
114- Σ = zeros (T, (2 , 2 ))
115- let times = path. times, jumps = path. jumps, μw = nvm. μ, σw = nvm. σ
116- for (v, z) in zip (times, jumps)
117- ft = exp (dyn, (t - v)) * dyn. L
118- μ += ft .* μw .* z
119- Σ += ft * transpose (ft) .* σw^ 2 .* z
120- end
121- return μ, Σ
122- end
123- end
117+ ts = range (0 , 100 ; length= N)
124118
119+ rng = Random. MersenneTwister (seed)
125120X = zeros (Float64, (N, 2 ))
126121Y = zeros (Float64, N)
127122for (i, t) in enumerate (ts)
128123 if i > 1
129124 s = ts[i - 1 ]
130125 dt = t - s
131- path = simulate (rng, process, dt, s, t, ϵ )
126+ path = simulate (rng, process, dt, s, t)
132127 μ, Σ = meancov (t, dyn, path, nvm)
133128 X[i, :] .= rand (rng, MultivariateNormal (exp (dyn, dt) * X[i - 1 , :] + μ, Σ))
134129 end
@@ -138,77 +133,57 @@ for (i, t) in enumerate(ts)
138133 end
139134end
140135
141- # AdvancedPS
142- Parameters = @NamedTuple begin
143- dyn:: LangevinDynamics
144- process:: GammaProcess
145- nvm:: NormalMeanVariance
146- times:: Vector{Float64}
147- end
136+ # NOTE: doesn't match 1:1, but I think that's okay
137+ rng = Random. MersenneTwister (seed)
138+ _, x, y = sample (rng, levyssm, N)
148139
149- struct MixedState{T}
150- x:: Vector{T}
151- path:: GammaPath{T}
140+ # TODO : this can surely be optimized
141+ struct LevyLangevin{T} <: LatentDynamics{T,Vector{T}}
142+ dt:: T
143+ dyn:: LangevinDynamics{T}
144+ process:: GammaProcess{T}
145+ nvm:: NormalMeanVariance{T}
152146end
153147
154- mutable struct LevyLangevin <: SSMProblems.AbstractStateSpaceModel
155- X:: Vector{MixedState{Float64}}
156- observations:: Vector{Float64}
157- θ:: Parameters
158- LevyLangevin (θ:: Parameters ) = new (Vector {MixedState{Float64}} (), θ)
159- function LevyLangevin (y:: Vector{Float64} , θ:: Parameters )
160- return new (Vector {MixedState{Float64}} (), y, θ)
161- end
148+ function SSMProblems. distribution (proc:: LevyLangevin{T} ) where {T<: Real }
149+ return MultivariateNormal (zeros (T, 2 ), I)
162150end
163151
164- function SSMProblems. transition!! (rng:: AbstractRNG , model:: LevyLangevin )
165- return MixedState (
166- rand (rng, MultivariateNormal ([0 , 0 ], I)), GammaPath (Float64[], Float64[])
167- )
152+ function SSMProblems. distribution (proc:: LevyLangevin{T} , step:: Int , state) where {T<: Real }
153+ dt = proc. dt
154+ path = simulate (rng, proc. process, dt, (step - 1 ) * dt, step * dt)
155+ μ, Σ = meancov (step * dt, proc. dyn, path, proc. nvm)
156+ return MultivariateNormal (exp (proc. dyn, dt) * state + μ, Σ)
168157end
169158
170- function SSMProblems. transition!! (
171- rng:: AbstractRNG , model:: LevyLangevin , state:: MixedState , step
172- )
173- times = model. θ. times
174- s = times[step - 1 ]
175- t = times[step]
176- dt = t - s
177- path = simulate (rng, model. θ. process, dt, s, t)
178- μ, Σ = meancov (t, model. θ. dyn, path, model. θ. nvm)
179- Σ += 1e-6 * I
180- return MixedState (rand (rng, MultivariateNormal (exp (dyn, dt) * state. x + μ, Σ)), path)
159+ struct LinearGaussianObservation{T<: Real } <: ObservationProcess{T,T}
160+ H:: Vector{T}
161+ R:: T
181162end
182163
183- function SSMProblems. transition_logdensity (
184- model:: LevyLangevin , prev_state:: MixedState , current_state:: MixedState , step
185- )
186- times = model. θ. times
187- s = times[step - 1 ]
188- t = times[step]
189- dt = t - s
190- path = simulate (rng, model. θ. process, dt, s, t)
191- μ, Σ = meancov (t, model. θ. dyn, path, model. θ. nvm)
192- Σ += 1e-6 * I
193- return logpdf (MultivariateNormal (exp (dyn, dt) * prev_state. x + μ, Σ), current_state. x)
164+ function SSMProblems. distribution (proc:: LinearGaussianObservation , step:: Int , state)
165+ return Normal (transpose (proc. H) * state, proc. R)
194166end
195167
196- function SSMProblems. emission_logdensity (model:: LevyLangevin , state:: MixedState , step)
197- return logpdf (Normal (transpose (H) * state. x, σe), model. observations[step])
168+ function LevyModel (dt, A, L, θ, H, σe, C, β, ϵ, μw, σw)
169+ dyn = LevyLangevin (
170+ dt,
171+ LangevinDynamics (A, L, θ, H, σe),
172+ GammaProcess (C, β; ϵ),
173+ NormalMeanVariance (μw, σw),
174+ )
175+
176+ obs = LinearGaussianObservation (H, σe)
177+ return StateSpaceModel (dyn, obs)
198178end
199179
200- AdvancedPS . isdone (model :: LevyLangevin , step) = step > length (model . θ . times)
180+ levyssm = LevyModel ( 0.5025125628140756 , A, L, θ, H, σe, C, β, ϵ, μw, σw);
201181
202- θ₀ = Parameters ((dyn, process, nvm, ts))
203- model = LevyLangevin (Y, θ₀)
204- pg = AdvancedPS. PGAS (Np)
205- chains = sample (rng, model, pg, Ns; progress= false );
182+ pg = AdvancedPS. PGAS (50 );
183+ chains = sample (rng, levyssm (Y), pg, 100 );
206184
207185# Concat all sampled states
208- particles = hcat ([chain. trajectory. model. X for chain in chains]. .. )
209- marginal_states = map (s -> s. x, particles);
210- jump_times = map (s -> s. path. times, particles);
211- jump_intensities = map (s -> s. path. jumps, particles);
186+ marginal_states = hcat ([chain. trajectory. model. X for chain in chains]. .. )
212187
213188# Plot marginal state and jump intensities for one trajectory
214189p1 = plot (
@@ -224,12 +199,8 @@ plot!(
224199 label= " Marginal State (x2)" ,
225200)
226201
227- p2 = scatter (
228- vcat ([t for t in jump_times[:, end ]]. .. ),
229- vcat ([j for j in jump_intensities[:, end ]]. .. );
230- color= :darkorange ,
231- label= " Jumps" ,
232- )
202+ # TODO : collect jumps from the model
203+ p2 = scatter ([], []; color= :darkorange , label= " Jumps" )
233204
234205plot (
235206 p1, p2; plot_title= " Marginal State and Jump Intensities" , layout= (2 , 1 ), size= (600 , 600 )
0 commit comments