@@ -8,74 +8,76 @@ using Distributions
88using Libtask
99using SSMProblems
1010
11- Parameters = @NamedTuple begin
12- a:: Float64
13- q:: Float64
14- kernel
11+ struct GaussianProcessDynamics{T<: Real ,KT<: Kernel } <: LatentDynamics{T,T}
12+ proc:: GP{ZeroMean{T},KT}
13+ function GaussianProcessDynamics (:: Type{T} , kernel:: KT ) where {T<: Real ,KT<: Kernel }
14+ return new {T,KT} (GP (ZeroMean {T} (), kernel))
15+ end
1516end
1617
17- mutable struct GPSSM <: SSMProblems.AbstractStateSpaceModel
18- X:: Vector{Float64}
19- observations:: Vector{Float64}
20- θ:: Parameters
21-
22- GPSSM (params:: Parameters ) = new (Vector {Float64} (), params)
23- GPSSM (y:: Vector{Float64} , params:: Parameters ) = new (Vector {Float64} (), y, params)
18+ struct LinearGaussianDynamics{T<: Real } <: LatentDynamics{T,T}
19+ a:: T
20+ b:: T
21+ q:: T
2422end
2523
26- seed = 1
27- T = 100
28- Nₚ = 20
29- Nₛ = 250
30- a = 0.9
31- q = 0.5
32-
33- params = Parameters ((a, q, SqExponentialKernel ()))
24+ function SSMProblems. distribution (proc:: LinearGaussianDynamics{T} ) where {T<: Real }
25+ return Normal (zero (T), proc. q)
26+ end
3427
35- f (θ :: Parameters , x, t) = Normal (θ . a * x, θ . q )
36- h (θ :: Parameters ) = Normal (0 , θ . q)
37- g (θ :: Parameters , x, t) = Normal ( 0 , exp ( 0.5 * x) ^ 2 )
28+ function SSMProblems . distribution (proc :: LinearGaussianDynamics , :: Int , state )
29+ return Normal (proc . a * state + proc . b, proc . q)
30+ end
3831
39- rng = Random . MersenneTwister (seed)
32+ struct StochasticVolatility{T <: Real } <: ObservationProcess{T,T} end
4033
41- x = zeros (T)
42- y = similar (x)
43- x[1 ] = rand (rng, h (params))
44- for t in 1 : T
45- if t < T
46- x[t + 1 ] = rand (rng, f (params, x[t], t))
47- end
48- y[t] = rand (rng, g (params, x[t], t))
34+ function SSMProblems. distribution (:: StochasticVolatility{T} , :: Int , state) where {T<: Real }
35+ return Normal (zero (T), exp ((1 / 2 ) * state))
4936end
5037
51- function gp_update (model:: GPSSM , state, step)
52- gp = GP (model. θ. kernel)
53- prior = gp (1 : (step - 1 ))
54- post = posterior (prior, model. X[1 : (step - 1 )])
55- μ, σ = mean_and_cov (post, [step])
56- return Normal (μ[1 ], σ[1 ])
38+ function LinearGaussianStochasticVolatilityModel (a:: T , q:: T ) where {T<: Real }
39+ dyn = LinearGaussianDynamics (a, zero (T), q)
40+ obs = StochasticVolatility {T} ()
41+ return SSMProblems. StateSpaceModel (dyn, obs)
5742end
5843
59- SSMProblems. transition!! (rng:: AbstractRNG , model:: GPSSM ) = rand (rng, h (model. θ))
60- function SSMProblems. transition!! (rng:: AbstractRNG , model:: GPSSM , state, step)
61- return rand (rng, gp_update (model, state, step))
44+ function GaussianProcessStateSpaceModel (:: Type{T} , kernel:: KT ) where {T<: Real ,KT<: Kernel }
45+ dyn = GaussianProcessDynamics (T, kernel)
46+ obs = StochasticVolatility {T} ()
47+ return SSMProblems. StateSpaceModel (dyn, obs)
6248end
6349
64- function SSMProblems. emission_logdensity (model:: GPSSM , state, step)
65- return logpdf (g (model. θ, state, step), model. observations[step])
66- end
67- function SSMProblems. transition_logdensity (model:: GPSSM , prev_state, current_state, step)
68- return logpdf (gp_update (model, prev_state, step), current_state)
50+ const GPSSM{T,KT<: Kernel } = SSMProblems. StateSpaceModel{
51+ T,
52+ GaussianProcessDynamics{T,KT},
53+ StochasticVolatility{T}
54+ };
55+
56+ # for non-markovian models, we can redefine dynamics to reference the trajectory
57+ function AdvancedPS. dynamics (
58+ ssm:: AdvancedPS.TracedSSM{<:GPSSM{T},T,T} , step:: Int
59+ ) where {T<: Real }
60+ prior = ssm. model. dyn. proc (1 : (step - 1 ))
61+ post = posterior (prior, ssm. X[1 : (step - 1 )])
62+ μ, σ = mean_and_cov (post, [step])
63+ return LinearGaussianDynamics (zero (T), μ[1 ], sqrt (σ[1 ]))
6964end
7065
71- AdvancedPS. isdone (:: GPSSM , step) = step > T
66+ # Everything is now ready to simulate some data.
67+ rng = MersenneTwister (1234 );
68+ true_model = LinearGaussianStochasticVolatilityModel (0.9 , 0.5 );
69+ _, x, y = sample (rng, true_model, 100 );
7270
73- model = GPSSM (y, params)
74- pg = AdvancedPS. PGAS (Nₚ)
75- chains = sample (rng, model, pg, Nₛ)
71+ # Create the model and run the sampler
72+ gpssm = GaussianProcessStateSpaceModel (Float64, SqExponentialKernel ());
73+ model = gpssm (y);
74+ pg = AdvancedPS. PGAS (20 );
75+ chains = sample (rng, model, pg, 250 ; progress= false );
76+ # md nothing #hide
7677
77- particles = hcat ([chain. trajectory. model. X for chain in chains]. .. )
78+ particles = hcat ([chain. trajectory. model. X for chain in chains]. .. );
7879mean_trajectory = mean (particles; dims= 2 );
80+ # md nothing #hide
7981
8082scatter (particles; label= false , opacity= 0.01 , color= :black , xlabel= " t" , ylabel= " state" )
8183plot! (x; color= :darkorange , label= " Original Trajectory" )
0 commit comments