Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,6 @@ else
using ..MCMCChains: MCMCChains
end

_has_varname_to_symbol(info::NamedTuple{names}) where {names} = :varname_to_symbol in names
function _check_varname_indexing(c::MCMCChains.Chains)
return DynamicPPL.supports_varname_indexing(c) ||
error("Chains do not support indexing using $vn.")
end

# Load state from a `Chains`: By convention, it is stored in `:samplerstate` metadata
function DynamicPPL.loadstate(chain::MCMCChains.Chains)
if !haskey(chain.info, :samplerstate)
Expand All @@ -26,10 +20,17 @@ function DynamicPPL.loadstate(chain::MCMCChains.Chains)
return chain.info[:samplerstate]
end

# A few methods needed.
_has_varname_to_symbol(info::NamedTuple{names}) where {names} = :varname_to_symbol in names

function DynamicPPL.supports_varname_indexing(chain::MCMCChains.Chains)
return _has_varname_to_symbol(chain.info)
end

function _check_varname_indexing(c::MCMCChains.Chains)
return DynamicPPL.supports_varname_indexing(c) ||
error("Chains do not support indexing using `VarName`s.")
end

function DynamicPPL.getindex_varname(
c::MCMCChains.Chains, sample_idx, vn::DynamicPPL.VarName, chain_idx
)
Expand Down
Loading