Skip to content

Commit 5632c92

Browse files
committed
improved GP fix
1 parent 9e2f73f commit 5632c92

File tree

3 files changed

+43
-50
lines changed

3 files changed

+43
-50
lines changed

examples/gaussian-process/script.jl

Lines changed: 37 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -11,81 +11,74 @@ using Distributions
1111
using Libtask
1212
using SSMProblems
1313

14-
# Gaussian process encoded transition dynamics
15-
mutable struct GaussianProcessDynamics{T<:Real} <: SSMProblems.LatentDynamics{T,T}
16-
proc::AbstractGPs.AbstractGP
17-
q::T
18-
function GaussianProcessDynamics(q::T, kernel::KT) where {T<:Real,KT<:Kernel}
19-
return new{T}(GP(ZeroMean{T}(), kernel), q)
14+
struct GaussianProcessDynamics{T<:Real,KT<:Kernel} <: LatentDynamics{T,T}
15+
proc::GP{ZeroMean{T},KT}
16+
function GaussianProcessDynamics(::Type{T}, kernel::KT) where {T<:Real,KT<:Kernel}
17+
return new{T,KT}(GP(ZeroMean{T}(), kernel))
2018
end
2119
end
2220

23-
function SSMProblems.distribution(dyn::GaussianProcessDynamics{T}) where {T<:Real}
24-
return Normal(zero(T), dyn.q)
25-
end
26-
27-
# TODO: broken...
28-
function SSMProblems.simulate(
29-
rng::AbstractRNG, dyn::GaussianProcessDynamics, step::Int, state
30-
)
31-
dyn.proc = posterior(dyn.proc(step:step, 0.1), [state])
32-
μ, σ = mean_and_cov(dyn.proc, [step])
33-
return rand(rng, Normal(μ[1], sqrt(σ[1])))
34-
end
35-
36-
function SSMProblems.logdensity(dyn::GaussianProcessDynamics, step::Int, state, prev_state)
37-
μ, σ = mean_and_cov(dyn.proc, [step])
38-
return logpdf(Normal(μ[1], sqrt(σ[1])), state)
39-
end
40-
41-
# Linear Gaussian dynamics used for simulation
42-
struct LinearGaussianDynamics{T<:Real} <: SSMProblems.LatentDynamics{T,T}
21+
struct LinearGaussianDynamics{T<:Real} <: LatentDynamics{T,T}
4322
a::T
23+
b::T
4424
q::T
4525
end
4626

47-
function SSMProblems.distribution(dyn::LinearGaussianDynamics{T}) where {T<:Real}
48-
return Normal(zero(T), dyn.q)
27+
function SSMProblems.distribution(proc::LinearGaussianDynamics{T}) where {T<:Real}
28+
return Normal(zero(T), proc.q)
4929
end
5030

51-
function SSMProblems.distribution(dyn::LinearGaussianDynamics, ::Int, state)
52-
return Normal(dyn.a * state, dyn.q)
31+
function SSMProblems.distribution(proc::LinearGaussianDynamics, ::Int, state)
32+
return Normal(proc.a * state + proc.b, proc.q)
5333
end
5434

55-
# Observation process used in both variants of the model
56-
struct StochasticVolatility{T<:Real} <: SSMProblems.ObservationProcess{T,T} end
35+
struct StochasticVolatility{T<:Real} <: ObservationProcess{T,T} end
5736

5837
function SSMProblems.distribution(::StochasticVolatility{T}, ::Int, state) where {T<:Real}
5938
return Normal(zero(T), exp((1 / 2) * state))
6039
end
6140

62-
# Baseline model (for simulation)
6341
function LinearGaussianStochasticVolatilityModel(a::T, q::T) where {T<:Real}
64-
dyn = LinearGaussianDynamics(a, q)
42+
dyn = LinearGaussianDynamics(a, zero(T), q)
6543
obs = StochasticVolatility{T}()
6644
return SSMProblems.StateSpaceModel(dyn, obs)
6745
end
6846

69-
# Gaussian process model (for sampling)
70-
function GaussianProcessStateSpaceModel(q::T, kernel::KT) where {T<:Real,KT<:Kernel}
71-
dyn = GaussianProcessDynamics(q, kernel)
47+
function GaussianProcessStateSpaceModel(::Type{T}, kernel::KT) where {T<:Real,KT<:Kernel}
48+
dyn = GaussianProcessDynamics(T, kernel)
7249
obs = StochasticVolatility{T}()
7350
return SSMProblems.StateSpaceModel(dyn, obs)
7451
end
7552

