Skip to content

Commit 48f83f6

Browse files
committed
Find an edge case
1 parent ff48fae commit 48f83f6

File tree

3 files changed

+8
-7
lines changed

3 files changed

+8
-7
lines changed

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ end
4444

4545
function chain_sample_to_varname_dict(c::MCMCChains.Chains, sample_idx, chain_idx)
4646
_check_varname_indexing(c)
47-
d = Dict{VarName}()
47+
d = Dict{DynamicPPL.VarName,Any}()
4848
for vn in DynamicPPL.varnames(c)
4949
d[vn] = DynamicPPL.getindex_varname(c, sample_idx, vn, chain_idx)
5050
end
@@ -271,10 +271,7 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha
271271
# return the model's retval (`first`).
272272
first(
273273
DynamicPPL.init!!(
274-
rng,
275-
model,
276-
varinfo,
277-
DynamicPPL.ParamsInit(values_dict, DynamicPPL.PriorInit()),
274+
model, varinfo, DynamicPPL.ParamsInit(values_dict, DynamicPPL.PriorInit())
278275
),
279276
)
280277
end

src/simple_varinfo.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,6 @@ end
463463

464464
# Context implementations
465465

466-
# NOTE: We don't implement `settrans!!(vi, trans, vn)`.
467466
function settrans!!(vi::SimpleVarInfo, trans)
468467
return settrans!!(vi, trans ? DynamicTransformation() : NoTransformation())
469468
end
@@ -473,6 +472,9 @@ end
473472
function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans)
474473
return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, trans)
475474
end
475+
function settrans!!(::SimpleOrThreadSafeSimple, trans::Bool, vn::VarName)
476+
@info "Attempting to call `settrans!!` on a `SimpleVarInfo` for a specific variable `$vn`; this will be ignored"
477+
end
476478

477479
istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation)
478480
istrans(vi::SimpleVarInfo, ::VarName) = istrans(vi)

test/test_util.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,10 @@ function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::I
8181
varnames = collect(varnames)
8282
# Construct matrix of values
8383
vals = [get(dict, vn, missing) for dict in dicts, vn in varnames]
84+
# Construct mapping of varnames to symbols
85+
vns_to_syms = Dict{VarName,Symbol}(zip(varnames, Symbol.(varnames)))
8486
# Construct and return the Chains object
85-
return Chains(vals, varnames)
87+
return Chains(vals, varnames; info=(varname_to_symbol=vns_to_syms,))
8688
end
8789
function make_chain_from_prior(model::Model, n_iters::Int)
8890
return make_chain_from_prior(Random.default_rng(), model, n_iters)

0 commit comments

Comments
 (0)