Skip to content

PGAS example where state consists of a tuple of distributions? #82

@YSanchezAraujo

Description

@YSanchezAraujo

I'm wondering if a model like the one I present below is possible? The basic problem is one where the state isn't a single distribution, but a collection of distributions, which all evolve in a Markovian manner. I don't know exactly how this works internally, so the code is based on the assumption that the state is propagated forward from initialization to transition to observation.

n_trials, n_cols = size(X)

Parameters = @NamedTuple begin
    X::Matrix
    lam_lapse_init::Float64
    sigma_set_init::Array{Float64}
    mu_init::Array{Float64}
    n_trials::Int64
    n_cols::Int64
end

mutable struct PF <: AdvancedPS.AbstractStateSpaceModel
    W::Matrix
    lam_lapse::Array
    sigma_set::Matrix
    theta::Parameters
    PF(theta::Parameters) = new(
        zeros(Float64, theta.n_trials, theta.n_cols),
        zeros(Float64, theta.n_trials),
        zeros(Float64, theta.n_trials, theta.n_cols),
        theta
    )
end


function init_step(m::PF)
    return (
        truncated(Normal(m.theta.lam_lapse_init, 0.1), lower=-10),
        truncated(Normal(m.theta.sigma_set_init[1], 0.1), lower=0.),
        truncated(Normal(m.theta.sigma_set_init[2], 0.1), lower=0.),
        truncated(Normal(m.theta.sigma_set_init[3], 0.1), lower=0.),
        MvNormal(m.theta.mu_init, 1.)
    )
end

AdvancedPS.initialization(m::PF) = init_step(m)

function transition_step(m::PF, state)
    return (
        truncated(Normal(state[1], 0.1), lower=-10), # lam_lapse
        truncated(Normal(state[2], 0.1), lower=0.), # sigma1
        truncated(Normal(state[3], 0.1), lower=0.), # sigma2
        truncated(Normal(state[4], 0.1), lower=0.), # sigma3
        MvNormal(state[5], Diagonal([state[2], state[3], state[4]])) # mu
    )
end

AdvancedPS.transition(m::PF, state) = transition_step(m, state)

function obs_density(m::PF, state, t)
    lam_lapse, _, _,_, mu = state
    lapse = logistic(lam_lapse)
    prob = (1 - lapse) * logistic(m.theta.X[t, :]'mu) + lapse * 0.5
    return Bernoulli(prob)
end  

function AdvancedPS.observation(m::PF ,state, t)
    return logpdf(obs_density(m, state, t), y[t])
end

AdvancedPS.isdone(m::PF, t) = t > m.theta.n_trials

n_particles = 20
n_samples = 200
rng = MersenneTwister(2342)

theta0 = Parameters(
    (-9, zeros(3), zeros(3), n_trials, n_cols)
    )

model = PF(theta0)
pgas = AdvancedPS.PGAS(n_particles)
chains = sample(rng, model, pgas, n_samples; progress=true);

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions