diff --git a/Project.toml b/Project.toml index cc7d8b8d..5bd32bd5 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probabilistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "5.8.2" +version = "5.9.0" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" @@ -17,6 +17,7 @@ ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" @@ -29,6 +30,7 @@ LoggingExtras = "0.4, 0.5, 1" ProgressLogging = "0.1" StatsBase = "0.32, 0.33, 0.34" TerminalLoggers = "0.1" +Test = "1" Transducers = "0.4.30" UUIDs = "<0.0.1, 1" julia = "1.6" diff --git a/docs/Project.toml b/docs/Project.toml index f74dfb58..d5fc343e 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,4 +1,5 @@ [deps] +AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/docs/src/api.md b/docs/src/api.md index 94b006ab..b11f3c17 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -144,13 +144,24 @@ AbstractMCMC defines the abstract type `AbstractChains` for Markov chains. AbstractMCMC.AbstractChains ``` -For chains of this type, AbstractMCMC defines the following two methods. +For chains of this type, AbstractMCMC defines the following two **public** methods. ```@docs AbstractMCMC.chainscat AbstractMCMC.chainsstack ``` +The following interface methods are considered experimental and may change even in formally non-breaking releases. + +```@docs +AbstractMCMC.Chains.get_data +AbstractMCMC.Chains.iter_indices +AbstractMCMC.Chains.chain_indices +AbstractMCMC.Chains.niters +AbstractMCMC.Chains.nchains +AbstractMCMC.Chains.test_interface +``` + ## Interacting with states of samplers To make it a bit easier to interact with some arbitrary sampler state, we encourage implementations of `AbstractSampler` to implement the following methods: diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index e103d5a5..a71fd435 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -22,14 +22,6 @@ export sample # Parallel sampling types export MCMCThreads, MCMCDistributed, MCMCSerial -""" - AbstractChains - -`AbstractChains` is an abstract type for an object that stores -parameter samples generated through a MCMC process. -""" -abstract type AbstractChains end - """ AbstractSampler @@ -137,6 +129,7 @@ function setparams!!(model::AbstractModel, state, params) return setparams!!(state, params) end +include("chains.jl") include("samplingstats.jl") include("logging.jl") include("interface.jl") diff --git a/src/chains.jl b/src/chains.jl new file mode 100644 index 00000000..fd57ba6e --- /dev/null +++ b/src/chains.jl @@ -0,0 +1,61 @@ +# AbstractChains interface +# +# NOTE: The entire interface is treated as experimental except for the AbstractChains type +# itself, along with `chainscat` and `chainsstack`. Thus, if you change any of those three, +# it is mandatory to release a breaking version. Other changes to the AbstractChains +# interface can be made in patch releases. + +""" + AbstractMCMC.AbstractChains + +An abstract type for Markov chains, i.e., a data structure which stores samples +obtained from Markov chain Monte Carlo (MCMC) sampling. + +!!! danger "Explicitly experimental" + + Although the abstract type `AbstractMCMC.AbstractChains` itself, along with the + functions `chainscat` and `chainsstack`, are exported and public, please note that *all + other parts of the interface remain experimental and subject to change*. In particular, + breaking changes to the interface may be introduced in formally non-breaking releases. + +Markov chains should generally have dictionary-like behaviour, where keys are mapped to +matrices of values. + +## Interface + +To implement a new subtype of `AbstractChains`, you need to define the following methods: + +- `Base.size` should return a tuple of ints (the exact meaning is left to you) +- `Base.keys` should return a list of keys +- [`AbstractMCMC.Chains.get_data`](@ref)`(chn, key)` +- [`AbstractMCMC.Chains.iter_indices`](@ref)`(chn)` +- [`AbstractMCMC.Chains.chain_indices`](@ref)`(chn)` + +You can optionally define the following methods for efficiency: + +- [`AbstractMCMC.Chains.niters`](@ref)`(chn)` +- [`AbstractMCMC.Chains.nchains`](@ref)`(chn)` +""" +abstract type AbstractChains end + +""" + chainscat(c::AbstractChains...) + +Concatenate multiple chains. + +By default, the chains are concatenated along the third dimension by calling +`cat(c...; dims=3)`. +""" +chainscat(c::AbstractChains...) = cat(c...; dims=3) + +""" + chainsstack(c::AbstractVector) + +Stack chains in `c`. + +By default, the vector of chains is returned unmodified. If `eltype(c) <: AbstractChains`, +then `reduce(chainscat, c)` is called. +""" +chainsstack(c) = c +chainsstack(c::AbstractVector{<:AbstractChains}) = reduce(chainscat, c) +include("experimental/chains.jl") diff --git a/src/experimental/chains.jl b/src/experimental/chains.jl new file mode 100644 index 00000000..29ce8aa3 --- /dev/null +++ b/src/experimental/chains.jl @@ -0,0 +1,123 @@ +module Chains + +using AbstractMCMC: AbstractMCMC, AbstractChains +using Test + +""" + AbstractMCMC.Chains.get_data(chn, key) + +Obtain the data associated with `key` from the `AbstractChain` object `chn`. + +This function should return an `AbstractMatrix` where the rows correspond to iterations and +columns correspond to chains. +""" +function get_data end + +""" + AbstractMCMC.Chains.iter_indices(chn) + +Obtain the indices of each iteration for the `AbstractChains` object `chn`. + +For example, if `chn` contains 1000 samples, but 1000 warmup steps and a thinning factor of +2 was used, then this function should return `1001:2:3000` (or an equivalent vector). + +This function should return an `AbstractVector{<:Integer}`. +""" +function iter_indices end + +""" + AbstractMCMC.Chains.chain_indices(chn) + +Obtain the indices of each chain in the `AbstractChains` object `chn`. + +If there is no special numbering associated with chains, then this function can simply +return `1:nchains(chn)`. However, this function provides the flexibility to have +non-standard chain numbering (e.g. if chains are combined from multiple sources). + +This function should return an `AbstractVector{<:Integer}`. +""" +function chain_indices end + +""" + AbstractMCMC.Chains.niters(chn) + +Obtain the number of iterations in the `AbstractChains` object `chn`. + +The default implementation calculates the length of `AbstractChains.iter_indices(chn)`. You +can define your own method if you have a more efficient way of obtaining this information. +""" +niters(c::AbstractChains) = length(iter_indices(c)) + +""" + AbstractMCMC.Chains.nchains(chn) + +Obtain the number of chains in the `AbstractChains` object `chn`. + +The default implementation calculates the length of `AbstractChains.chain_indices(chn)`. You +can define your own method if you have a more efficient way of obtaining this information. +""" +nchains(c::AbstractChains) = length(chain_indices(c)) + +""" + AbstractMCMC.Chains.test_interface(chn) + +Test that the `AbstractChains` object `chn` implements the required interface. +""" +function test_interface(chn::AbstractChains) + # TODO: Test chainscat, chainsstack + + @testset "Base.size, AbstractMCMC.Chains.niters, AbstractMCMC.Chains.nchains" begin + @test size(chn) isa NTuple{N,Int} where {N} + @test AbstractMCMC.Chains.niters(chn) isa Int + @test AbstractMCMC.Chains.nchains(chn) isa Int + end + + @testset "Base.keys" begin + @test collect(keys(chn)) isa AbstractVector + end + + @testset "AbstractMCMC.Chains.get_data" begin + for k in keys(chn) + data = AbstractMCMC.Chains.get_data(chn, k) + @test data isa AbstractMatrix + @test size(data) == + (AbstractMCMC.Chains.niters(chn), AbstractMCMC.Chains.nchains(chn)) + end + end + + @testset "AbstractMCMC.Chains.iter_indices" begin + ii = AbstractMCMC.Chains.iter_indices(chn) + @test ii isa AbstractVector{<:Integer} + @test length(ii) == AbstractMCMC.Chains.niters(chn) + end + + @testset "AbstractMCMC.Chains.chain_indices" begin + ci = AbstractMCMC.Chains.chain_indices(chn) + @test ci isa AbstractVector{<:Integer} + @test length(ci) == AbstractMCMC.Chains.nchains(chn) + end +end + +# Plotting functions; to be extended by individual chain libraries +function autocorplot end +function autocorplot! end +function energyplot end +function energyplot! end +function forestplot end +function forestplot! end +function meanplot end +function meanplot! end +function mixeddensity end +function mixeddensity! end +function ppcplot end +function ppcplot! end +function ridgelineplot end +function ridgelineplot! end +function traceplot end +function traceplot! end +# Note that other functions are provided by other libraries. In particular: +# Plots.histogram +# Plots.density +# StatsPlots.cornerplot + +end # AbstractMCMC.Chains diff --git a/src/interface.jl b/src/interface.jl index 902424d2..98c4d858 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -1,24 +1,3 @@ -""" - chainscat(c::AbstractChains...) - -Concatenate multiple chains. - -By default, the chains are concatenated along the third dimension by calling -`cat(c...; dims=3)`. -""" -chainscat(c::AbstractChains...) = cat(c...; dims=3) - -""" - chainsstack(c::AbstractVector) - -Stack chains in `c`. - -By default, the vector of chains is returned unmodified. If `eltype(c) <: AbstractChains`, -then `reduce(chainscat, c)` is called. -""" -chainsstack(c) = c -chainsstack(c::AbstractVector{<:AbstractChains}) = reduce(chainscat, c) - """ bundle_samples(samples, model, sampler, state, chain_type[; kwargs...]) diff --git a/test/chains.jl b/test/chains.jl new file mode 100644 index 00000000..38c7e5e9 --- /dev/null +++ b/test/chains.jl @@ -0,0 +1,26 @@ +module AbstractMCMCChainsTests + +using AbstractMCMC: AbstractMCMC +using Test + +# This is a test mock: it minimally satisfies the AbstractChains interface. We use this to +# test our `test_interface` function, i.e., to ensure that something that satisfies the +# interface passes the test. +# See: https://invenia.github.io/blog/2020/11/06/interfacetesting/ +struct MockChain <: AbstractMCMC.AbstractChains + iter_indices::Vector{Int} + chain_indices::Vector{Int} + data::Dict{Symbol,Matrix{Float64}} +end +const MOCK = MockChain(1:10, 1:3, Dict(:param1 => rand(10, 3), :param2 => rand(10, 3))) +AbstractMCMC.Chains.iter_indices(c::MockChain) = c.iter_indices +AbstractMCMC.Chains.chain_indices(c::MockChain) = c.chain_indices +Base.size(c::MockChain) = (AbstractMCMC.Chains.niters(c), AbstractMCMC.Chains.nchains(c)) +Base.keys(c::MockChain) = keys(c.data) +AbstractMCMC.Chains.get_data(c::MockChain, k::Symbol) = c.data[k] + +@testset "AbstractChains interface" begin + AbstractMCMC.Chains.test_interface(MOCK) +end + +end diff --git a/test/runtests.jl b/test/runtests.jl index 909ae8b3..64fcfe85 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -24,4 +24,5 @@ include("utils.jl") include("stepper.jl") include("transducer.jl") include("logdensityproblems.jl") + include("chains.jl") end