Skip to content

Commit ecfe977

Browse files
committed
Fix examples and streamline Levy SSM
1 parent fc63162 commit ecfe977

File tree

4 files changed

+104
-109
lines changed

4 files changed

+104
-109
lines changed

examples/gaussian-process/script.jl

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,59 +8,63 @@ using Distributions
88
using Libtask
99
using SSMProblems
1010

11-
struct GaussianProcessDynamics{T<:Real,KT<:Kernel} <: LatentDynamics{T,T}
11+
struct GaussianProcessDynamics{T<:Real,KT<:Kernel} <: SSMProblems.LatentDynamics
1212
proc::GP{ZeroMean{T},KT}
1313
function GaussianProcessDynamics(::Type{T}, kernel::KT) where {T<:Real,KT<:Kernel}
1414
return new{T,KT}(GP(ZeroMean{T}(), kernel))
1515
end
1616
end
1717

18-
struct LinearGaussianDynamics{T<:Real} <: LatentDynamics{T,T}
19-
a::T
20-
b::T
21-
q::T
18+
struct GaussianPrior{ΣT<:Real} <: SSMProblems.StatePrior
19+
σ::ΣT
2220
end
2321

24-
function SSMProblems.distribution(proc::LinearGaussianDynamics{T}) where {T<:Real}
25-
return Normal(zero(T), proc.q)
22+
SSMProblems.distribution(proc::GaussianPrior) = Normal(0, proc.σ)
23+
24+
struct LinearGaussianDynamics{AT<:Real,BT<:Real,QT<:Real} <: SSMProblems.LatentDynamics
25+
a::AT
26+
b::BT
27+
q::QT
2628
end
2729

2830
function SSMProblems.distribution(proc::LinearGaussianDynamics, ::Int, state)
2931
return Normal(proc.a * state + proc.b, proc.q)
3032
end
3133

32-
struct StochasticVolatility{T<:Real} <: ObservationProcess{T,T} end
34+
struct StochasticVolatility <: SSMProblems.ObservationProcess end
3335

34-
function SSMProblems.distribution(::StochasticVolatility{T}, ::Int, state) where {T<:Real}
35-
return Normal(zero(T), exp((1 / 2) * state))
36+
function SSMProblems.distribution(::StochasticVolatility, ::Int, state)
37+
return Normal(0, exp(state / 2))
3638
end
3739

38-
function LinearGaussianStochasticVolatilityModel(a::T, q::T) where {T<:Real}
39-
dyn = LinearGaussianDynamics(a, zero(T), q)
40-
obs = StochasticVolatility{T}()
41-
return SSMProblems.StateSpaceModel(dyn, obs)
40+
function LinearGaussianStochasticVolatilityModel(a, q)
41+
prior = GaussianPrior(q)
42+
dyn = LinearGaussianDynamics(a, 0, q)
43+
obs = StochasticVolatility()
44+
return SSMProblems.StateSpaceModel(prior, dyn, obs)
4245
end
4346

4447
function GaussianProcessStateSpaceModel(::Type{T}, kernel::KT) where {T<:Real,KT<:Kernel}
48+
prior = GaussianPrior(one(T))
4549
dyn = GaussianProcessDynamics(T, kernel)
46-
obs = StochasticVolatility{T}()
47-
return SSMProblems.StateSpaceModel(dyn, obs)
50+
obs = StochasticVolatility()
51+
return SSMProblems.StateSpaceModel(prior, dyn, obs)
4852
end
4953

5054
const GPSSM{T,KT<:Kernel} = SSMProblems.StateSpaceModel{
51-
T,
52-
GaussianProcessDynamics{T,KT},
53-
StochasticVolatility{T}
55+
<:GaussianPrior,
56+
<:GaussianProcessDynamics{T,KT},
57+
StochasticVolatility
5458
};
5559

5660
# for non-markovian models, we can redefine dynamics to reference the trajectory
5761
function AdvancedPS.dynamics(
58-
ssm::AdvancedPS.TracedSSM{<:GPSSM{T},T,T}, step::Int
59-
) where {T<:Real}
62+
ssm::AdvancedPS.TracedSSM{<:GPSSM}, step::Int
63+
)
6064
prior = ssm.model.dyn.proc(1:(step - 1))
6165
post = posterior(prior, ssm.X[1:(step - 1)])
6266
μ, σ = mean_and_cov(post, [step])
63-
return LinearGaussianDynamics(zero(T), μ[1], sqrt(σ[1]))
67+
return LinearGaussianDynamics(0, μ[1], sqrt(σ[1]))
6468
end
6569

