@@ -8,74 +8,83 @@ using Distributions
88using Libtask
99using SSMProblems
1010
11- Parameters = @NamedTuple begin
12- a:: Float64
13- q:: Float64
14- kernel
11+ # Gaussian process encoded transition dynamics
12+ mutable struct GaussianProcessDynamics{T<: Real } <: SSMProblems.LatentDynamics{T,T}
13+ proc:: AbstractGPs.AbstractGP
14+ q:: T
15+ function GaussianProcessDynamics (q:: T , kernel:: KT ) where {T<: Real ,KT<: Kernel }
16+ return new {T} (GP (ZeroMean {T} (), kernel), q)
17+ end
1518end
1619
17- mutable struct GPSSM <: SSMProblems.AbstractStateSpaceModel
18- X:: Vector{Float64}
19- observations:: Vector{Float64}
20- θ:: Parameters
21-
22- GPSSM (params:: Parameters ) = new (Vector {Float64} (), params)
23- GPSSM (y:: Vector{Float64} , params:: Parameters ) = new (Vector {Float64} (), y, params)
20+ function SSMProblems. distribution (dyn:: GaussianProcessDynamics{T} ) where {T<: Real }
21+ return Normal (zero (T), dyn. q)
2422end
2523
26- seed = 1
27- T = 100
28- Nₚ = 20
29- Nₛ = 250
30- a = 0.9
31- q = 0.5
32-
33- params = Parameters ((a, q, SqExponentialKernel ()))
24+ # TODO : broken...
25+ function SSMProblems . simulate (
26+ rng :: AbstractRNG , dyn :: GaussianProcessDynamics , step :: Int , state
27+ )
28+ dyn . proc = posterior (dyn . proc (step : step), [state])
29+ μ, σ = mean_and_cov (dyn . proc, [step])
30+ return rand (rng, Normal (μ[ 1 ], sqrt (σ[ 1 ])))
31+ end
3432
35- f (θ:: Parameters , x, t) = Normal (θ. a * x, θ. q)
36- h (θ:: Parameters ) = Normal (0 , θ. q)
37- g (θ:: Parameters , x, t) = Normal (0 , exp (0.5 * x)^ 2 )
33+ function SSMProblems. logdensity (dyn:: GaussianProcessDynamics , step:: Int , state, prev_state)
34+ μ, σ = mean_and_cov (dyn. proc, [step])
35+ return logpdf (Normal (μ, sqrt (σ)), state)
36+ end
3837
39- rng = Random. MersenneTwister (seed)
38+ # Linear Gaussian dynamics used for simulation
39+ struct LinearGaussianDynamics{T<: Real } <: SSMProblems.LatentDynamics{T,T}
40+ a:: T
41+ q:: T
42+ end
4043
41- x = zeros (T)
42- y = similar (x)
43- x[1 ] = rand (rng, h (params))
44- for t in 1 : T
45- if t < T
46- x[t + 1 ] = rand (rng, f (params, x[t], t))
47- end
48- y[t] = rand (rng, g (params, x[t], t))
44+ function SSMProblems. distribution (dyn:: LinearGaussianDynamics{T} ) where {T<: Real }
45+ return Normal (zero (T), dyn. q)
4946end
5047
51- function gp_update (model:: GPSSM , state, step)
52- gp = GP (model. θ. kernel)
53- prior = gp (1 : (step - 1 ))
54- post = posterior (prior, model. X[1 : (step - 1 )])
55- μ, σ = mean_and_cov (post, [step])
56- return Normal (μ[1 ], σ[1 ])
48+ function SSMProblems. distribution (dyn:: LinearGaussianDynamics , :: Int , state)
49+ return Normal (dyn. a * state, dyn. q)
5750end
5851
59- SSMProblems. transition!! (rng:: AbstractRNG , model:: GPSSM ) = rand (rng, h (model. θ))
60- function SSMProblems. transition!! (rng:: AbstractRNG , model:: GPSSM , state, step)
61- return rand (rng, gp_update (model, state, step))
52+ # Observation process used in both variants of the model
53+ struct StochasticVolatility{T<: Real } <: SSMProblems.ObservationProcess{T,T} end
54+
55+ function SSMProblems. distribution (:: StochasticVolatility{T} , :: Int , state) where {T<: Real }
56+ return Normal (zero (T), exp ((1 / 2 ) * state))
6257end
6358
64- function SSMProblems. emission_logdensity (model:: GPSSM , state, step)
65- return logpdf (g (model. θ, state, step), model. observations[step])
59+ # Baseline model (for simulation)
60+ function LinearGaussianStochasticVolatilityModel (a:: T , q:: T ) where {T<: Real }
61+ dyn = LinearGaussianDynamics (a, q)
62+ obs = StochasticVolatility {T} ()
63+ return SSMProblems. StateSpaceModel (dyn, obs)
6664end
67- function SSMProblems. transition_logdensity (model:: GPSSM , prev_state, current_state, step)
68- return logpdf (gp_update (model, prev_state, step), current_state)
65+
66+ # Gaussian process model (for sampling)
67+ function GaussianProcessStateSpaceModel (q:: T , kernel:: KT ) where {T<: Real ,KT<: Kernel }
68+ dyn = GaussianProcessDynamics (q, kernel)
69+ obs = StochasticVolatility {T} ()
70+ return SSMProblems. StateSpaceModel (dyn, obs)
6971end
7072
71- AdvancedPS. isdone (:: GPSSM , step) = step > T
73+ # Everything is now ready to simulate some data.
74+ rng = Random. MersenneTwister (1234 )
75+ true_model = LinearGaussianStochasticVolatilityModel (0.9 , 0.5 )
76+ _, x, y = sample (rng, true_model, 100 );
7277
73- model = GPSSM (y, params)
74- pg = AdvancedPS. PGAS (Nₚ)
75- chains = sample (rng, model, pg, Nₛ)
78+ # Create the model and run the sampler
79+ gpssm = GaussianProcessStateSpaceModel (0.5 , SqExponentialKernel ())
80+ model = gpssm (y)
81+ pg = AdvancedPS. PGAS (20 )
82+ chains = sample (rng, model, pg, 250 )
83+ # md nothing #hide
7684
7785particles = hcat ([chain. trajectory. model. X for chain in chains]. .. )
7886mean_trajectory = mean (particles; dims= 2 );
87+ # md nothing #hide
7988
8089scatter (particles; label= false , opacity= 0.01 , color= :black , xlabel= " t" , ylabel= " state" )
8190plot! (x; color= :darkorange , label= " Original Trajectory" )
0 commit comments