@@ -11,81 +11,74 @@ using Distributions
11
11
using Libtask
12
12
using SSMProblems
13
13
14
- # Gaussian process encoded transition dynamics
15
- mutable struct GaussianProcessDynamics{T<: Real } <: SSMProblems.LatentDynamics{T,T}
16
- proc:: AbstractGPs.AbstractGP
17
- q:: T
18
- function GaussianProcessDynamics (q:: T , kernel:: KT ) where {T<: Real ,KT<: Kernel }
19
- return new {T} (GP (ZeroMean {T} (), kernel), q)
14
+ struct GaussianProcessDynamics{T<: Real ,KT<: Kernel } <: LatentDynamics{T,T}
15
+ proc:: GP{ZeroMean{T},KT}
16
+ function GaussianProcessDynamics (:: Type{T} , kernel:: KT ) where {T<: Real ,KT<: Kernel }
17
+ return new {T,KT} (GP (ZeroMean {T} (), kernel))
20
18
end
21
19
end
22
20
23
- function SSMProblems. distribution (dyn:: GaussianProcessDynamics{T} ) where {T<: Real }
24
- return Normal (zero (T), dyn. q)
25
- end
26
-
27
- # TODO : broken...
28
- function SSMProblems. simulate (
29
- rng:: AbstractRNG , dyn:: GaussianProcessDynamics , step:: Int , state
30
- )
31
- dyn. proc = posterior (dyn. proc (step: step, 0.1 ), [state])
32
- μ, σ = mean_and_cov (dyn. proc, [step])
33
- return rand (rng, Normal (μ[1 ], sqrt (σ[1 ])))
34
- end
35
-
36
- function SSMProblems. logdensity (dyn:: GaussianProcessDynamics , step:: Int , state, prev_state)
37
- μ, σ = mean_and_cov (dyn. proc, [step])
38
- return logpdf (Normal (μ[1 ], sqrt (σ[1 ])), state)
39
- end
40
-
41
- # Linear Gaussian dynamics used for simulation
42
- struct LinearGaussianDynamics{T<: Real } <: SSMProblems.LatentDynamics{T,T}
21
+ struct LinearGaussianDynamics{T<: Real } <: LatentDynamics{T,T}
43
22
a:: T
23
+ b:: T
44
24
q:: T
45
25
end
46
26
47
- function SSMProblems. distribution (dyn :: LinearGaussianDynamics{T} ) where {T<: Real }
48
- return Normal (zero (T), dyn . q)
27
+ function SSMProblems. distribution (proc :: LinearGaussianDynamics{T} ) where {T<: Real }
28
+ return Normal (zero (T), proc . q)
49
29
end
50
30
51
- function SSMProblems. distribution (dyn :: LinearGaussianDynamics , :: Int , state)
52
- return Normal (dyn . a * state, dyn . q)
31
+ function SSMProblems. distribution (proc :: LinearGaussianDynamics , :: Int , state)
32
+ return Normal (proc . a * state + proc . b, proc . q)
53
33
end
54
34
55
- # Observation process used in both variants of the model
56
- struct StochasticVolatility{T<: Real } <: SSMProblems.ObservationProcess{T,T} end
35
+ struct StochasticVolatility{T<: Real } <: ObservationProcess{T,T} end
57
36
58
37
function SSMProblems. distribution (:: StochasticVolatility{T} , :: Int , state) where {T<: Real }
59
38
return Normal (zero (T), exp ((1 / 2 ) * state))
60
39
end
61
40
62
- # Baseline model (for simulation)
63
41
function LinearGaussianStochasticVolatilityModel (a:: T , q:: T ) where {T<: Real }
64
- dyn = LinearGaussianDynamics (a, q)
42
+ dyn = LinearGaussianDynamics (a, zero (T), q)
65
43
obs = StochasticVolatility {T} ()
66
44
return SSMProblems. StateSpaceModel (dyn, obs)
67
45
end
68
46
69
- # Gaussian process model (for sampling)
70
- function GaussianProcessStateSpaceModel (q:: T , kernel:: KT ) where {T<: Real ,KT<: Kernel }
71
- dyn = GaussianProcessDynamics (q, kernel)
47
+ function GaussianProcessStateSpaceModel (:: Type{T} , kernel:: KT ) where {T<: Real ,KT<: Kernel }
48
+ dyn = GaussianProcessDynamics (T, kernel)
72
49
obs = StochasticVolatility {T} ()
73
50
return SSMProblems. StateSpaceModel (dyn, obs)
74
51
end
75
52
53
+ const GPSSM{T,KT<: Kernel } = SSMProblems. StateSpaceModel{
54
+ T,
55
+ GaussianProcessDynamics{T,KT},
56
+ StochasticVolatility{T}
57
+ };
58
+
59
+ # for non-markovian models, we can redefine dynamics to reference the trajectory
60
+ function AdvancedPS. dynamics (
61
+ ssm:: AdvancedPS.TracedSSM{<:GPSSM{T},T,T} , step:: Int
62
+ ) where {T<: Real }
63
+ prior = ssm. model. dyn. proc (1 : (step - 1 ))
64
+ post = posterior (prior, ssm. X[1 : (step - 1 )])
65
+ μ, σ = mean_and_cov (post, [step])
66
+ return LinearGaussianDynamics (zero (T), μ[1 ], sqrt (σ[1 ]))
67
+ end
68
+
76
69
# Everything is now ready to simulate some data.
77
- rng = Random . MersenneTwister (1234 )
78
- true_model = LinearGaussianStochasticVolatilityModel (0.9 , 0.5 )
70
+ rng = MersenneTwister (1234 );
71
+ true_model = LinearGaussianStochasticVolatilityModel (0.9 , 0.5 );
79
72
_, x, y = sample (rng, true_model, 100 );
80
73
81
74
# Create the model and run the sampler
82
- gpssm = GaussianProcessStateSpaceModel (0.5 , SqExponentialKernel ())
83
- model = gpssm (y)
84
- pg = AdvancedPS. PGAS (20 )
85
- chains = sample (rng, model, pg, 50 )
75
+ gpssm = GaussianProcessStateSpaceModel (Float64 , SqExponentialKernel ());
76
+ model = gpssm (y);
77
+ pg = AdvancedPS. PGAS (20 );
78
+ chains = sample (rng, model, pg, 250 ; progress = false );
86
79
# md nothing #hide
87
80
88
- particles = hcat ([chain. trajectory. model. X for chain in chains]. .. )
81
+ particles = hcat ([chain. trajectory. model. X for chain in chains]. .. );
89
82
mean_trajectory = mean (particles; dims= 2 );
90
83
# md nothing #hide
91
84
0 commit comments