diff --git a/Project.toml b/Project.toml index 2e1b5e3d3a..e7068c140a 100644 --- a/Project.toml +++ b/Project.toml @@ -18,6 +18,7 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" EllipticalSliceSampling = "cad2338a-1db2-11e9-3401-43bc07c9ede2" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +InferenceObjects = "b5cf5a8d-e756-4ee3-b014-01d49d192c00" Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" @@ -53,6 +54,7 @@ DynamicHMC = "3.4" DynamicPPL = "0.23" EllipticalSliceSampling = "0.5, 1" ForwardDiff = "0.10.3" +InferenceObjects = "0.2" Libtask = "0.7, 0.8" LogDensityProblems = "2" LogDensityProblemsAD = "1.4" diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index 73f190dcc6..bd7bc4e2b6 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -35,6 +35,7 @@ import LogDensityProblems import LogDensityProblemsAD import Random import MCMCChains +using InferenceObjects: InferenceObjects import StatsBase: predict export InferenceAlgorithm, @@ -70,6 +71,15 @@ export InferenceAlgorithm, isgibbscomponent, externalsampler +const turing_inferencedata_key_map = ( + hamiltonian_energy = :energy, + hamiltonian_energy_error = :energy_error, + is_adapt = :tune, + max_hamiltonian_energy_error = :max_energy_error, + nom_step_size = :step_size_nom, + numerical_error = :diverging, +) + ####################### # Sampler abstraction # ####################### @@ -450,6 +460,81 @@ end DynamicPPL.loadstate(chain::MCMCChains.Chains) = chain.info[:samplerstate] +# Default InferenceObjects constructor +# This is type piracy! +function AbstractMCMC.bundle_samples( + ts::Vector, + model::AbstractModel, + spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior}, + state, + chain_type::Type{InferenceObjects.InferenceData}; + group = spl isa SampleFromPrior ? :prior : :posterior, + save_state = false, + stats = missing, + dims=(;), + coords=(;), + kwargs... +) + sample = map(t -> map(v -> length(v[1]) == 1 ? v[1][1] : v[1], getparams(t)), ts) + sample_stats = map(_rename_sample_stats ∘ metadata, ts) + + # Set up the info tuple. + attrs = OrderedDict{String,Any}() + if save_state + attrs["model"] = model + attrs["sampler"] = spl + attrs["samplerstate"] = state + end + + # Merge in the timing info, if available + if !ismissing(stats) + attrs["start_time"] = stats.start + attrs["stop_time"] = stats.stop + end + + # Get the average or final log evidence, if it exists. + le = getlogevidence(ts, spl, state) + if !ismissing(le) + attrs["log_evidence"] = le + end + + # identify if this is posterior or prior + sample_stats_group = group === :prior ? :sample_stats_prior : :sample_stats + + # InferenceData construction. + idata = InferenceObjects.convert_to_inference_data( + [sample]; + group=group, + sample_stats_group => [sample_stats], + attrs=attrs, + dims=dims, + coords=coords, + ) + return idata +end + +function AbstractMCMC.chainsstack(c::AbstractVector{<:InferenceObjects.InferenceData}) + nchains = length(c) + nchains == 1 && return c[1] + groups = map(keys(first(c))) do k + k => AbstractMCMC.chainsstack(map(idata -> idata[k], c)) + end + return InferenceObjects.InferenceData(; groups...) +end +function AbstractMCMC.chainsstack(c::AbstractVector{<:InferenceObjects.Dataset}) + nchains = length(c) + nchains == 1 && return c[1] + # TODO: gather our metadata into vectors instead of replacing + group = cat(c...; dims=:chain) + # give each chain a different index + return InferenceObjects.DimensionalData.set(group, :chain => Base.OneTo(nchains)) +end + +function _rename_sample_stats(stats::NamedTuple) + new_keys = map(k -> get(turing_inferencedata_key_map, k, k), keys(stats)) + return NamedTuple{new_keys}(values(stats)) +end + ####################################### # Concrete algorithm implementations. # #######################################