-
-
Notifications
You must be signed in to change notification settings - Fork 1
Description
This would allow a Turing user to pass chain_type=InferenceData to e.g. AbstractMCMC.sample and get the draws automatically packed into an InferenceData with suitable groups. Previously I had hacked together a proof-of-concept at https://github.com/sethaxen/DynamicPPLInferenceObjects.jl, but it looks like now DynamicPPL/AbstractMCMC's APIs are good enough that this can be done in a less hackish way in a package extension.
Useful reference implementations are https://github.com/penelopeysm/FlexiChains.jl/blob/main/ext/FlexiChainsDynamicPPLExt.jl and https://github.com/TuringLang/DynamicPPL.jl/blob/main/ext/DynamicPPLMCMCChainsExt.jl.
Note that we can probably only support models where all variables are scalars or dense AbstractArrays, so an initial implementation would warn if there are other more structured Julia types and discard these. We could potentially extend support to these other types by allowing the user to register global invertible maps that would interconvert between that type and a keyed collection of scalar or dense arrays. The MCMCChains extension linked above does something similar.
A side benefit of once this exists is that we can remove ArviZ.from_mcmcchains and add instead a from_turing, which could accept any chain_type by simply converting to an AbstractMatrix{<:ParamsWithStats} and then back to an InferenceData.