-
Notifications
You must be signed in to change notification settings - Fork 12
Open
Description
I'm wondering if a model like the one I present below is possible? The basic problem is one where the state isn't a single distribution, but a collection of distributions, which all evolve in a Markovian manner. I don't know exactly how this works internally, so the code is based on the assumption that the state is propagated forward from initialization to transition to observation.
n_trials, n_cols = size(X)
Parameters = @NamedTuple begin
X::Matrix
lam_lapse_init::Float64
sigma_set_init::Array{Float64}
mu_init::Array{Float64}
n_trials::Int64
n_cols::Int64
end
mutable struct PF <: AdvancedPS.AbstractStateSpaceModel
W::Matrix
lam_lapse::Array
sigma_set::Matrix
theta::Parameters
PF(theta::Parameters) = new(
zeros(Float64, theta.n_trials, theta.n_cols),
zeros(Float64, theta.n_trials),
zeros(Float64, theta.n_trials, theta.n_cols),
theta
)
end
function init_step(m::PF)
return (
truncated(Normal(m.theta.lam_lapse_init, 0.1), lower=-10),
truncated(Normal(m.theta.sigma_set_init[1], 0.1), lower=0.),
truncated(Normal(m.theta.sigma_set_init[2], 0.1), lower=0.),
truncated(Normal(m.theta.sigma_set_init[3], 0.1), lower=0.),
MvNormal(m.theta.mu_init, 1.)
)
end
AdvancedPS.initialization(m::PF) = init_step(m)
function transition_step(m::PF, state)
return (
truncated(Normal(state[1], 0.1), lower=-10), # lam_lapse
truncated(Normal(state[2], 0.1), lower=0.), # sigma1
truncated(Normal(state[3], 0.1), lower=0.), # sigma2
truncated(Normal(state[4], 0.1), lower=0.), # sigma3
MvNormal(state[5], Diagonal([state[2], state[3], state[4]])) # mu
)
end
AdvancedPS.transition(m::PF, state) = transition_step(m, state)
function obs_density(m::PF, state, t)
lam_lapse, _, _,_, mu = state
lapse = logistic(lam_lapse)
prob = (1 - lapse) * logistic(m.theta.X[t, :]'mu) + lapse * 0.5
return Bernoulli(prob)
end
function AdvancedPS.observation(m::PF ,state, t)
return logpdf(obs_density(m, state, t), y[t])
end
AdvancedPS.isdone(m::PF, t) = t > m.theta.n_trials
n_particles = 20
n_samples = 200
rng = MersenneTwister(2342)
theta0 = Parameters(
(-9, zeros(3), zeros(3), n_trials, n_cols)
)
model = PF(theta0)
pgas = AdvancedPS.PGAS(n_particles)
chains = sample(rng, model, pgas, n_samples; progress=true);
Metadata
Metadata
Assignees
Labels
No labels