Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 26 additions & 8 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -549,10 +549,10 @@ uuid = "c87230d0-a227-11e9-1b43-d7ebe4e7570a"
version = "0.4.5"

[[deps.FFMPEG_jll]]
deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "JLLWrappers", "LAME_jll", "Libdl", "Ogg_jll", "OpenSSL_jll", "Opus_jll", "PCRE2_jll", "Zlib_jll", "libaom_jll", "libass_jll", "libfdk_aac_jll", "libvorbis_jll", "x264_jll", "x265_jll"]
git-tree-sha1 = "01ba9d15e9eae375dc1eb9589df76b3572acd3f2"
deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "JLLWrappers", "LAME_jll", "Libdl", "Ogg_jll", "OpenSSL_jll", "Opus_jll", "PCRE2_jll", "Zlib_jll", "libaom_jll", "libass_jll", "libfdk_aac_jll", "libva_jll", "libvorbis_jll", "x264_jll", "x265_jll"]
git-tree-sha1 = "66381d7059b5f3f6162f28831854008040a4e905"
uuid = "b22a6f82-2f65-5046-a5b2-351ab43fb4e5"
version = "8.0.1+0"
version = "8.0.1+1"

[[deps.FFTA]]
deps = ["AbstractFFTs", "DocStringExtensions", "LinearAlgebra", "MuladdMacro", "Primes", "Random", "Reexport"]
Expand Down Expand Up @@ -1445,9 +1445,9 @@ version = "1.3.1"

[[deps.Revise]]
deps = ["CodeTracking", "FileWatching", "InteractiveUtils", "JuliaInterpreter", "LibGit2", "LoweredCodeUtils", "OrderedCollections", "Preferences", "REPL", "UUIDs"]
git-tree-sha1 = "14d1bfb0a30317edc77e11094607ace3c800f193"
git-tree-sha1 = "d97d78d4fc5f858d8ce44f6b88bc972f2023f51d"
uuid = "295af30f-e4ad-537b-8983-00126c2a3abe"
version = "3.13.2"
version = "3.14.0"
weakdeps = ["Distributed"]

[deps.Revise.extensions]
Expand All @@ -1473,7 +1473,7 @@ version = "0.7.0"
deps = ["BenchmarkTools", "CSV", "Distributions", "ForwardDiff", "LaTeXStrings", "LineSearches", "LinearAlgebra", "Literate", "Measures", "Optim", "PDMats", "PProf", "Plots", "Profile", "Random", "Revise", "Statistics", "Utils"]
path = "."
uuid = "2328efba-24c2-4a10-a32f-a74f69d05fca"
version = "1.0.1"
version = "1.0.0"

[[deps.Scratch]]
deps = ["Dates"]
Expand Down Expand Up @@ -1552,9 +1552,9 @@ version = "1.0.4"

[[deps.StaticArrays]]
deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"]
git-tree-sha1 = "0f529006004a8be48f1be25f3451186579392d47"
git-tree-sha1 = "246a8bb2e6667f832eea063c3a56aef96429a3db"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.9.17"
version = "1.9.18"
weakdeps = ["ChainRulesCore", "Statistics"]

[deps.StaticArrays.extensions]
Expand Down Expand Up @@ -1854,6 +1854,12 @@ git-tree-sha1 = "7ed9347888fac59a618302ee38216dd0379c480d"
uuid = "ea2f1a96-1ddc-540d-b46f-429655e07cfa"
version = "0.9.12+0"

