Skip to content

Commit 9631ceb

Browse files
made sure that pgas can sample t=0 properly. Updated example scripts
1 parent 2923a3f commit 9631ceb

File tree

21 files changed

+985
-952
lines changed

21 files changed

+985
-952
lines changed

.github/workflows/CI.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ jobs:
2424
matrix:
2525
version:
2626
- '1.11'
27-
- 'pre'
2827
os:
2928
- ubuntu-latest
3029
arch:

docs/src/examples/LGSS.md

Lines changed: 264 additions & 265 deletions
Large diffs are not rendered by default.

docs/src/examples/Poisson.md

Lines changed: 18 additions & 13 deletions
Large diffs are not rendered by default.

docs/src/examples/SAR.md

Lines changed: 34 additions & 29 deletions
Large diffs are not rendered by default.

docs/src/examples/SV.md

Lines changed: 262 additions & 262 deletions
Large diffs are not rendered by default.

docs/src/examples/TvReg.md

Lines changed: 84 additions & 85 deletions
Large diffs are not rendered by default.

docs/src/examples/TvRegHeteroInnov.md

Lines changed: 194 additions & 194 deletions
Large diffs are not rendered by default.

examples/LGSS/script.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ gr(legend = :topleft, grid = false, color = colors[2], lw = 2, legendfontsize=8,
2727
xtickfontsize=8, ytickfontsize=8, xguidefontsize=8, yguidefontsize=8,
2828
titlefontsize = 10, markerstrokecolor = :auto)
2929

30-
myquantile(A, p; dims, kwargs...) = mapslices(x -> quantile(x, p; kwargs...), A; dims)
3130
Random.seed!(123);
3231

3332
# ### Set up the state-space model
@@ -38,7 +37,7 @@ mutable struct LGSSParams
3837
σₑ::Float64
3938
end
4039

41-
prior(θ) = Normal(0, θ.σᵥ / ((1 - θ.a^2)));
40+
prior(θ) = Normal(0, 10*θ.σᵥ / ((1 - θ.a^2)));
4241
transition(θ, state, t) = Normal.a * state, θ.σᵥ);
4342
observation(θ, state, t) = Normal(state, θ.σₑ);
4443

@@ -65,25 +64,26 @@ plot!(y; seriestype=:scatter, label="observed, y", xlabel="t", markersize = 2,
6564
color = colors[1], markerstrokecolor = :auto)
6665

6766
# ### PGAS sampling
68-
Nₚ = 20 # Number of particles for PGAS
69-
Nₛ = 1000; # Number of samples from posterior
70-
PGASdraws = PGASsampler(y, θ, Nₛ, Nₚ, prior, transition, observation)
67+
Nₚ = 20 # Number of particles for PGAS
68+
Nₛ = 10000; # Number of samples from posterior
69+
sample_t0 = true # Sample state at t=0 ?
70+
PGASdraws = PGASsampler(y, θ, Nₛ, Nₚ, prior, transition, observation; sample_t0 = sample_t0)
7171
PGASmean = mean(PGASdraws, dims = 3)[:,:,1]
72-
PGASquantiles = myquantile(PGASdraws, [0.025, 0.975], dims = 3);
72+
PGASquantiles = quantile_multidim(PGASdraws, [0.025, 0.975], dims = 3);
7373

7474
# ### FFBS Sampling
7575
## Set up the LGSS for FFBS and sample
7676
Σₑ =.σₑ^2]
7777
Σₙ =.σᵥ^2]
7878
μ₀ = [0;;]
79-
Σ₀ =.σᵥ^2/(1-θ.a^2);;]
79+
Σ₀ = 10^2*.σᵥ^2/(1-θ.a^2);;]
8080
A = θ.a
8181
C = 1
8282
B = 0
8383
U = zeros(T,1);
84-
FFBSdraws = FFBS(U, y, A, B, C, Σₑ, Σₙ, μ₀, Σ₀, Nₛ);
85-
FFBSmean = mean(FFBSdraws, dims = 3)[2:end,:,1] # Exclude initial state at t=0
86-
FFBSquantiles = myquantile(FFBSdraws, [0.025, 0.975], dims = 3)[2:end,:,:];
84+
FFBSdraws = FFBS(U, y, A, B, C, Σₑ, Σₙ, μ₀, Σ₀, Nₛ; sample_t0 = sample_t0);
85+
FFBSmean = mean(FFBSdraws, dims = 3)
86+
FFBSquantiles = quantile_multidim(FFBSdraws, [0.025, 0.975], dims = 3);
8787

8888
# ### Plot the posterior mean and 95% C.I. intervals from both algorithms
8989
plottrue = true
@@ -93,7 +93,7 @@ for j in 1:p
9393

9494
#True state evolution
9595
if plottrue
96-
plt_tmp = plot(x, c = colors[3], lw = 1, label = "true state")
96+
plt_tmp = plot([NaN*ones(sample_t0);x], c = colors[3], lw = 1, label = "true state")
9797
else
9898
plt_tmp = plot()
9999
end

examples/Poisson/script.jl

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ gr(legend = :topleft, grid = false, color = colors[2], lw = 2, legendfontsize=8,
2424
xtickfontsize=8, ytickfontsize=8, xguidefontsize=8, yguidefontsize=8,
2525
titlefontsize = 10, markerstrokecolor = :auto)
2626

