Skip to content

Commit c7766f5

Browse files
authored
Linear Gaussian unit test (#98)
* Add unit test for linear Gaussian SSM * Replace matrix model dynamics with scalar * Remove redundant stack Fixes CI error that came from `stack` not being available in Julia 1.7. * Increase particle count and ensure reproducibility * Update test to SSMProblems interface
1 parent 2880bd3 commit c7766f5

File tree

3 files changed

+120
-0
lines changed

3 files changed

+120
-0
lines changed

test/Project.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
[deps]
22
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
33
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
4+
DynamicIterators = "6c76993d-992e-5bf1-9e63-34920a5a5a38"
5+
GaussianDistributions = "43dcc890-d446-5863-8d1a-14597580bb8d"
6+
HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5"
7+
Kalman = "d59c0ba6-2ef2-5409-8dc5-1fd9a2b46832"
48
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
59
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
610
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
11+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
712
SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48"
813
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
914

test/linear-gaussian.jl

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
"""
2+
Unit tests for the validity of the SMC algorithms included in this package.
3+
4+
We test each SMC algorithm on a one-dimensional linear Gaussian state space model for which
5+
an analytic filtering distribution can be computed using the Kalman filter provided by the
6+
`Kalman.jl` package.
7+
8+
The validity of the algorithm is tested by comparing the final estimated filtering
9+
distribution ground truth using a one-sided Kolmogorov-Smirnov test.
10+
"""
11+
12+
using DynamicIterators
13+
using GaussianDistributions
14+
using HypothesisTests
15+
using Kalman
16+
17+
function test_algorithm(rng, algorithm, model, N_SAMPLES, Xf)
18+
chains = sample(rng, model, algorithm, N_SAMPLES; progress=false)
19+
particles = hcat([chain.trajectory.model.X for chain in chains]...)
20+
final_particles = particles[:, end]
21+
22+
test = ExactOneSampleKSTest(final_particles, Normal(Xf.x[end].μ, sqrt(Xf.x[end].Σ)))
23+
return pvalue(test)
24+
end
25+
26+
@testset "linear-gaussian.jl" begin
27+
T = 3
28+
N_PARTICLES = 100
29+
N_SAMPLES = 50
30+
31+
# Model dynamics
32+
a = 0.5
33+
b = 0.2
34+
q = 0.1
35+
E = LinearEvolution(a, Gaussian(b, q))
36+
37+
H = 1.0
38+
R = 0.1
39+
Obs = LinearObservationModel(H, R)
40+
41+
x0 = 0.0
42+
P0 = 1.0
43+
G0 = Gaussian(x0, P0)
44+
45+
M = LinearStateSpaceModel(E, Obs)
46+
O = LinearObservation(E, H, R)
47+
48+
# Simulate from model
49+
rng = StableRNG(1234)
50+
initial = rand(rng, StateObs(G0, M.obs))
51+
trajectory = trace(DynamicIterators.Sampled(M, rng), 1 => initial, endtime(T))
52+
y_pairs = collect(t => y for (t, (x, y)) in pairs(trajectory))
53+
ys = [y for (t, (x, y)) in pairs(trajectory)]
54+
55+
# Ground truth smoothing
56+
Xf, ll = kalmanfilter(M, 1 => G0, y_pairs)
57+
58+
# Define AdvancedPS model
59+
mutable struct LinearGaussianParams
60+
a::Float64
61+
b::Float64
62+
q::Float64
63+
h::Float64
64+
r::Float64
65+
x0::Float64
66+
p0::Float64
67+
end
68+
69+
mutable struct LinearGaussianModel <: SSMProblems.AbstractStateSpaceModel
70+
X::Vector{Float64}
71+
observations::Vector{Float64}
72+
θ::LinearGaussianParams
73+
function LinearGaussianModel(y::Vector{Float64}, θ::LinearGaussianParams)
74+
return new(Vector{Float64}(), y, θ)
75+
end
76+
end
77+
78+
function SSMProblems.transition!!(rng::AbstractRNG, model::LinearGaussianModel)
79+
return rand(rng, Normal(model.θ.x0, model.θ.p0))
80+
end
81+
function SSMProblems.transition!!(
82+
rng::AbstractRNG, model::LinearGaussianModel, state, step
83+
)
84+
return rand(rng, Normal(model.θ.a * state + model.θ.b, model.θ.q))
85+
end
86+
function SSMProblems.transition_logdensity(
87+
model::LinearGaussianModel, prev_state, current_state, step
88+
)
89+
return logpdf(Normal(model.θ.a * prev_state + model.θ.b, model.θ.q), current_state)
90+
end
91+
function SSMProblems.emission_logdensity(model::LinearGaussianModel, state, step)
92+
return logpdf(Normal(model.θ.h * state, model.θ.r), model.observations[step])
93+
end
94+
95+
AdvancedPS.isdone(::LinearGaussianModel, step) = step > T
96+
97+
params = LinearGaussianParams(a, b, q, H, R, x0, P0)
98+
model = LinearGaussianModel(ys, params)
99+
100+
@testset "PGAS" begin
101+
pgas = AdvancedPS.PGAS(N_PARTICLES)
102+
p = test_algorithm(rng, pgas, model, N_SAMPLES, Xf)
103+
@test p > 0.05
104+
end
105+
106+
@testset "PG" begin
107+
pg = AdvancedPS.PG(N_PARTICLES)
108+
p = test_algorithm(rng, pg, model, N_SAMPLES, Xf)
109+
@test p > 0.05
110+
end
111+
end

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using AbstractMCMC
33
using Distributions
44
using Libtask
55
using Random
6+
using StableRNGs
67
using Test
78
using SSMProblems
89

@@ -22,4 +23,7 @@ using SSMProblems
2223
@testset "PG-AS" begin
2324
include("pgas.jl")
2425
end
26+
@testset "Linear Gaussian SSM tests" begin
27+
include("linear-gaussian.jl")
28+
end
2529
end

0 commit comments

Comments
 (0)