@@ -11,81 +11,74 @@ using Distributions
1111using Libtask
1212using SSMProblems
1313
14- # Gaussian process encoded transition dynamics
15- mutable struct GaussianProcessDynamics{T<: Real } <: SSMProblems.LatentDynamics{T,T}
16- proc:: AbstractGPs.AbstractGP
17- q:: T
18- function GaussianProcessDynamics (q:: T , kernel:: KT ) where {T<: Real ,KT<: Kernel }
19- return new {T} (GP (ZeroMean {T} (), kernel), q)
14+ struct GaussianProcessDynamics{T<: Real ,KT<: Kernel } <: LatentDynamics{T,T}
15+ proc:: GP{ZeroMean{T},KT}
16+ function GaussianProcessDynamics (:: Type{T} , kernel:: KT ) where {T<: Real ,KT<: Kernel }
17+ return new {T,KT} (GP (ZeroMean {T} (), kernel))
2018 end
2119end
2220
23- function SSMProblems. distribution (dyn:: GaussianProcessDynamics{T} ) where {T<: Real }
24- return Normal (zero (T), dyn. q)
25- end
26-
27- # TODO : broken...
28- function SSMProblems. simulate (
29- rng:: AbstractRNG , dyn:: GaussianProcessDynamics , step:: Int , state
30- )
31- dyn. proc = posterior (dyn. proc (step: step, 0.1 ), [state])
32- μ, σ = mean_and_cov (dyn. proc, [step])
33- return rand (rng, Normal (μ[1 ], sqrt (σ[1 ])))
34- end
35-
36- function SSMProblems. logdensity (dyn:: GaussianProcessDynamics , step:: Int , state, prev_state)
37- μ, σ = mean_and_cov (dyn. proc, [step])
38- return logpdf (Normal (μ[1 ], sqrt (σ[1 ])), state)
39- end
40-
41- # Linear Gaussian dynamics used for simulation
42- struct LinearGaussianDynamics{T<: Real } <: SSMProblems.LatentDynamics{T,T}
21+ struct LinearGaussianDynamics{T<: Real } <: LatentDynamics{T,T}
4322 a:: T
23+ b:: T
4424 q:: T
4525end
4626
47- function SSMProblems. distribution (dyn :: LinearGaussianDynamics{T} ) where {T<: Real }
48- return Normal (zero (T), dyn . q)
27+ function SSMProblems. distribution (proc :: LinearGaussianDynamics{T} ) where {T<: Real }
28+ return Normal (zero (T), proc . q)
4929end
5030
51- function SSMProblems. distribution (dyn :: LinearGaussianDynamics , :: Int , state)
52- return Normal (dyn . a * state, dyn . q)
31+ function SSMProblems. distribution (proc :: LinearGaussianDynamics , :: Int , state)
32+ return Normal (proc . a * state + proc . b, proc . q)
5333end
5434
55- # Observation process used in both variants of the model
56- struct StochasticVolatility{T<: Real } <: SSMProblems.ObservationProcess{T,T} end
35+ struct StochasticVolatility{T<: Real } <: ObservationProcess{T,T} end
5736
5837function SSMProblems. distribution (:: StochasticVolatility{T} , :: Int , state) where {T<: Real }
5938 return Normal (zero (T), exp ((1 / 2 ) * state))
6039end
6140
62- # Baseline model (for simulation)
6341function LinearGaussianStochasticVolatilityModel (a:: T , q:: T ) where {T<: Real }
64- dyn = LinearGaussianDynamics (a, q)
42+ dyn = LinearGaussianDynamics (a, zero (T), q)
6543 obs = StochasticVolatility {T} ()
6644 return SSMProblems. StateSpaceModel (dyn, obs)
6745end
6846
69- # Gaussian process model (for sampling)
70- function GaussianProcessStateSpaceModel (q:: T , kernel:: KT ) where {T<: Real ,KT<: Kernel }
71- dyn = GaussianProcessDynamics (q, kernel)
47+ function GaussianProcessStateSpaceModel (:: Type{T} , kernel:: KT ) where {T<: Real ,KT<: Kernel }
48+ dyn = GaussianProcessDynamics (T, kernel)
7249 obs = StochasticVolatility {T} ()
7350 return SSMProblems. StateSpaceModel (dyn, obs)
7451end
7552
53+ const GPSSM{T,KT<: Kernel } = SSMProblems. StateSpaceModel{
54+ T,
55+ GaussianProcessDynamics{T,KT},
56+ StochasticVolatility{T}
57+ };
58+
59+ # for non-markovian models, we can redefine dynamics to reference the trajectory
60+ function AdvancedPS. dynamics (
61+ ssm:: AdvancedPS.TracedSSM{<:GPSSM{T},T,T} , step:: Int
62+ ) where {T<: Real }
63+ prior = ssm. model. dyn. proc (1 : (step - 1 ))
64+ post = posterior (prior, ssm. X[1 : (step - 1 )])
65+ μ, σ = mean_and_cov (post, [step])
66+ return LinearGaussianDynamics (zero (T), μ[1 ], sqrt (σ[1 ]))
67+ end
68+
7669# Everything is now ready to simulate some data.
77- rng = Random . MersenneTwister (1234 )
78- true_model = LinearGaussianStochasticVolatilityModel (0.9 , 0.5 )
70+ rng = MersenneTwister (1234 );
71+ true_model = LinearGaussianStochasticVolatilityModel (0.9 , 0.5 );
7972_, x, y = sample (rng, true_model, 100 );
8073
8174# Create the model and run the sampler
82- gpssm = GaussianProcessStateSpaceModel (0.5 , SqExponentialKernel ())
83- model = gpssm (y)
84- pg = AdvancedPS. PGAS (20 )
85- chains = sample (rng, model, pg, 50 )
75+ gpssm = GaussianProcessStateSpaceModel (Float64 , SqExponentialKernel ());
76+ model = gpssm (y);
77+ pg = AdvancedPS. PGAS (20 );
78+ chains = sample (rng, model, pg, 250 ; progress = false );
8679# md nothing #hide
8780
88- particles = hcat ([chain. trajectory. model. X for chain in chains]. .. )
81+ particles = hcat ([chain. trajectory. model. X for chain in chains]. .. );
8982mean_trajectory = mean (particles; dims= 2 );
9083# md nothing #hide
9184
0 commit comments