27-
myquantile(A, p; dims, kwargs...) = mapslices(x -> quantile(x, p; kwargs...), A; dims)
2827
Random.seed!(123);
2928

3029
# ### Set up Poisson model
@@ -79,7 +78,7 @@ function plotEvolDistributions!(plt, postDraws, quantiles, trueEvol = nothing,
7978
label = nothing, shaded = false; plotSettings...)
8079

8180
T, nState, nSim = size(postDraws)
82-
postquantiles = myquantile(postDraws, quantiles, dims = 3);
81+
postquantiles = quantile_multidim(postDraws, quantiles, dims = 3);
8382

8483
if !isnothing(trueEvol)
8584
plot!(plt, 1:T, trueEvol, lw = 1, c = :black, label = "true"; plotSettings...)
@@ -122,17 +121,20 @@ plt
122121
# ### PGAS sampling
123122
nSim = 1000; # Number of samples from posterior
124123
nParticles = 100 # Number of particles for PGAS
125-
PGASdraws = PGASsampler(y, θ, nSim, nParticles, prior, transition, observation);
124+
sample_t0 = true # Sample state at t=0 ?
125+
PGASdraws = PGASsampler(y, θ, nSim, nParticles, prior, transition, observation;
126+
sample_t0 = sample_t0);
126127

127128
# Plot the true evolution and the posterior distributions from PGAS
128129
quantiles = [0.025, 0.5, 0.975]
129130
pltLogIntensity = plot(; title = "Log Intensity "*L" \log\lambda_t = x_t")
130-
plotEvolDistributions!(pltLogIntensity, PGASdraws, quantiles, x,
131+
plotEvolDistributions!(pltLogIntensity, PGASdraws, quantiles, [NaN*ones(sample_t0);x],
131132
"PGAS(N=$nParticles)", true; color = colors[3], lw = 1, legend = :topright)
132133

133134
pltIntensity = plot(; title = "Intensity "*L" \lambda_t = \exp(x_t)")
134-
plotEvolDistributions!(pltIntensity, exp.(PGASdraws), quantiles, exp.(x),
135-
"PGAS(N=$nParticles)", true; color = colors[3], lw = 1, legend = :topright)
135+
plotEvolDistributions!(pltIntensity, exp.(PGASdraws), quantiles,
136+
exp.([NaN*ones(sample_t0);x]), "PGAS(N=$nParticles)", true; color = colors[3],
137+
lw = 1, legend = :topright)
136138

137139
plot(pltLogIntensity, pltIntensity, layout = (1,2), size = (800, 300), bottommargin = 5mm)
138140

@@ -147,7 +149,8 @@ B = θ.μ*(1-θ.a) # The transition model is x_t = μ + a (x_{t-1} - μ) + η_t,
147149
U = ones(T,1);
148150

149151
# Simulate from the Laplace approximation of the posterior distribution
150-
LaplaceDraws = FFBS_laplace(U, y, A, B, Σₙ, μ₀, Σ₀, observation, θ, nSim);
152+
LaplaceDraws = FFBS_laplace(U, y, A, B, Σₙ, μ₀, Σ₀, observation, θ, nSim;
153+
sample_t0 = sample_t0);
151154

152155
# Plot the PGAS and Laplace posterior distributions, this time without true evolution
153156
pltLogIntensity = plot(; title = "Log Intensity "*L" \log\lambda_t = x_t")
@@ -179,10 +182,12 @@ x, y, plt = simTvPoisson(observation, transition, prior, θ, T);
179182
plt
180183

181184
# Simulate from the PGAS posterior distribution
182-
PGASdraws = PGASsampler(y, θ, nSim, nParticles, prior, transition, observation);
185+
PGASdraws = PGASsampler(y, θ, nSim, nParticles, prior, transition, observation;
186+
sample_t0 = sample_t0);
183187

184188
# Simulate from the Laplace approximation of the posterior distribution
185-
LaplaceDraws = FFBS_laplace(U, y, A, B, Σₙ, μ₀, Σ₀, observation, θ, nSim);
189+
LaplaceDraws = FFBS_laplace(U, y, A, B, Σₙ, μ₀, Σ₀, observation, θ, nSim;
190+
sample_t0 = sample_t0);
186191

187192
# Plot the PGAS and Laplace posterior distributions, without true evolution
188193
pltLogIntensity = plot(; title = "Log Intensity "*L" \log\lambda_t = x_t")

examples/SAR/Manifest.toml

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
julia_version = "1.11.6"
44
manifest_format = "2.0"
5-
project_hash = "69971822c41d250d3341c5f32e3985a0a89a97bb"
5+
project_hash = "1442dc65b9337561130506a3e2dc47606ca8686b"
66

77
[[deps.ADTypes]]
88
git-tree-sha1 = "7927b9af540ee964cc5d1b73293f1eb0b761a3a1"
@@ -297,11 +297,6 @@ version = "0.7.4"
297297
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
298298
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
299299

300-
[[deps.DisplayAs]]
301-
git-tree-sha1 = "43c017d5dd3a48d56486055973f443f8a39bb6d9"
302-
uuid = "0b91fe84-8a4c-11e9-3e1d-67c38462b6d6"
303-
version = "0.1.6"
304-
305300
[[deps.Distributed]]
306301
deps = ["Random", "Serialization", "Sockets"]
307302
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

0 commit comments

Comments
 (0)