[[deps.Xorg_libpciaccess_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Zlib_jll"]
git-tree-sha1 = "4909eb8f1cbf6bd4b1c30dd18b2ead9019ef2fad"
uuid = "a65dc6b1-eb27-53a1-bb3e-dea574b5389e"
version = "0.18.1+0"

[[deps.Xorg_libxcb_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libXau_jll", "Xorg_libXdmcp_jll"]
git-tree-sha1 = "bfcaf7ec088eaba362093393fe11aa141fa15422"
Expand Down Expand Up @@ -1966,6 +1972,12 @@ git-tree-sha1 = "9bf7903af251d2050b467f76bdbe57ce541f7f4f"
uuid = "1183f4f0-6f2a-5f1a-908b-139f9cdfea6f"
version = "0.2.2+0"

[[deps.libdrm_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libpciaccess_jll"]
git-tree-sha1 = "63aac0bcb0b582e11bad965cef4a689905456c03"
uuid = "8e53e030-5e6c-5a89-a30b-be5b7263a166"
version = "2.4.125+1"

[[deps.libevdev_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
git-tree-sha1 = "56d643b57b188d30cccc25e331d416d3d358e557"
Expand All @@ -1990,6 +2002,12 @@ git-tree-sha1 = "e015f211ebb898c8180887012b938f3851e719ac"
uuid = "b53b4c65-9356-5827-b1ea-8c7a1a84506f"
version = "1.6.55+0"

[[deps.libva_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libX11_jll", "Xorg_libXext_jll", "Xorg_libXfixes_jll", "libdrm_jll"]
git-tree-sha1 = "7dbf96baae3310fe2fa0df0ccbb3c6288d5816c9"
uuid = "9a156e7d-b971-5f62-b2c9-67348b8fb97c"
version = "2.23.0+0"

[[deps.libvorbis_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Ogg_jll"]
git-tree-sha1 = "11e1772e7f3cc987e9d3de991dd4f6b2602663a5"
Expand Down
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@

## Description

Julia implementation of some posterior samplers for state-space models with general non-linear/non-Gaussian observation models and linear (heteroscedastic) transition models. Some example scripts can be found in the `examples` folder.
Julia implementation of some posterior samplers for state-space models with general non-linear/non-Gaussian observation models and linear (heteroscedastic) transition models. Some example scripts can be found in the `examples` folder, and in the Examples section of the documentation. See below of a simple PGAS example.

## Installation
Install from the Julia package manager (via Github) by typing `]` in the Julia REPL:
The package is in the [CompBayesRegistry](https://github.com/compbayes/CompBayesRegistry), which must first be added to your Julia. The package can then be installed by the usual `add` mechanism in the Julia Package manager.

Install from the Julia package manager by typing `]` in the Julia REPL, followed by
```
] add git@github.com:compbayes/SMCsamplers.jl.git
registry add https://github.com/compbayes/CompBayesRegistry.git
add SMCsamplers
```

## Example
Expand Down
9 changes: 5 additions & 4 deletions docs/src/FFBS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
```@docs
FFBS
FFBSx
FFBS_unscented
FFBS_laplace
FFBS!
FFBSx!
FFBS_unscented!
FFBS_SLR!
FFBS_laplace!
```
38 changes: 28 additions & 10 deletions examples/AR/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -549,10 +549,10 @@ uuid = "c87230d0-a227-11e9-1b43-d7ebe4e7570a"
version = "0.4.5"

[[deps.FFMPEG_jll]]
deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "JLLWrappers", "LAME_jll", "Libdl", "Ogg_jll", "OpenSSL_jll", "Opus_jll", "PCRE2_jll", "Zlib_jll", "libaom_jll", "libass_jll", "libfdk_aac_jll", "libvorbis_jll", "x264_jll", "x265_jll"]
git-tree-sha1 = "01ba9d15e9eae375dc1eb9589df76b3572acd3f2"
deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "JLLWrappers", "LAME_jll", "Libdl", "Ogg_jll", "OpenSSL_jll", "Opus_jll", "PCRE2_jll", "Zlib_jll", "libaom_jll", "libass_jll", "libfdk_aac_jll", "libva_jll", "libvorbis_jll", "x264_jll", "x265_jll"]
git-tree-sha1 = "66381d7059b5f3f6162f28831854008040a4e905"
uuid = "b22a6f82-2f65-5046-a5b2-351ab43fb4e5"
version = "8.0.1+0"
version = "8.0.1+1"

[[deps.FFTA]]
deps = ["AbstractFFTs", "DocStringExtensions", "LinearAlgebra", "MuladdMacro", "Primes", "Random", "Reexport"]
Expand Down Expand Up @@ -727,9 +727,9 @@ version = "1.0.2"

[[deps.HTTP]]
deps = ["Base64", "CodecZlib", "ConcurrentUtilities", "Dates", "ExceptionUnwrapping", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "PrecompileTools", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"]
git-tree-sha1 = "5e6fe50ae7f23d171f44e311c2960294aaa0beb5"
git-tree-sha1 = "51059d23c8bb67911a2e6fd5130229113735fc7e"
uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3"
version = "1.10.19"
version = "1.11.0"

[[deps.HarfBuzz_jll]]
deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "Graphite2_jll", "JLLWrappers", "Libdl", "Libffi_jll"]
Expand Down Expand Up @@ -1473,7 +1473,7 @@ version = "0.7.0"
deps = ["BenchmarkTools", "CSV", "Distributions", "ForwardDiff", "LaTeXStrings", "LineSearches", "LinearAlgebra", "Literate", "Measures", "Optim", "PDMats", "PProf", "Plots", "Profile", "Random", "Revise", "Statistics", "Utils"]
path = "/home/mv/.julia/dev/SMCsamplers/docs/.."
uuid = "2328efba-24c2-4a10-a32f-a74f69d05fca"
version = "1.0.0-DEV"
version = "1.0.0"

[[deps.Scratch]]
deps = ["Dates"]
Expand Down Expand Up @@ -1552,9 +1552,9 @@ version = "1.0.4"

[[deps.StaticArrays]]
deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"]
git-tree-sha1 = "0f529006004a8be48f1be25f3451186579392d47"
git-tree-sha1 = "246a8bb2e6667f832eea063c3a56aef96429a3db"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.9.17"
version = "1.9.18"
weakdeps = ["ChainRulesCore", "Statistics"]

[deps.StaticArrays.extensions]
Expand Down Expand Up @@ -1737,9 +1737,9 @@ version = "0.2.0"

[[deps.Utils]]
deps = ["AdvancedMH", "ColorSchemes", "DataFrames", "Distances", "Distributions", "ForwardDiff", "KernelDensity", "LaTeXStrings", "LinearAlgebra", "Loess", "LogExpFunctions", "Optim", "PDMats", "Plots", "ProgressMeter", "Random", "SpecialFunctions", "Statistics", "StatsPlots"]
git-tree-sha1 = "5e80f00b61919281f2e0cc8b6193f6952856c6fb"
git-tree-sha1 = "db691c7685339e2d3bebe4d22b800c8c10a5d930"
uuid = "a6a4a4df-ba9f-45e9-bafc-bec3296762b9"
version = "0.1.1"
version = "1.0.1"

[[deps.Vulkan_Loader_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Wayland_jll", "Xorg_libX11_jll", "Xorg_libXrandr_jll", "xkbcommon_jll"]
Expand Down Expand Up @@ -1854,6 +1854,12 @@ git-tree-sha1 = "7ed9347888fac59a618302ee38216dd0379c480d"
uuid = "ea2f1a96-1ddc-540d-b46f-429655e07cfa"
version = "0.9.12+0"

[[deps.Xorg_libpciaccess_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Zlib_jll"]
git-tree-sha1 = "4909eb8f1cbf6bd4b1c30dd18b2ead9019ef2fad"
uuid = "a65dc6b1-eb27-53a1-bb3e-dea574b5389e"
version = "0.18.1+0"

[[deps.Xorg_libxcb_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libXau_jll", "Xorg_libXdmcp_jll"]
git-tree-sha1 = "bfcaf7ec088eaba362093393fe11aa141fa15422"
Expand Down Expand Up @@ -1966,6 +1972,12 @@ git-tree-sha1 = "9bf7903af251d2050b467f76bdbe57ce541f7f4f"
uuid = "1183f4f0-6f2a-5f1a-908b-139f9cdfea6f"
version = "0.2.2+0"

[[deps.libdrm_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libpciaccess_jll"]
git-tree-sha1 = "63aac0bcb0b582e11bad965cef4a689905456c03"
uuid = "8e53e030-5e6c-5a89-a30b-be5b7263a166"
version = "2.4.125+1"

[[deps.libevdev_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
git-tree-sha1 = "56d643b57b188d30cccc25e331d416d3d358e557"
Expand All @@ -1990,6 +2002,12 @@ git-tree-sha1 = "e015f211ebb898c8180887012b938f3851e719ac"
uuid = "b53b4c65-9356-5827-b1ea-8c7a1a84506f"
version = "1.6.55+0"

[[deps.libva_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libX11_jll", "Xorg_libXext_jll", "Xorg_libXfixes_jll", "libdrm_jll"]
git-tree-sha1 = "7dbf96baae3310fe2fa0df0ccbb3c6288d5816c9"
uuid = "9a156e7d-b971-5f62-b2c9-67348b8fb97c"
version = "2.23.0+0"

[[deps.libvorbis_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Ogg_jll"]
git-tree-sha1 = "11e1772e7f3cc987e9d3de991dd4f6b2602663a5"
Expand Down
23 changes: 15 additions & 8 deletions examples/AR/script.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,9 @@ Cargs = [Z[t,:] for t in 1:T];
∂C(state, z) = ForwardDiff.gradient(state -> C(state, z), state)';

# ### FFBS posterior sampling using the Extended Kalman filter (EKF)
EKFdraws, μ_filterEKF, Σ_filterEKF = FFBSx(U, Y, A, B, C, ∂C, Cargs, Σₑ, Σₙ, μ₀, Σ₀, nSim;
filter_output = true);
EKFdraws = zeros(T + sample_t0, nState, nSim)
μ_filterEKF, Σ_filterEKF = FFBSx!(EKFdraws, U, Y, A, B, C, ∂C, Cargs, Σₑ, Σₙ,
μ₀, Σ₀, nSim; filter_output = true);
EKFdraws = restr.(EKFdraws) # Apply the restriction to the draws
EKFmedian = median(EKFdraws, dims = 3)[:,:,1];
EKFquantiles = quantile_multidim(EKFdraws, [0.025, 0.975], dims = 3);
Expand All @@ -166,7 +167,8 @@ plot(plt..., layout = (1,2), size = (800, 300), ylims = (-1.7,1.7), xlabel = "ti

# ### FFBS posterior sampling using the Unscented Kalman filter (UKF)
α = 1; β = 0; κ = 0;
UKFdraws = FFBS_unscented(U, Y, A, B, C, Cargs, Σₑ, Σₙ, μ₀, Σ₀, nSim;
UKFdraws = zeros(T + sample_t0, nState, nSim)
FFBS_unscented!(UKFdraws, U, Y, A, B, C, Cargs, Σₑ, Σₙ, μ₀, Σ₀, nSim;
α = α, β = β, κ = κ);
UKFdraws = restr.(UKFdraws) # Apply the restriction to the draws
UKFmedian = median(UKFdraws, dims = 3)[:,:,1]
Expand All @@ -188,8 +190,9 @@ plotIEKF = true
if plotIEKF
maxIter = 10
tol = 10^-4 # tolerance for convergence
IEKFdraws, μ_filterIEKF, Σ_filterIEKF = FFBSx(U, Y, A, B, C, ∂C, Cargs, Σₑ, Σₙ, μ₀, Σ₀,
nSim, maxIter, tol; filter_output = true);
IEKFdraws = zeros(T + sample_t0, nState, nSim)
μ_filterIEKF, Σ_filterIEKF = FFBSx!(IEKFdraws, U, Y, A, B, C, ∂C, Cargs, Σₑ, Σₙ,
μ₀, Σ₀, nSim, maxIter, tol; filter_output = true);
IEKFdraws = restr.(IEKFdraws) # Apply the restriction to the draws
IEKFmedian = median(IEKFdraws, dims = 3)[:,:,1];
IEKFquantiles = quantile_multidim(IEKFdraws, [0.025, 0.975], dims = 3);
Expand All @@ -211,7 +214,8 @@ if plotIEKFL
linesearch = true
maxIter = 10
tol = 10^-4 # tolerance for convergence
IEKFLdraws, μ_filterIEKFL, Σ_filterIEKFL = FFBSx(U, Y, A, B, C, ∂C, Cargs, Σₑ, Σₙ,
IEKFLdraws = zeros(T + sample_t0, nState, nSim)
μ_filterIEKFL, Σ_filterIEKFL = FFBSx!(IEKFLdraws, U, Y, A, B, C, ∂C, Cargs, Σₑ, Σₙ,
μ₀, Σ₀, nSim, maxIter, tol, linesearch; filter_output = true);
IEKFLdraws = restr.(IEKFLdraws) # Apply the restriction to the draws
IEKFLmedian = median(IEKFLdraws, dims = 3)[:,:,1];
Expand All @@ -233,8 +237,10 @@ plotLaplace = true
if plotLaplace
maxIter = 10
tol = 10^-4 # tolerance for convergence
LaplaceDraws, μ_filterLaplace, Σ_filterLaplace = FFBS_laplace(U, Y, A, B, Σₙ, μ₀, Σ₀,
observation, θ, nSim; filter_output = true);
nFailure = Ref(0) # Persistent counter that gets updated when failures occur
LaplaceDraws = zeros(T + sample_t0, nState, nSim)
μ_filterLaplace, Σ_filterLaplace = FFBS_laplace!(LaplaceDraws, U, Y, A, B, Σₙ,
μ₀, Σ₀, observation, θ, nSim; filter_output = true, nFailure = nFailure);
LaplaceDraws = restr.(LaplaceDraws) # Apply the restriction to the draws
Laplacemedian = median(LaplaceDraws, dims = 3)[:,:,1];
Laplacequantiles = quantile_multidim(LaplaceDraws, [0.025, 0.975], dims = 3);
Expand All @@ -249,3 +255,4 @@ if plotLaplace
plot(plt..., layout = (1,2), size = (1400, 600), ylims = (-1.7,1.7), xlabel = "time",
bottommargin = 5mm)
end
println("Laplace failed at $(100*nFailure[]/nSim)% of the simulated trajectories")
4 changes: 3 additions & 1 deletion examples/LGSS/script.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ A = θ.a
C = 1
B = 0
U = zeros(T,1);
FFBSdraws = FFBS(U, y, A, B, C, Σₑ, Σₙ, μ₀, Σ₀, nSim; sample_t0 = sample_t0);
nState = length(μ₀)
FFBSdraws = zeros(T + sample_t0, nState, nSim);
FFBS!(FFBSdraws, U, y, A, B, C, Σₑ, Σₙ, μ₀, Σ₀, nSim; sample_t0 = sample_t0);
FFBSmean = mean(FFBSdraws, dims = 3)
FFBSquantiles = quantile_multidim(FFBSdraws, [0.025, 0.975], dims = 3);

Expand Down
15 changes: 10 additions & 5 deletions examples/SAR/script.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,10 @@ Cargs = [Z[t,:] for t in 1:T];
∂C(state, z) = ForwardDiff.gradient(state -> C(state, z), state)';

# ### FFBS posterior sampling using the Extended Kalman filter (EKF)
EKFdraws, μ_filterEKF, Σ_filterEKF = FFBSx(U, Y, A, B, C, ∂C, Cargs, Σₑ, Σₙ, μ₀, Σ₀, nSim;
filter_output = true, sample_t0 = sample_t0);

EKFdraws = zeros(T + sample_t0, nState, nSim);
μ_filterEKF, Σ_filterEKF = FFBSx!(EKFdraws, U, Y, A, B, C, ∂C, Cargs, Σₑ, Σₙ,
μ₀, Σ₀, nSim; filter_output = true, sample_t0 = sample_t0);
EKFdraws = restr.(EKFdraws) # Apply the restriction to the draws
EKFmedian = median(EKFdraws, dims = 3)[:,:,1];
EKFquantiles = quantile_multidim(EKFdraws, [0.025, 0.975], dims = 3);
Expand All @@ -178,7 +180,8 @@ plot(plt..., layout = (1,2), size = (800, 300), ylims = (-1.7,1.7), xlabel = "ti

# ### FFBS posterior sampling using the Unscented Kalman filter (UKF)
α = 1; β = 0; κ = 0;
UKFdraws = FFBS_unscented(U, Y, A, B, C, Cargs, Σₑ, Σₙ, μ₀, Σ₀, nSim;
UKFdraws = zeros(T + sample_t0, nState, nSim);
FFBS_unscented!(UKFdraws, U, Y, A, B, C, Cargs, Σₑ, Σₙ, μ₀, Σ₀, nSim;
α = α, β = β, κ = κ, sample_t0 = sample_t0);
UKFdraws = restr.(UKFdraws) # Apply the restriction to the draws
UKFmedian = median(UKFdraws, dims = 3)[:,:,1]
Expand All @@ -199,7 +202,8 @@ plotIEKF = true
if plotIEKF
maxIter = 10
tol = 10^-4 # tolerance for convergence
IEKFdraws, μ_filterIEKF, Σ_filterIEKF = FFBSx(U, Y, A, B, C, ∂C, Cargs, Σₑ, Σₙ, μ₀, Σ₀,
IEKFdraws = zeros(T + sample_t0, nState, nSim);
FFBSx!(IEKFdraws, U, Y, A, B, C, ∂C, Cargs, Σₑ, Σₙ, μ₀, Σ₀,
nSim, maxIter, tol; filter_output = true, sample_t0 = sample_t0);
IEKFdraws = restr.(IEKFdraws) # Apply the restriction to the draws
IEKFmedian = median(IEKFdraws, dims = 3)[:,:,1];
Expand All @@ -222,7 +226,8 @@ if plotIEKFL
linesearch = true
maxIter = 10
tol = 10^-4 # tolerance for convergence
IEKFLdraws, μ_filterIEKFL, Σ_filterIEKFL = FFBSx(U, Y, A, B, C, ∂C, Cargs, Σₑ, Σₙ,
IEKFLdraws = zeros(T + sample_t0, nState, nSim);
FFBSx!(IEKFLdraws, U, Y, A, B, C, ∂C, Cargs, Σₑ, Σₙ,
μ₀, Σ₀, nSim, maxIter, tol, linesearch; filter_output = true,
sample_t0 = sample_t0);
IEKFLdraws = restr.(IEKFLdraws) # Apply the restriction to the draws
Expand Down
5 changes: 3 additions & 2 deletions examples/TvReg/script.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ function SimTVReg(T, p, σₑ, Σₙ, Σ₀ = Σₙ)
return y, Z, β
end

p = 2 # State size - number of β parameters, including intercept
p = nState = 2 # State size - number of β parameters, including intercept
T = 200 # Number of observations
σₑ = 1
Σₙ = PDMat([1 0;0 0.1]) # State innovation covariance matrix
Expand Down Expand Up @@ -85,7 +85,8 @@ end
B = 0.0
U = zeros(T,1)

FFBSdraws = FFBS(U, y, A, B, C, Σₑ, Σₙ, μ₀, Σ₀, Nₛ);
FFBSdraws = zeros(T + 1, nState, Nₛ);
FFBS!(FFBSdraws, U, y, A, B, C, Σₑ, Σₙ, μ₀, Σ₀, Nₛ);
FFBSmean = mean(FFBSdraws, dims = 3)[2:end,:,1] # Exclude initial state at t=0
FFBSquantiles = quantile_multidim(FFBSdraws, [0.025, 0.975], dims = 3)[2:end,:,:];

Expand Down
Loading
Loading