diff --git a/Project.toml b/Project.toml index 2bde35e8..841571df 100644 --- a/Project.toml +++ b/Project.toml @@ -26,7 +26,7 @@ Random = "<0.0.1, 1" Random123 = "1.3" Requires = "1.0" StatsFuns = "0.9, 1" -SSMProblems = "0.5" +SSMProblems = "0.6" julia = "1.10.8" [extras] diff --git a/docs/Project.toml b/docs/Project.toml index fe1414d2..3454b6e8 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -2,6 +2,3 @@ AdvancedPS = "576499cb-2369-40b2-a588-c64705576edc" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" - -[compat] -Documenter = "0.27" diff --git a/docs/make.jl b/docs/make.jl index c540c151..d5e40fc6 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -48,7 +48,7 @@ DocMeta.setdocmeta!(AdvancedPS, :DocTestSetup, :(using AdvancedPS); recursive=tr makedocs(; sitename="AdvancedPS", - format=Documenter.HTML(), + format=Documenter.HTML(; size_threshold=1000 * 2^11), # 1Mb per page modules=[AdvancedPS], pages=[ "Home" => "index.md", diff --git a/examples/gaussian-process/script.jl b/examples/gaussian-process/script.jl index 40957858..4bf031cb 100644 --- a/examples/gaussian-process/script.jl +++ b/examples/gaussian-process/script.jl @@ -8,59 +8,59 @@ using Distributions using Libtask using SSMProblems -struct GaussianProcessDynamics{T<:Real,KT<:Kernel} <: LatentDynamics{T,T} +struct GaussianProcessDynamics{T<:Real,KT<:Kernel} <: SSMProblems.LatentDynamics proc::GP{ZeroMean{T},KT} function GaussianProcessDynamics(::Type{T}, kernel::KT) where {T<:Real,KT<:Kernel} return new{T,KT}(GP(ZeroMean{T}(), kernel)) end end -struct LinearGaussianDynamics{T<:Real} <: LatentDynamics{T,T} - a::T - b::T - q::T +struct GaussianPrior{ΣT<:Real} <: SSMProblems.StatePrior + σ::ΣT end -function SSMProblems.distribution(proc::LinearGaussianDynamics{T}) where {T<:Real} - return Normal(zero(T), proc.q) +SSMProblems.distribution(proc::GaussianPrior) = Normal(0, proc.σ) + +struct LinearGaussianDynamics{AT<:Real,BT<:Real,QT<:Real} <: SSMProblems.LatentDynamics + a::AT + b::BT + q::QT end function SSMProblems.distribution(proc::LinearGaussianDynamics, ::Int, state) return Normal(proc.a * state + proc.b, proc.q) end -struct StochasticVolatility{T<:Real} <: ObservationProcess{T,T} end +struct StochasticVolatility <: SSMProblems.ObservationProcess end -function SSMProblems.distribution(::StochasticVolatility{T}, ::Int, state) where {T<:Real} - return Normal(zero(T), exp((1 / 2) * state)) +function SSMProblems.distribution(::StochasticVolatility, ::Int, state) + return Normal(0, exp(state / 2)) end -function LinearGaussianStochasticVolatilityModel(a::T, q::T) where {T<:Real} - dyn = LinearGaussianDynamics(a, zero(T), q) - obs = StochasticVolatility{T}() - return SSMProblems.StateSpaceModel(dyn, obs) +function LinearGaussianStochasticVolatilityModel(a, q) + prior = GaussianPrior(q) + dyn = LinearGaussianDynamics(a, 0, q) + obs = StochasticVolatility() + return SSMProblems.StateSpaceModel(prior, dyn, obs) end function GaussianProcessStateSpaceModel(::Type{T}, kernel::KT) where {T<:Real,KT<:Kernel} + prior = GaussianPrior(one(T)) dyn = GaussianProcessDynamics(T, kernel) - obs = StochasticVolatility{T}() - return SSMProblems.StateSpaceModel(dyn, obs) + obs = StochasticVolatility() + return SSMProblems.StateSpaceModel(prior, dyn, obs) end const GPSSM{T,KT<:Kernel} = SSMProblems.StateSpaceModel{ - T, - GaussianProcessDynamics{T,KT}, - StochasticVolatility{T} + <:GaussianPrior,<:GaussianProcessDynamics{T,KT},StochasticVolatility }; # for non-markovian models, we can redefine dynamics to reference the trajectory -function AdvancedPS.dynamics( - ssm::AdvancedPS.TracedSSM{<:GPSSM{T},T,T}, step::Int -) where {T<:Real} +function AdvancedPS.dynamics(ssm::AdvancedPS.TracedSSM{<:GPSSM}, step::Int) prior = ssm.model.dyn.proc(1:(step - 1)) - post = posterior(prior, ssm.X[1:(step - 1)]) - μ, σ = mean_and_cov(post, [step]) - return LinearGaussianDynamics(zero(T), μ[1], sqrt(σ[1])) + post = posterior(prior, ssm.X[1:(step - 1)]) + μ, σ = mean_and_cov(post, [step]) + return LinearGaussianDynamics(0, μ[1], sqrt(σ[1])) end # Everything is now ready to simulate some data. @@ -70,9 +70,9 @@ _, x, y = sample(rng, true_model, 100); # Create the model and run the sampler gpssm = GaussianProcessStateSpaceModel(Float64, SqExponentialKernel()); -model = gpssm(y); +model = AdvancedPS.TracedSSM(gpssm, y); pg = AdvancedPS.PGAS(20); -chains = sample(rng, model, pg, 250; progress=false); +chains = sample(rng, model, pg, 250); #md nothing #hide particles = hcat([chain.trajectory.model.X for chain in chains]...); diff --git a/examples/gaussian-ssm/script.jl b/examples/gaussian-ssm/script.jl index e8fd9ad3..097494c9 100644 --- a/examples/gaussian-ssm/script.jl +++ b/examples/gaussian-ssm/script.jl @@ -28,27 +28,31 @@ using SSMProblems # as well as the initial distribution $f_0(x) = \mathcal{N}(0, q^2/(1-a^2))$. # To use `AdvancedPS` we first need to define a model type that subtypes `AdvancedPS.AbstractStateSpaceModel`. -mutable struct Parameters{T<:Real} - a::T - q::T - r::T +mutable struct Parameters{AT<:Real,QT<:Real,RT<:Real} + a::AT + q::QT + r::RT end -struct LinearGaussianDynamics{T<:Real} <: SSMProblems.LatentDynamics{T,T} - a::T - q::T +struct GaussianPrior{ΣT<:Real} <: SSMProblems.StatePrior + σ::ΣT end -function SSMProblems.distribution(dyn::LinearGaussianDynamics{T}; kwargs...) where {T<:Real} - return Normal(zero(T), sqrt(dyn.q^2 / (1 - dyn.a^2))) +struct LinearGaussianDynamics{AT<:Real,QT<:Real} <: SSMProblems.LatentDynamics + a::AT + q::QT +end + +function SSMProblems.distribution(prior::GaussianPrior; kwargs...) + return Normal(0, prior.σ) end function SSMProblems.distribution(dyn::LinearGaussianDynamics, step::Int, state; kwargs...) return Normal(dyn.a * state, dyn.q) end -struct LinearGaussianObservation{T<:Real} <: SSMProblems.ObservationProcess{T,T} - r::T +struct LinearGaussianObservation{RT<:Real} <: SSMProblems.ObservationProcess + r::RT end function SSMProblems.distribution( @@ -58,9 +62,10 @@ function SSMProblems.distribution( end function LinearGaussianStateSpaceModel(θ::Parameters) + prior = GaussianPrior(sqrt(θ.q^2 / (1 - θ.a^2))) dyn = LinearGaussianDynamics(θ.a, θ.q) obs = LinearGaussianObservation(θ.r) - return SSMProblems.StateSpaceModel(dyn, obs) + return SSMProblems.StateSpaceModel(prior, dyn, obs) end # Everything is now ready to simulate some data. @@ -75,8 +80,9 @@ plot!(y; seriestype=:scatter, label="y", xlabel="t", mc=:red, ms=2, ma=0.5) # `AdvancedPS` subscribes to the `AbstractMCMC` API. To sample we just need to define a Particle Gibbs kernel # and a model interface. -pgas = AdvancedPS.PGAS(20) -chains = sample(rng, true_model(y), pgas, 500; progress=false); +N = 20 +pgas = AdvancedPS.PGAS(N) +chains = sample(rng, AdvancedPS.TracedSSM(true_model, y), pgas, 500; progress=false); #md nothing #hide # @@ -104,4 +110,4 @@ plot( xlabel="Iteration", ylabel="Update rate", ) -hline!([1 - 1 / length(chains)]; label="N: $(length(chains))") +hline!([1 - 1 / N]; label="N: $(N)") diff --git a/examples/levy-ssm/script.jl b/examples/levy-ssm/script.jl index e42a2eec..0ffa6199 100644 --- a/examples/levy-ssm/script.jl +++ b/examples/levy-ssm/script.jl @@ -27,17 +27,16 @@ function simulate( t = t0 truncated = last_jump < tolerance while !truncated - t += rand(rng, Exponential(one(T) / rate)) - xi = one(T) / (β * (exp(t / C) - one(T))) - prob = (one(T) + β * xi) * exp(-β * xi) + t += rand(rng, Exponential(1 / rate)) + xi = 1 / (β * (exp(t / C) - 1)) + prob = (1 + β * xi) * exp(-β * xi) if rand(rng) < prob push!(jumps, xi) last_jump = xi end truncated = last_jump < tolerance end - times = rand(rng, Uniform(start, finish), length(jumps)) - return GammaPath(jumps, times) + return GammaPath(jumps, rand(rng, Uniform(start, finish), length(jumps))) end end @@ -47,85 +46,66 @@ function integral(times::Array{<:Real}, path::GammaPath) end end -struct LangevinDynamics{T} - A::Matrix{T} - L::Vector{T} - θ::T - H::Vector{T} - σe::T +struct LangevinDynamics{AT<:AbstractMatrix,LT<:AbstractVector,θT<:Real} + A::AT + L::LT + θ::θT end -struct NormalMeanVariance{T} - μ::T - σ::T +function Base.exp(dyn::LangevinDynamics, dt) + f_val = exp(dyn.θ * dt) + return [1 (f_val - 1)/dyn.θ; 0 f_val] end -f(dt, θ) = exp(θ * dt) -function Base.exp(dyn::LangevinDynamics{T}, dt::T) where {T<:Real} - let θ = dyn.θ - f_val = f(dt, θ) - return [one(T) (f_val - 1)/θ; zero(T) f_val] - end +function meancov(t, dyn::LangevinDynamics, path::GammaPath, dist::Normal) + fts = exp.(Ref(dyn), (t .- path.times)) .* Ref(dyn.L) + μ = sum(@. fts * mean(dist) * path.jumps) + Σ = sum(@. fts * transpose(fts) * var(dist) * path.jumps) + return μ, Σ + eltype(Σ)(1e-6) * I end -function meancov( - t::T, dyn::LangevinDynamics, path::GammaPath, nvm::NormalMeanVariance -) where {T<:Real} - μ = zeros(T, 2) - Σ = zeros(T, (2, 2)) - let times = path.times, jumps = path.jumps, μw = nvm.μ, σw = nvm.σ - for (v, z) in zip(times, jumps) - ft = exp(dyn, (t - v)) * dyn.L - μ += ft .* μw .* z - Σ += ft * transpose(ft) .* σw^2 .* z - end - - # Guarantees positive semi-definiteness - return μ, Σ + T(1e-6) * I - end +struct LevyPrior{XT<:AbstractVector,ΣT<:AbstractMatrix} <: StatePrior + μ::XT + Σ::ΣT end -struct LevyLangevin{T} <: LatentDynamics{T,Vector{T}} - dt::T - dyn::LangevinDynamics{T} - process::GammaProcess{T} - nvm::NormalMeanVariance{T} -end +SSMProblems.distribution(proc::LevyPrior) = MvNormal(proc.μ, proc.Σ) -function SSMProblems.distribution(proc::LevyLangevin{T}) where {T<:Real} - return MultivariateNormal(zeros(T, 2), I) +struct LevyLangevin{T<:Real,LT<:LangevinDynamics,ΓT<:GammaProcess,DT<:Normal} <: + SSMProblems.LatentDynamics + dt::T + dyn::LT + process::ΓT + dist::DT end -function SSMProblems.distribution(proc::LevyLangevin{T}, step::Int, state) where {T<:Real} +function SSMProblems.distribution(proc::LevyLangevin, step::Int, state) dt = proc.dt path = simulate(rng, proc.process, dt, (step - 1) * dt, step * dt) - μ, Σ = meancov(step * dt, proc.dyn, path, proc.nvm) - return MultivariateNormal(exp(proc.dyn, dt) * state + μ, Σ) + μ, Σ = meancov(step * dt, proc.dyn, path, proc.dist) + return MvNormal(exp(proc.dyn, dt) * state + μ, Σ) end -struct LinearGaussianObservation{T<:Real} <: ObservationProcess{T,T} - H::Vector{T} - R::T +struct LinearGaussianObservation{HT<:AbstractVector,RT<:Real} <: + SSMProblems.ObservationProcess + H::HT + R::RT end -function SSMProblems.distribution(proc::LinearGaussianObservation, step::Int, state) +function SSMProblems.distribution(proc::LinearGaussianObservation, ::Int, state) return Normal(transpose(proc.H) * state, proc.R) end -function LevyModel(dt, θ, σe, C, β, μw, σw; ϵ=1e-10) - A = [0.0 1.0; 0.0 θ] - L = [0.0; 1.0] - H = [1.0, 0] - +function LevyModel(dt, θ, σe, C, β, μw, σw; kwargs...) dyn = LevyLangevin( dt, - LangevinDynamics(A, L, θ, H, σe), - GammaProcess(C, β; ϵ), - NormalMeanVariance(μw, σw), + LangevinDynamics([0 1; 0 θ], [0; 1], θ), + GammaProcess(C, β; kwargs...), + Normal(μw, σw), ) - obs = LinearGaussianObservation(H, σe) - return StateSpaceModel(dyn, obs) + obs = LinearGaussianObservation([1; 0], σe) + return SSMProblems.StateSpaceModel(LevyPrior(zeros(Bool, 2), I(2)), dyn, obs) end # Levy SSM with Langevin dynamics @@ -139,7 +119,7 @@ end # Simulation parameters N = 200 ts = range(0, 100; length=N) -levyssm = LevyModel(step(ts), θ, 1.0, 1.0, 1.0, 0.0, 1.0); +levyssm = LevyModel(step(ts), -0.5, 1, 1.0, 1.0, 0, 1); # Simulate data rng = Random.MersenneTwister(1234); @@ -147,10 +127,10 @@ _, X, Y = sample(rng, levyssm, N); # Run sampler pg = AdvancedPS.PGAS(50); -chains = sample(rng, levyssm(Y), pg, 100); +chains = sample(rng, AdvancedPS.TracedSSM(levyssm, Y), pg, 100; progress=false); # Concat all sampled states -marginal_states = hcat([chain.trajectory.model.X for chain in chains]...) +marginal_states = hcat([chain.trajectory.model.X for chain in chains]...); # Plot marginal state and jump intensities for one trajectory p1 = plot( @@ -166,7 +146,6 @@ plot!( label="Marginal State (x2)", ) -# TODO: collect jumps from the model p2 = scatter([], []; color=:darkorange, label="Jumps") plot( diff --git a/examples/particle-gibbs/script.jl b/examples/particle-gibbs/script.jl index 1eb782dc..b10647a8 100644 --- a/examples/particle-gibbs/script.jl +++ b/examples/particle-gibbs/script.jl @@ -52,29 +52,34 @@ end # ``` # with the initial distribution $f_0(x) = \mathcal{N}(0, q^2)$. # Here we assume the static parameters $\theta = (a^2, q^2)$ are known and we are only interested in sampling from the latent state $x_t$. -struct LinearGaussianDynamics{T<:Real} <: SSMProblems.LatentDynamics{T,T} - a::T - q::T +struct GaussianPrior{T<:Real} <: SSMProblems.StatePrior + σ::T end -function SSMProblems.distribution(dyn::LinearGaussianDynamics{T}) where {T<:Real} - return Normal(zero(T), dyn.q) +function SSMProblems.distribution(proc::GaussianPrior) + return Normal(0, proc.σ) +end + +struct LinearGaussianDynamics{AT<:Real,QT<:Real} <: SSMProblems.LatentDynamics + a::AT + q::QT end function SSMProblems.distribution(dyn::LinearGaussianDynamics, ::Int, state) return Normal(dyn.a * state, dyn.q) end -struct StochasticVolatility{T<:Real} <: SSMProblems.ObservationProcess{T,T} end +struct StochasticVolatility <: SSMProblems.ObservationProcess end -function SSMProblems.distribution(::StochasticVolatility{T}, ::Int, state) where {T<:Real} - return Normal(zero(T), exp((1 / 2) * state)) +function SSMProblems.distribution(::StochasticVolatility, ::Int, state) + return Normal(0, exp(state / 2)) end -function LinearGaussianStochasticVolatilityModel(a::T, q::T) where {T<:Real} +function LinearGaussianStochasticVolatilityModel(a, q) + prior = GaussianPrior(q) dyn = LinearGaussianDynamics(a, q) - obs = StochasticVolatility{T}() - return SSMProblems.StateSpaceModel(dyn, obs) + obs = StochasticVolatility() + return SSMProblems.StateSpaceModel(prior, dyn, obs) end #md nothing #hide @@ -90,7 +95,7 @@ plot(x; label="x", xlabel="t") plot(y; label="y", xlabel="t") # Here we use the particle gibbs kernel without adaptive resampling. -model = true_model(y) +model = AdvancedPS.TracedSSM(true_model, y) pg = AdvancedPS.PG(20, 1.0) chains = sample(rng, model, pg, 200; progress=false); #md nothing #hide diff --git a/src/model.jl b/src/model.jl index 1b89198d..2b76a714 100644 --- a/src/model.jl +++ b/src/model.jl @@ -10,21 +10,18 @@ const Particle = Trace const SSMTrace{R} = Trace{<:SSMProblems.AbstractStateSpaceModel,R} const GenericTrace{R} = Trace{<:AbstractGenericModel,R} -mutable struct TracedSSM{SSM,XT,YT} <: SSMProblems.AbstractStateSpaceModel +mutable struct TracedSSM{SSM} <: SSMProblems.AbstractStateSpaceModel model::SSM - X::Vector{XT} - Y::Vector{YT} - + X + Y function TracedSSM( - model::SSMProblems.StateSpaceModel{T,LD,OP}, Y::Vector{YT} - ) where {T,LD,OP,YT} - XT = eltype(LD) - @assert eltype(OP) == YT - return new{SSMProblems.StateSpaceModel{T,LD,OP},XT,YT}(model, Vector{XT}(), Y) + model::SSM, Y::AbstractVector + ) where {SSM<:SSMProblems.StateSpaceModel} + return new{SSM}(model, [], Y) end end -(model::SSMProblems.StateSpaceModel)(Y::AbstractVector) = TracedSSM(model, Y) +prior(ssm::TracedSSM) = ssm.model.prior dynamics(ssm::TracedSSM, step::Int) = ssm.model.dyn observation(ssm::TracedSSM, step::Int) = ssm.model.obs diff --git a/src/pgas.jl b/src/pgas.jl index c14bc736..5450e299 100644 --- a/src/pgas.jl +++ b/src/pgas.jl @@ -59,7 +59,7 @@ function advance!(particle::SSMTrace, isref::Bool=false) if !isref if running_step == 1 - new_state = SSMProblems.simulate(particle.rng, dynamics(model, running_step)) + new_state = SSMProblems.simulate(particle.rng, prior(model)) else current_state = model.X[running_step - 1] new_state = SSMProblems.simulate( @@ -76,7 +76,13 @@ function advance!(particle::SSMTrace, isref::Bool=false) ) # Accept transition and move the time index/rng counter - !isref && push!(model.X, new_state) + if !isref + if running_step == 1 + model.X = [new_state] + else + push!(model.X, new_state) + end + end inc_counter!(particle.rng) return score diff --git a/test/container.jl b/test/container.jl index 07308e21..545b379d 100644 --- a/test/container.jl +++ b/test/container.jl @@ -1,19 +1,20 @@ @testset "container.jl" begin # Since the extension would hide the low level function call API - struct LogPDynamics{T} <: LatentDynamics{T,T} end - struct LogPObservation{T} <: ObservationProcess{T,T} + struct LogPPrior <: StatePrior end + struct LogPDynamics <: LatentDynamics end + struct LogPObservation{T} <: ObservationProcess logp::T end SSMProblems.logdensity(proc::LogPObservation, ::Int, state, observation) = proc.logp - SSMProblems.distribution(proc::LogPDynamics, ::Int, state) = Uniform() - SSMProblems.distribution(::LogPDynamics) = Uniform() + SSMProblems.distribution(::LogPDynamics, ::Int, state) = Uniform() + SSMProblems.distribution(::LogPPrior) = Uniform() function LogPModel(logp::T) where {T<:Real} - ssm = StateSpaceModel(LogPDynamics{T}(), LogPObservation(logp)) + ssm = StateSpaceModel(LogPPrior(), LogPDynamics(), LogPObservation(logp)) # pick some arbitrarily large observables - return ssm(ones(T, 10)) + return AdvancedPS.TracedSSM(ssm, ones(T, 10)) end @testset "copy particle container" begin diff --git a/test/linear-gaussian.jl b/test/linear-gaussian.jl index 6409ab63..5e3fd4ce 100644 --- a/test/linear-gaussian.jl +++ b/test/linear-gaussian.jl @@ -56,14 +56,17 @@ end Xf, ll = kalmanfilter(M, 1 => G0, y_pairs) # Define AdvancedPS model - struct LinearGaussianDynamics{T<:Real} <: LatentDynamics{T,T} - a::T - b::T - q::T + struct GaussianPrior{XT,ΣT} <: StatePrior + μ::XT + σ::ΣT end - function SSMProblems.distribution(proc::LinearGaussianDynamics{T}; kwargs...) where {T} - return Normal(convert(T, X0), convert(T, P0)) + SSMProblems.distribution(proc::GaussianPrior; kwargs...) = Normal(proc.μ, proc.σ) + + struct LinearGaussianDynamics{AT<:Real,BT<:Real,QT<:Real} <: LatentDynamics + a::AT + b::BT + q::QT end function SSMProblems.distribution( @@ -72,9 +75,9 @@ end return Normal(proc.a * state + proc.b, proc.q) end - struct LinearGaussianObservation{T<:Real} <: ObservationProcess{T,T} - h::T - r::T + struct LinearGaussianObservation{HT<:Real,RT<:Real} <: ObservationProcess + h::HT + r::RT end function SSMProblems.distribution( @@ -83,14 +86,15 @@ end return Normal(proc.h * state, proc.r) end - function LinearGaussianStateSpaceModel(a, b, q, h, r) + function LinearGaussianStateSpaceModel(x0, σ0, a, b, q, h, r) + prior = GaussianPrior(x0, σ0) dyn = LinearGaussianDynamics(a, b, q) obs = LinearGaussianObservation(h, r) - return StateSpaceModel(dyn, obs) + return StateSpaceModel(prior, dyn, obs) end - lgssm = LinearGaussianStateSpaceModel(A, B, Q, H, R) - model = lgssm(ys) + lgssm = LinearGaussianStateSpaceModel(X0, P0, A, B, Q, H, R) + model = AdvancedPS.TracedSSM(lgssm, ys) @testset "PGAS" begin pgas = AdvancedPS.PGAS(N_PARTICLES) diff --git a/test/pgas.jl b/test/pgas.jl index 6a909ac1..bba6d28c 100644 --- a/test/pgas.jl +++ b/test/pgas.jl @@ -1,36 +1,42 @@ @testset "pgas.jl" begin - mutable struct Params{T<:Real} - a::T - q::T - r::T + mutable struct Params{AT<:Real,QT<:Real,RT<:Real} + a::AT + q::QT + r::RT end - struct BaseModelDynamics{T<:Real} <: LatentDynamics{T,T} - a::T - q::T + struct BaseModelPrior{QT<:Real} <: StatePrior + q::QT end - function SSMProblems.distribution(dyn::BaseModelDynamics{T}) where {T<:Real} - return Normal(zero(T), dyn.q) + struct BaseModelDynamics{AT<:Real,QT<:Real} <: LatentDynamics + a::AT + q::QT + end + + function SSMProblems.distribution(prior::BaseModelPrior) + return Normal(0, prior.q) end function SSMProblems.distribution(dyn::BaseModelDynamics, step::Int, state) return Normal(dyn.a * state, dyn.q) end - struct BaseModelObservation{T<:Real} <: ObservationProcess{T,T} - r::T + struct BaseModelObservation{RT<:Real} <: ObservationProcess + r::RT end function SSMProblems.distribution(obs::BaseModelObservation, step::Int, state) return Normal(state, obs.r) end - function BaseModel(θ::Params{T}) where {T<:Real} + function BaseModel(θ::Params{AT,QT,RT}) where {AT<:Real,QT<:Real,RT<:Real} + T = promote_type(AT, QT, RT) + prior = BaseModelPrior(θ.q) dyn = BaseModelDynamics(θ.a, θ.q) obs = BaseModelObservation(θ.r) - ssm = StateSpaceModel(dyn, obs) - return ssm(zeros(T, 3)) + ssm = StateSpaceModel(prior, dyn, obs) + return AdvancedPS.TracedSSM(ssm, zeros(T, 3)) end @testset "fork reference" begin