Skip to content

unflatten can't handle non-float eltypes #842

@torfjelde

Description

@torfjelde

This issue was posted in slack:

using Random, Distributions, Turing

mutable struct TestSampler <: AbstractMCMC.AbstractSampler
end

struct Transition{T}
    θ::AbstractVector{T}
end

struct SamplerState{T<:Transition}
    transition::T
end

function AbstractMCMC.step(
    rng::Random.AbstractRNG,
    model::AbstractMCMC.LogDensityModel,
    spl::TestSampler;
    kwargs...
)
    theta_init = rand(rng, Bernoulli(0.5), 5)

    transition = Transition(theta_init)
    return transition, SamplerState(transition)
end

function AbstractMCMC.step(
    rng::Random.AbstractRNG,
    model::AbstractMCMC.LogDensityModel,
    sampler::TestSampler,
    state::SamplerState;
    kwargs...
)
    theta_proposal = rand(rng, Bernoulli(0.5), 5)

    transition = Transition(theta_proposal)
    return transition, SamplerState(transition)
end


@model function test_model(y_obs)
    y ~ filldist(Bernoulli(0.2), length(y_obs))

    for i in eachindex(y_obs)
        y_obs[i] ~ y[i] ? Normal(1, 0.1) : Normal(0, 0.1)
    end
end

y_obs = rand(Bernoulli(0.2), 5) .+ rand(Normal(0, 0.1), 5)
model = test_model(y_obs)

sampler = externalsampler(TestSampler(), unconstrained=false)

chain = sample(model, sampler, 100)

with stacktrace

ERROR: InexactError: Bool(-87.3902668081344)
Stacktrace:
  [1] Bool
    @ ./float.jl:251 [inlined]
  [2] convert
    @ ./number.jl:7 [inlined]
  [3] RefValue
    @ ./refvalue.jl:8 [inlined]
  [4] unflatten
    @ ~/.julia/packages/DynamicPPL/senfM/src/varinfo.jl:225 [inlined]
  [5] unflatten(vi::DynamicPPL.TypedVarInfo{@NamedTuple{y::DynamicPPL.Metadata{…}}, Float64}, x::Vector{Bool})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/senfM/src/varinfo.jl:220
  [6] transition_to_turing(f::LogDensityFunction{…}, transition::Transition{…})
    @ Turing.Inference ~/.julia/packages/Turing/oFGEb/src/mcmc/abstractmcmc.jl:11
  [7] transition_to_turing(f::LogDensityProblemsADForwardDiffExt.ForwardDiffLogDensity{…}, transition::Transition{…})
    @ Turing.Inference ~/.julia/packages/Turing/oFGEb/src/mcmc/abstractmcmc.jl:17
  [8] step(rng::TaskLocalRNG, model::DynamicPPL.Model{…}, sampler_wrapper::DynamicPPL.Sampler{…}; initial_state::Nothing, initial_params::Nothing, kwargs::@Kwargs{})
    @ Turing.Inference ~/.julia/packages/Turing/oFGEb/src/mcmc/abstractmcmc.jl:103
  [9] step
    @ ~/.julia/packages/Turing/oFGEb/src/mcmc/abstractmcmc.jl:58 [inlined]
 [10] macro expansion
    @ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:159 [inlined]
 [11] macro expansion
    @ ~/.julia/packages/ProgressLogging/6KXlp/src/ProgressLogging.jl:328 [inlined]
 [12] macro expansion
    @ ~/.julia/packages/AbstractMCMC/FSyVk/src/logging.jl:9 [inlined]
 [13] mcmcsample(rng::TaskLocalRNG, model::DynamicPPL.Model{…}, sampler::DynamicPPL.Sampler{…}, N::Int64; progress::Bool, progressname::String, callback::Nothing, num_warmup::Int64, discard_initial::Int64, thinning::Int64, chain_type::Type, initial_state::Nothing, kwargs::@Kwargs{})
    @ AbstractMCMC ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:142
 [14] sample(rng::TaskLocalRNG, model::DynamicPPL.Model{…}, sampler::DynamicPPL.Sampler{…}, N::Int64; chain_type::Type, resume_from::Nothing, initial_state::Nothing, kwargs::@Kwargs{})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/senfM/src/sampler.jl:107
 [15] sample
    @ ~/.julia/packages/DynamicPPL/senfM/src/sampler.jl:97 [inlined]
 [16] #sample#6
    @ ~/.julia/packages/Turing/oFGEb/src/mcmc/Inference.jl:321 [inlined]
 [17] sample
    @ ~/.julia/packages/Turing/oFGEb/src/mcmc/Inference.jl:312 [inlined]
 [18] #sample#5
    @ ~/.julia/packages/Turing/oFGEb/src/mcmc/Inference.jl:309 [inlined]
 [19] sample(model::DynamicPPL.Model{…}, alg::Turing.Inference.ExternalSampler{…}, N::Int64)
    @ Turing.Inference ~/.julia/packages/Turing/oFGEb/src/mcmc/Inference.jl:306
 [20] top-level scope
    @ ~/nectar-source/seromix/src/mh_sampler.jl:62

This is caused by

return VarInfo(md, Base.RefValue{eltype(x)}(getlogp(vi)), Ref(get_num_produce(vi)))

where we use eltype(x) to construct the type of the logp field (we do this instead of extracting the type from the existing logp field to support stuff like ForwardDiff.jl).

Fix incoming.

Ref: https://julialang.slack.com/archives/CCYDC34A0/p1741825508922639?thread_ts=1741825446.029949&cid=CCYDC34A0

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions