-
Notifications
You must be signed in to change notification settings - Fork 19
AbstractChains interface #180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
44b50c6
2045d4f
873df97
a587aac
2ac37ad
522917a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| # AbstractChains interface | ||
| # | ||
| # NOTE: The entire interface is treated as experimental except for the AbstractChains type | ||
| # itself, along with `chainscat` and `chainsstack`. Thus, if you change any of those three, | ||
| # it is mandatory to release a breaking version. Other changes to the AbstractChains | ||
| # interface can be made in patch releases. | ||
|
|
||
| """ | ||
| AbstractMCMC.AbstractChains | ||
|
|
||
| An abstract type for Markov chains, i.e., a data structure which stores samples | ||
| obtained from Markov chain Monte Carlo (MCMC) sampling. | ||
|
|
||
| !!! danger "Explicitly experimental" | ||
|
|
||
| Although the abstract type `AbstractMCMC.AbstractChains` itself, along with the | ||
| functions `chainscat` and `chainsstack`, are exported and public, please note that *all | ||
| other parts of the interface remain experimental and subject to change*. In particular, | ||
| breaking changes to the interface may be introduced in formally non-breaking releases. | ||
|
|
||
| Markov chains should generally have dictionary-like behaviour, where keys are mapped to | ||
| matrices of values. | ||
|
|
||
| ## Interface | ||
|
|
||
| To implement a new subtype of `AbstractChains`, you need to define the following methods: | ||
|
|
||
| - `Base.size` should return a tuple of ints (the exact meaning is left to you) | ||
| - `Base.keys` should return a list of keys | ||
| - [`AbstractMCMC.Chains.get_data`](@ref)`(chn, key)` | ||
| - [`AbstractMCMC.Chains.iter_indices`](@ref)`(chn)` | ||
| - [`AbstractMCMC.Chains.chain_indices`](@ref)`(chn)` | ||
|
|
||
| You can optionally define the following methods for efficiency: | ||
|
|
||
| - [`AbstractMCMC.Chains.niters`](@ref)`(chn)` | ||
| - [`AbstractMCMC.Chains.nchains`](@ref)`(chn)` | ||
| """ | ||
| abstract type AbstractChains end | ||
|
|
||
| """ | ||
| chainscat(c::AbstractChains...) | ||
|
|
||
| Concatenate multiple chains. | ||
|
|
||
| By default, the chains are concatenated along the third dimension by calling | ||
| `cat(c...; dims=3)`. | ||
| """ | ||
| chainscat(c::AbstractChains...) = cat(c...; dims=3) | ||
|
|
||
| """ | ||
| chainsstack(c::AbstractVector) | ||
|
|
||
| Stack chains in `c`. | ||
|
|
||
| By default, the vector of chains is returned unmodified. If `eltype(c) <: AbstractChains`, | ||
| then `reduce(chainscat, c)` is called. | ||
| """ | ||
| chainsstack(c) = c | ||
| chainsstack(c::AbstractVector{<:AbstractChains}) = reduce(chainscat, c) | ||
| include("experimental/chains.jl") |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,123 @@ | ||
| module Chains | ||
|
|
||
| using AbstractMCMC: AbstractMCMC, AbstractChains | ||
| using Test | ||
|
|
||
| """ | ||
| AbstractMCMC.Chains.get_data(chn, key) | ||
|
|
||
| Obtain the data associated with `key` from the `AbstractChain` object `chn`. | ||
|
|
||
| This function should return an `AbstractMatrix` where the rows correspond to iterations and | ||
| columns correspond to chains. | ||
| """ | ||
| function get_data end | ||
|
|
||
| """ | ||
| AbstractMCMC.Chains.iter_indices(chn) | ||
|
|
||
| Obtain the indices of each iteration for the `AbstractChains` object `chn`. | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not very clear to me. Maybe add a working example to this and related APIs below for readability? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't really see how to add an example without a concrete implementation. I'll flesh out the docstring though. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added more words! |
||
|
|
||
| For example, if `chn` contains 1000 samples, but 1000 warmup steps and a thinning factor of | ||
| 2 was used, then this function should return `1001:2:3000` (or an equivalent vector). | ||
|
|
||
| This function should return an `AbstractVector{<:Integer}`. | ||
| """ | ||
| function iter_indices end | ||
|
|
||
| """ | ||
| AbstractMCMC.Chains.chain_indices(chn) | ||
|
|
||
| Obtain the indices of each chain in the `AbstractChains` object `chn`. | ||
|
|
||
| If there is no special numbering associated with chains, then this function can simply | ||
| return `1:nchains(chn)`. However, this function provides the flexibility to have | ||
| non-standard chain numbering (e.g. if chains are combined from multiple sources). | ||
|
|
||
| This function should return an `AbstractVector{<:Integer}`. | ||
| """ | ||
| function chain_indices end | ||
|
|
||
| """ | ||
| AbstractMCMC.Chains.niters(chn) | ||
|
|
||
| Obtain the number of iterations in the `AbstractChains` object `chn`. | ||
|
|
||
| The default implementation calculates the length of `AbstractChains.iter_indices(chn)`. You | ||
| can define your own method if you have a more efficient way of obtaining this information. | ||
| """ | ||
| niters(c::AbstractChains) = length(iter_indices(c)) | ||
|
|
||
| """ | ||
| AbstractMCMC.Chains.nchains(chn) | ||
|
|
||
| Obtain the number of chains in the `AbstractChains` object `chn`. | ||
|
|
||
| The default implementation calculates the length of `AbstractChains.chain_indices(chn)`. You | ||
| can define your own method if you have a more efficient way of obtaining this information. | ||
| """ | ||
| nchains(c::AbstractChains) = length(chain_indices(c)) | ||
|
|
||
| """ | ||
| AbstractMCMC.Chains.test_interface(chn) | ||
|
|
||
| Test that the `AbstractChains` object `chn` implements the required interface. | ||
| """ | ||
| function test_interface(chn::AbstractChains) | ||
| # TODO: Test chainscat, chainsstack | ||
|
|
||
| @testset "Base.size, AbstractMCMC.Chains.niters, AbstractMCMC.Chains.nchains" begin | ||
| @test size(chn) isa NTuple{N,Int} where {N} | ||
| @test AbstractMCMC.Chains.niters(chn) isa Int | ||
| @test AbstractMCMC.Chains.nchains(chn) isa Int | ||
| end | ||
|
|
||
| @testset "Base.keys" begin | ||
| @test collect(keys(chn)) isa AbstractVector | ||
| end | ||
|
|
||
| @testset "AbstractMCMC.Chains.get_data" begin | ||
| for k in keys(chn) | ||
| data = AbstractMCMC.Chains.get_data(chn, k) | ||
| @test data isa AbstractMatrix | ||
| @test size(data) == | ||
| (AbstractMCMC.Chains.niters(chn), AbstractMCMC.Chains.nchains(chn)) | ||
| end | ||
| end | ||
|
|
||
| @testset "AbstractMCMC.Chains.iter_indices" begin | ||
| ii = AbstractMCMC.Chains.iter_indices(chn) | ||
| @test ii isa AbstractVector{<:Integer} | ||
| @test length(ii) == AbstractMCMC.Chains.niters(chn) | ||
| end | ||
|
|
||
| @testset "AbstractMCMC.Chains.chain_indices" begin | ||
| ci = AbstractMCMC.Chains.chain_indices(chn) | ||
| @test ci isa AbstractVector{<:Integer} | ||
| @test length(ci) == AbstractMCMC.Chains.nchains(chn) | ||
| end | ||
| end | ||
|
|
||
| # Plotting functions; to be extended by individual chain libraries | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure we want to define the plotting APIs in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The problem here is if AbstractMCMC owns the chains type, then any plotting recipe that uses it also has to belong in AbstractMCMC. Otherwise, if you shifted the plotting functionality to a separate library, then defining something like Of course, there are differing levels of type piracy. Maybe it's not so bad for a package in TuringLang to pirate a type that is owned by a separate package in TuringLang. But I still think it's bad. |
||
| function autocorplot end | ||
| function autocorplot! end | ||
| function energyplot end | ||
| function energyplot! end | ||
| function forestplot end | ||
| function forestplot! end | ||
| function meanplot end | ||
| function meanplot! end | ||
| function mixeddensity end | ||
| function mixeddensity! end | ||
| function ppcplot end | ||
| function ppcplot! end | ||
| function ridgelineplot end | ||
| function ridgelineplot! end | ||
| function traceplot end | ||
| function traceplot! end | ||
| # Note that other functions are provided by other libraries. In particular: | ||
| # Plots.histogram | ||
| # Plots.density | ||
| # StatsPlots.cornerplot | ||
|
|
||
| end # AbstractMCMC.Chains | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| module AbstractMCMCChainsTests | ||
|
|
||
| using AbstractMCMC: AbstractMCMC | ||
| using Test | ||
|
|
||
| # This is a test mock: it minimally satisfies the AbstractChains interface. We use this to | ||
| # test our `test_interface` function, i.e., to ensure that something that satisfies the | ||
| # interface passes the test. | ||
| # See: https://invenia.github.io/blog/2020/11/06/interfacetesting/ | ||
| struct MockChain <: AbstractMCMC.AbstractChains | ||
| iter_indices::Vector{Int} | ||
| chain_indices::Vector{Int} | ||
| data::Dict{Symbol,Matrix{Float64}} | ||
| end | ||
| const MOCK = MockChain(1:10, 1:3, Dict(:param1 => rand(10, 3), :param2 => rand(10, 3))) | ||
| AbstractMCMC.Chains.iter_indices(c::MockChain) = c.iter_indices | ||
| AbstractMCMC.Chains.chain_indices(c::MockChain) = c.chain_indices | ||
| Base.size(c::MockChain) = (AbstractMCMC.Chains.niters(c), AbstractMCMC.Chains.nchains(c)) | ||
| Base.keys(c::MockChain) = keys(c.data) | ||
| AbstractMCMC.Chains.get_data(c::MockChain, k::Symbol) = c.data[k] | ||
|
|
||
| @testset "AbstractChains interface" begin | ||
| AbstractMCMC.Chains.test_interface(MOCK) | ||
| end | ||
|
|
||
| end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this function always returns
AbstractMatrix, maybe rename it toto_arrayfor clarity:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's more like a dictionary
getbecause it requires a key. IMO,to_arraysounds like you're converting the entire chain to an array.