6670
# Everything is now ready to simulate some data.
@@ -70,9 +74,9 @@ _, x, y = sample(rng, true_model, 100);
7074

7175
# Create the model and run the sampler
7276
gpssm = GaussianProcessStateSpaceModel(Float64, SqExponentialKernel());
73-
model = gpssm(y);
77+
model = AdvancedPS.TracedSSM(gpssm, y);
7478
pg = AdvancedPS.PGAS(20);
75-
chains = sample(rng, model, pg, 250; progress=false);
79+
chains = sample(rng, model, pg, 250);
7680
#md nothing #hide
7781

7882
particles = hcat([chain.trajectory.model.X for chain in chains]...);

examples/gaussian-ssm/script.jl

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,27 +28,31 @@ using SSMProblems
2828
# as well as the initial distribution $f_0(x) = \mathcal{N}(0, q^2/(1-a^2))$.
2929

3030
# To use `AdvancedPS` we first need to define a model type that subtypes `AdvancedPS.AbstractStateSpaceModel`.
31-
mutable struct Parameters{T<:Real}
32-
a::T
33-
q::T
34-
r::T
31+
mutable struct Parameters{AT<:Real,QT<:Real,RT<:Real}
32+
a::AT
33+
q::QT
34+
r::RT
3535
end
3636

37-
struct LinearGaussianDynamics{T<:Real} <: SSMProblems.LatentDynamics{T,T}
38-
a::T
39-
q::T
37+
struct GaussianPrior <: SSMProblems.StatePrior
38+
σ::ΣT
4039
end
4140

42-
function SSMProblems.distribution(dyn::LinearGaussianDynamics{T}; kwargs...) where {T<:Real}
43-
return Normal(zero(T), sqrt(dyn.q^2 / (1 - dyn.a^2)))
41+
struct LinearGaussianDynamics{AT<:Real,QT<:Real} <: SSMProblems.LatentDynamics
42+
a::AT
43+
q::QT
44+
end
45+
46+
function SSMProblems.distribution(dyn::GaussianPrior; kwargs...)
47+
return Normal(0, prior.σ)
4448
end
4549

4650
function SSMProblems.distribution(dyn::LinearGaussianDynamics, step::Int, state; kwargs...)
4751
return Normal(dyn.a * state, dyn.q)
4852
end
4953

50-
struct LinearGaussianObservation{T<:Real} <: SSMProblems.ObservationProcess{T,T}
51-
r::T
54+
struct LinearGaussianObservation{RT<:Real} <: SSMProblems.ObservationProcess
55+
r::RT
5256
end
5357

5458
function SSMProblems.distribution(
@@ -58,6 +62,7 @@ function SSMProblems.distribution(
5862
end
5963

6064
function LinearGaussianStateSpaceModel::Parameters)
65+
prior = GaussianPrior(sqrt.q^2 / (1 - θ.a^2)))
6166
dyn = LinearGaussianDynamics.a, θ.q)
6267
obs = LinearGaussianObservation.r)
6368
return SSMProblems.StateSpaceModel(dyn, obs)

examples/levy-ssm/script.jl