53+
const GPSSM{T,KT<:Kernel} = SSMProblems.StateSpaceModel{
54+
T,
55+
GaussianProcessDynamics{T,KT},
56+
StochasticVolatility{T}
57+
};
58+
59+
# for non-markovian models, we can redefine dynamics to reference the trajectory
60+
function AdvancedPS.dynamics(
61+
ssm::AdvancedPS.TracedSSM{<:GPSSM{T},T,T}, step::Int
62+
) where {T<:Real}
63+
prior = ssm.model.dyn.proc(1:(step - 1))
64+
post = posterior(prior, ssm.X[1:(step - 1)])
65+
μ, σ = mean_and_cov(post, [step])
66+
return LinearGaussianDynamics(zero(T), μ[1], sqrt(σ[1]))
67+
end
68+
7669
# Everything is now ready to simulate some data.
77-
rng = Random.MersenneTwister(1234)
78-
true_model = LinearGaussianStochasticVolatilityModel(0.9, 0.5)
70+
rng = MersenneTwister(1234);
71+
true_model = LinearGaussianStochasticVolatilityModel(0.9, 0.5);
7972
_, x, y = sample(rng, true_model, 100);
8073

8174
# Create the model and run the sampler
82-
gpssm = GaussianProcessStateSpaceModel(0.5, SqExponentialKernel())
83-
model = gpssm(y)
84-
pg = AdvancedPS.PGAS(20)
85-
chains = sample(rng, model, pg, 50)
75+
gpssm = GaussianProcessStateSpaceModel(Float64, SqExponentialKernel());
76+
model = gpssm(y);
77+
pg = AdvancedPS.PGAS(20);
78+
chains = sample(rng, model, pg, 250; progress=false);
8679
#md nothing #hide
8780

88-
particles = hcat([chain.trajectory.model.X for chain in chains]...)
81+
particles = hcat([chain.trajectory.model.X for chain in chains]...);
8982
mean_trajectory = mean(particles; dims=2);
9083
#md nothing #hide
9184

src/model.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ mutable struct TracedSSM{SSM,XT,YT} <: SSMProblems.AbstractStateSpaceModel
3131
end
3232

3333
(model::SSMProblems.StateSpaceModel)(Y::AbstractVector) = TracedSSM(model, Y)
34-
dynamics(ssm::TracedSSM) = ssm.model.dyn
35-
observation(ssm::TracedSSM) = ssm.model.obs
34+
dynamics(ssm::TracedSSM, step::Int) = ssm.model.dyn
35+
observation(ssm::TracedSSM, step::Int) = ssm.model.obs
3636

3737
isdone(ssm::TracedSSM, step::Int) = step > length(ssm.Y)
3838

src/pgas.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ Get the log weight of the transition from previous state of `model` to `x`
2626
function transition_logweight(particle::SSMTrace, x; kwargs...)
2727
iter = current_step(particle) - 1
2828
score = SSMProblems.logdensity(
29-
dynamics(particle.model), iter, particle.model.X[iter - 1], x, kwargs...
29+
dynamics(particle.model, iter), iter, particle.model.X[iter - 1], x, kwargs...
3030
)
3131
return score
3232
end
@@ -59,11 +59,11 @@ function advance!(particle::SSMTrace, isref::Bool=false)
5959

6060
if !isref
6161
if running_step == 1
62-
new_state = SSMProblems.simulate(particle.rng, dynamics(model))
62+
new_state = SSMProblems.simulate(particle.rng, dynamics(model, running_step))
6363
else
6464
current_state = model.X[running_step - 1]
6565
new_state = SSMProblems.simulate(
66-
particle.rng, dynamics(model), running_step, current_state
66+
particle.rng, dynamics(model, running_step), running_step, current_state
6767
)
6868
end
6969
else
@@ -72,7 +72,7 @@ function advance!(particle::SSMTrace, isref::Bool=false)
7272
end
7373

7474
score = SSMProblems.logdensity(
75-
observation(model), running_step, new_state, model.Y[running_step]
75+
observation(model, running_step), running_step, new_state, model.Y[running_step]
7676
)
7777

7878
# Accept transition and move the time index/rng counter

0 commit comments

Comments
 (0)