Lines changed: 43 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,16 @@ function simulate(
2727
t = t0
2828
truncated = last_jump < tolerance
2929
while !truncated
30-
t += rand(rng, Exponential(one(T) / rate))
31-
xi = one(T) /* (exp(t / C) - one(T)))
32-
prob = (one(T) + β * xi) * exp(-β * xi)
30+
t += rand(rng, Exponential(1 / rate))
31+
xi = 1 /* (exp(t / C) - 1))
32+
prob = (1 + β * xi) * exp(-β * xi)
3333
if rand(rng) < prob
3434
push!(jumps, xi)
3535
last_jump = xi
3636
end
3737
truncated = last_jump < tolerance
3838
end
39-
times = rand(rng, Uniform(start, finish), length(jumps))
40-
return GammaPath(jumps, times)
39+
return GammaPath(jumps, rand(rng, Uniform(start, finish), length(jumps)))
4140
end
4241
end
4342

@@ -47,85 +46,67 @@ function integral(times::Array{<:Real}, path::GammaPath)
4746
end
4847
end
4948

50-
struct LangevinDynamics{T}
51-
A::Matrix{T}
52-
L::Vector{T}
53-
θ::T
54-
H::Vector{T}
55-
σe::T
49+
struct LangevinDynamics{AT<:AbstractMatrix,LT<:AbstractVector,θT<:Real}
50+
A::AT
51+
L::LT
52+
θ::θT
5653
end
5754

58-
struct NormalMeanVariance{T}
59-
μ::T
60-
σ::T
55+
function Base.exp(dyn::LangevinDynamics, dt)
56+
f_val = exp(dyn.θ * dt)
57+
return [1 (f_val - 1)/dyn.θ; 0 f_val]
6158
end
6259

63-
f(dt, θ) = exp* dt)
64-
function Base.exp(dyn::LangevinDynamics{T}, dt::T) where {T<:Real}
65-
let θ = dyn.θ
66-
f_val = f(dt, θ)
67-
return [one(T) (f_val - 1)/θ; zero(T) f_val]
68-
end
69-
end
70-
71-
function meancov(
72-
t::T, dyn::LangevinDynamics, path::GammaPath, nvm::NormalMeanVariance
73-
) where {T<:Real}
74-
μ = zeros(T, 2)
75-
Σ = zeros(T, (2, 2))
76-
let times = path.times, jumps = path.jumps, μw = nvm.μ, σw = nvm.σ
77-
for (v, z) in zip(times, jumps)
78-
ft = exp(dyn, (t - v)) * dyn.L
79-
μ += ft .* μw .* z
80-
Σ += ft * transpose(ft) .* σw^2 .* z
81-
end
60+
function meancov(t, dyn::LangevinDynamics, path::GammaPath, dist::Normal)
61+
fts = exp.(Ref(dyn), (t .- path.times)) .* Ref(dyn.L)
62+
μ = sum(@. fts * mean(dist) * path.jumps)
63+
Σ = sum(@. fts * transpose(fts) * var(dist) * path.jumps)
8264

83-
# Guarantees positive semi-definiteness
84-
return μ, Σ + T(1e-6) * I
85-
end
65+
# Guarantees positive semi-definiteness
66+
return μ, Σ + eltype(Σ)(1e-6) * I
8667
end
8768

88-
struct LevyLangevin{T} <: LatentDynamics{T,Vector{T}}
89-
dt::T
90-
dyn::LangevinDynamics{T}
91-
process::GammaProcess{T}
92-
nvm::NormalMeanVariance{T}
69+
struct LevyPrior{XT<:AbstractVector,ΣT<:AbstractMatrix} <: StatePrior
70+
μ::XT
71+
Σ::ΣT
9372
end
9473

95-
function SSMProblems.distribution(proc::LevyLangevin{T}) where {T<:Real}
96-
return MultivariateNormal(zeros(T, 2), I)
74+
SSMProblems.distribution(proc::LevyPrior) = MvNormal(proc.μ, proc.Σ)
75+
76+
struct LevyLangevin{T<:Real,LT<:LangevinDynamics,ΓT<:GammaProcess,DT<:Normal} <:
77+
SSMProblems.LatentDynamics
78+
dt::T
79+
dyn::LT
80+
process::ΓT
81+
dist::DT
9782
end
9883

99-
function SSMProblems.distribution(proc::LevyLangevin{T}, step::Int, state) where {T<:Real}
84+
function SSMProblems.distribution(proc::LevyLangevin, step::Int, state)
10085
dt = proc.dt
10186
path = simulate(rng, proc.process, dt, (step - 1) * dt, step * dt)
102-
μ, Σ = meancov(step * dt, proc.dyn, path, proc.nvm)
103-
return MultivariateNormal(exp(proc.dyn, dt) * state + μ, Σ)
87+
μ, Σ = meancov(step * dt, proc.dyn, path, proc.dist)
88+
return MvNormal(exp(proc.dyn, dt) * state + μ, Σ)
10489
end
10590

106-
struct LinearGaussianObservation{T<:Real} <: ObservationProcess{T,T}
107-
H::Vector{T}
108-
R::T
91+
struct LinearGaussianObservation{HT<:AbstractVector,RT<:Real} <: SSMProblems.ObservationProcess
92+
H::HT
93+
R::RT
10994
end
11095

111-
function SSMProblems.distribution(proc::LinearGaussianObservation, step::Int, state)
96+
function SSMProblems.distribution(proc::LinearGaussianObservation, ::Int, state)
11297
return Normal(transpose(proc.H) * state, proc.R)
11398
end
11499

115-
function LevyModel(dt, θ, σe, C, β, μw, σw; ϵ=1e-10)
116-
A = [0.0 1.0; 0.0 θ]
117-
L = [0.0; 1.0]
118-
H = [1.0, 0]
119-
100+
function LevyModel(dt, θ, σe, C, β, μw, σw; kwargs...)
120101
dyn = LevyLangevin(
121102
dt,
122-
LangevinDynamics(A, L, θ, H, σe),
123-
GammaProcess(C, β; ϵ),
124-
NormalMeanVariance(μw, σw),
103+
LangevinDynamics([0 1; 0 θ], [0; 1], θ),
104+
GammaProcess(C, β; kwargs...),
105+
Normal(μw, σw),
125106
)
126107

127-
obs = LinearGaussianObservation(H, σe)
128-
return StateSpaceModel(dyn, obs)
108+
obs = LinearGaussianObservation([1; 0], σe)
109+
return SSMProblems.StateSpaceModel(LevyPrior(zeros(Bool, 2), I(2)), dyn, obs)
129110
end
130111

131112
# Levy SSM with Langevin dynamics
@@ -139,15 +120,15 @@ end
139120
# Simulation parameters
140121
N = 200
141122
ts = range(0, 100; length=N)
142-
levyssm = LevyModel(step(ts), θ, 1.0, 1.0, 1.0, 0.0, 1.0);
123+
levyssm = LevyModel(step(ts), -0.5, 1, 1.0, 1.0, 0, 1);
143124

144125
# Simulate data
145126
rng = Random.MersenneTwister(1234);
146127
_, X, Y = sample(rng, levyssm, N);
147128

148129
# Run sampler
149130
pg = AdvancedPS.PGAS(50);
150-
chains = sample(rng, levyssm(Y), pg, 100);
131+
chains = sample(rng, AdvancedPS.TracedSSM(levyssm, Y), pg, 100; progress=false);
151132

152133
# Concat all sampled states
153134
marginal_states = hcat([chain.trajectory.model.X for chain in chains]...)

examples/particle-gibbs/script.jl

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,29 +52,34 @@ end
5252
# ```
5353
# with the initial distribution $f_0(x) = \mathcal{N}(0, q^2)$.
5454
# Here we assume the static parameters $\theta = (a^2, q^2)$ are known and we are only interested in sampling from the latent state $x_t$.
55-
struct LinearGaussianDynamics{T<:Real} <: SSMProblems.LatentDynamics{T,T}
56-
a::T
57-
q::T
55+
struct GaussianPrior{T<:Real} <: SSMProblems.StatePrior
56+
σ::T
5857
end
5958

60-
function SSMProblems.distribution(dyn::LinearGaussianDynamics{T}) where {T<:Real}
61-
return Normal(zero(T), dyn.q)
59+
function SSMProblems.distribution(proc::GaussianPrior)
60+
return Normal(0, proc.σ)
61+
end
62+
63+
struct LinearGaussianDynamics{AT<:Real,QT<:Real} <: SSMProblems.LatentDynamics
64+
a::AT
65+
q::QT
6266
end
6367

6468
function SSMProblems.distribution(dyn::LinearGaussianDynamics, ::Int, state)
6569
return Normal(dyn.a * state, dyn.q)
6670
end
6771

68-
struct StochasticVolatility{T<:Real} <: SSMProblems.ObservationProcess{T,T} end
72+
struct StochasticVolatility <: SSMProblems.ObservationProcess end
6973

70-
function SSMProblems.distribution(::StochasticVolatility{T}, ::Int, state) where {T<:Real}
71-
return Normal(zero(T), exp((1 / 2) * state))
74+
function SSMProblems.distribution(::StochasticVolatility, ::Int, state)
75+
return Normal(0, exp(state / 2))
7276
end
7377

74-
function LinearGaussianStochasticVolatilityModel(a::T, q::T) where {T<:Real}
78+
function LinearGaussianStochasticVolatilityModel(a, q)
79+
prior = GaussianPrior(q)
7580
dyn = LinearGaussianDynamics(a, q)
76-
obs = StochasticVolatility{T}()
77-
return SSMProblems.StateSpaceModel(dyn, obs)
81+
obs = StochasticVolatility()
82+
return SSMProblems.StateSpaceModel(prior, dyn, obs)
7883
end
7984
#md nothing #hide
8085

@@ -90,7 +95,7 @@ plot(x; label="x", xlabel="t")
9095
plot(y; label="y", xlabel="t")
9196

9297
# Here we use the particle gibbs kernel without adaptive resampling.
93-
model = true_model(y)
98+
model = AdvancedPS.TracedSSM(true_model, y)
9499
pg = AdvancedPS.PG(20, 1.0)
95100
chains = sample(rng, model, pg, 200; progress=false);
96101
#md nothing #hide

0 commit comments

Comments
 (0)