Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
keywords = ["markov chain monte carlo", "probabilistic programming"]
license = "MIT"
desc = "A lightweight interface for common MCMC methods."
version = "5.8.2"
version = "5.9.0"

[deps]
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Expand All @@ -17,6 +17,7 @@ ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

Expand All @@ -29,6 +30,7 @@ LoggingExtras = "0.4, 0.5, 1"
ProgressLogging = "0.1"
StatsBase = "0.32, 0.33, 0.34"
TerminalLoggers = "0.1"
Test = "1"
Transducers = "0.4.30"
UUIDs = "<0.0.1, 1"
julia = "1.6"
Expand Down
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

Expand Down
13 changes: 12 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,24 @@ AbstractMCMC defines the abstract type `AbstractChains` for Markov chains.
AbstractMCMC.AbstractChains
```

For chains of this type, AbstractMCMC defines the following two methods.
For chains of this type, AbstractMCMC defines the following two **public** methods.

```@docs
AbstractMCMC.chainscat
AbstractMCMC.chainsstack
```

The following interface methods are considered experimental and may change even in formally non-breaking releases.

```@docs
AbstractMCMC.Chains.get_data
AbstractMCMC.Chains.iter_indices
AbstractMCMC.Chains.chain_indices
AbstractMCMC.Chains.niters
AbstractMCMC.Chains.nchains
AbstractMCMC.Chains.test_interface
```

## Interacting with states of samplers

To make it a bit easier to interact with some arbitrary sampler state, we encourage implementations of `AbstractSampler` to implement the following methods:
Expand Down
9 changes: 1 addition & 8 deletions src/AbstractMCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,6 @@ export sample
# Parallel sampling types
export MCMCThreads, MCMCDistributed, MCMCSerial

"""
AbstractChains

`AbstractChains` is an abstract type for an object that stores
parameter samples generated through a MCMC process.
"""
abstract type AbstractChains end

"""
AbstractSampler

Expand Down Expand Up @@ -137,6 +129,7 @@ function setparams!!(model::AbstractModel, state, params)
return setparams!!(state, params)
end

include("chains.jl")
include("samplingstats.jl")
include("logging.jl")
include("interface.jl")
Expand Down
61 changes: 61 additions & 0 deletions src/chains.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# AbstractChains interface
#
# NOTE: The entire interface is treated as experimental except for the AbstractChains type
# itself, along with `chainscat` and `chainsstack`. Thus, if you change any of those three,
# it is mandatory to release a breaking version. Other changes to the AbstractChains
# interface can be made in patch releases.

"""
AbstractMCMC.AbstractChains

An abstract type for Markov chains, i.e., a data structure which stores samples
obtained from Markov chain Monte Carlo (MCMC) sampling.

!!! danger "Explicitly experimental"

Although the abstract type `AbstractMCMC.AbstractChains` itself, along with the
functions `chainscat` and `chainsstack`, are exported and public, please note that *all
other parts of the interface remain experimental and subject to change*. In particular,
breaking changes to the interface may be introduced in formally non-breaking releases.

Markov chains should generally have dictionary-like behaviour, where keys are mapped to
matrices of values.

## Interface

To implement a new subtype of `AbstractChains`, you need to define the following methods:

- `Base.size` should return a tuple of ints (the exact meaning is left to you)
- `Base.keys` should return a list of keys
- [`AbstractMCMC.Chains.get_data`](@ref)`(chn, key)`
- [`AbstractMCMC.Chains.iter_indices`](@ref)`(chn)`
- [`AbstractMCMC.Chains.chain_indices`](@ref)`(chn)`

You can optionally define the following methods for efficiency:

- [`AbstractMCMC.Chains.niters`](@ref)`(chn)`
- [`AbstractMCMC.Chains.nchains`](@ref)`(chn)`
"""
abstract type AbstractChains end

"""
chainscat(c::AbstractChains...)

Concatenate multiple chains.

By default, the chains are concatenated along the third dimension by calling
`cat(c...; dims=3)`.
"""
chainscat(c::AbstractChains...) = cat(c...; dims=3)

"""
chainsstack(c::AbstractVector)

Stack chains in `c`.

By default, the vector of chains is returned unmodified. If `eltype(c) <: AbstractChains`,
then `reduce(chainscat, c)` is called.
"""
chainsstack(c) = c
chainsstack(c::AbstractVector{<:AbstractChains}) = reduce(chainscat, c)
include("experimental/chains.jl")
123 changes: 123 additions & 0 deletions src/experimental/chains.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
module Chains

using AbstractMCMC: AbstractMCMC, AbstractChains
using Test

"""
AbstractMCMC.Chains.get_data(chn, key)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this function always returns AbstractMatrix, maybe rename it to to_array for clarity:

Suggested change
AbstractMCMC.Chains.get_data(chn, key)
AbstractMCMC.Chains.to_array(chn, key)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's more like a dictionary get because it requires a key. IMO, to_array sounds like you're converting the entire chain to an array.


Obtain the data associated with `key` from the `AbstractChain` object `chn`.

This function should return an `AbstractMatrix` where the rows correspond to iterations and
columns correspond to chains.
"""
function get_data end

"""
AbstractMCMC.Chains.iter_indices(chn)

Obtain the indices of each iteration for the `AbstractChains` object `chn`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not very clear to me. Maybe add a working example to this and related APIs below for readability?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really see how to add an example without a concrete implementation. I'll flesh out the docstring though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added more words!


For example, if `chn` contains 1000 samples, but 1000 warmup steps and a thinning factor of
2 was used, then this function should return `1001:2:3000` (or an equivalent vector).

This function should return an `AbstractVector{<:Integer}`.
"""
function iter_indices end

"""
AbstractMCMC.Chains.chain_indices(chn)

Obtain the indices of each chain in the `AbstractChains` object `chn`.

If there is no special numbering associated with chains, then this function can simply
return `1:nchains(chn)`. However, this function provides the flexibility to have
non-standard chain numbering (e.g. if chains are combined from multiple sources).

This function should return an `AbstractVector{<:Integer}`.
"""
function chain_indices end

"""
AbstractMCMC.Chains.niters(chn)

Obtain the number of iterations in the `AbstractChains` object `chn`.

The default implementation calculates the length of `AbstractChains.iter_indices(chn)`. You
can define your own method if you have a more efficient way of obtaining this information.
"""
niters(c::AbstractChains) = length(iter_indices(c))

"""
AbstractMCMC.Chains.nchains(chn)

Obtain the number of chains in the `AbstractChains` object `chn`.

The default implementation calculates the length of `AbstractChains.chain_indices(chn)`. You
can define your own method if you have a more efficient way of obtaining this information.
"""
nchains(c::AbstractChains) = length(chain_indices(c))

"""
AbstractMCMC.Chains.test_interface(chn)

Test that the `AbstractChains` object `chn` implements the required interface.
"""
function test_interface(chn::AbstractChains)
# TODO: Test chainscat, chainsstack

@testset "Base.size, AbstractMCMC.Chains.niters, AbstractMCMC.Chains.nchains" begin
@test size(chn) isa NTuple{N,Int} where {N}
@test AbstractMCMC.Chains.niters(chn) isa Int
@test AbstractMCMC.Chains.nchains(chn) isa Int
end

@testset "Base.keys" begin
@test collect(keys(chn)) isa AbstractVector
end

@testset "AbstractMCMC.Chains.get_data" begin
for k in keys(chn)
data = AbstractMCMC.Chains.get_data(chn, k)
@test data isa AbstractMatrix
@test size(data) ==
(AbstractMCMC.Chains.niters(chn), AbstractMCMC.Chains.nchains(chn))
end
end

@testset "AbstractMCMC.Chains.iter_indices" begin
ii = AbstractMCMC.Chains.iter_indices(chn)
@test ii isa AbstractVector{<:Integer}
@test length(ii) == AbstractMCMC.Chains.niters(chn)
end

@testset "AbstractMCMC.Chains.chain_indices" begin
ci = AbstractMCMC.Chains.chain_indices(chn)
@test ci isa AbstractVector{<:Integer}
@test length(ci) == AbstractMCMC.Chains.nchains(chn)
end
end

# Plotting functions; to be extended by individual chain libraries
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure we want to define the plotting APIs in AbstractMCMC. Let's focus on Chain type and core interface only.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem here is if AbstractMCMC owns the chains type, then any plotting recipe that uses it also has to belong in AbstractMCMC. Otherwise, if you shifted the plotting functionality to a separate library, then defining something like plot(chn::AbstractChains) or density(chn::AbstractChains) would be type piracy, since AbstractChains is owned by AbstractMCMC.jl, and plot or density are owned by Plots.jl. This is the same issue that ChainsMakie has/had. If you wanted to move the plotting functionality to a separate library, then I would strongly suggest that everything should be in that separate library i.e. just make AbstractChains.jl hold all of it.

Of course, there are differing levels of type piracy. Maybe it's not so bad for a package in TuringLang to pirate a type that is owned by a separate package in TuringLang. But I still think it's bad.

function autocorplot end
function autocorplot! end
function energyplot end
function energyplot! end
function forestplot end
function forestplot! end
function meanplot end
function meanplot! end
function mixeddensity end
function mixeddensity! end
function ppcplot end
function ppcplot! end
function ridgelineplot end
function ridgelineplot! end
function traceplot end
function traceplot! end
# Note that other functions are provided by other libraries. In particular:
# Plots.histogram
# Plots.density
# StatsPlots.cornerplot

end # AbstractMCMC.Chains
21 changes: 0 additions & 21 deletions src/interface.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,3 @@
"""
chainscat(c::AbstractChains...)

Concatenate multiple chains.

By default, the chains are concatenated along the third dimension by calling
`cat(c...; dims=3)`.
"""
chainscat(c::AbstractChains...) = cat(c...; dims=3)

"""
chainsstack(c::AbstractVector)

Stack chains in `c`.

By default, the vector of chains is returned unmodified. If `eltype(c) <: AbstractChains`,
then `reduce(chainscat, c)` is called.
"""
chainsstack(c) = c
chainsstack(c::AbstractVector{<:AbstractChains}) = reduce(chainscat, c)

"""
bundle_samples(samples, model, sampler, state, chain_type[; kwargs...])

Expand Down
26 changes: 26 additions & 0 deletions test/chains.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
module AbstractMCMCChainsTests

using AbstractMCMC: AbstractMCMC
using Test

# This is a test mock: it minimally satisfies the AbstractChains interface. We use this to
# test our `test_interface` function, i.e., to ensure that something that satisfies the
# interface passes the test.
# See: https://invenia.github.io/blog/2020/11/06/interfacetesting/
struct MockChain <: AbstractMCMC.AbstractChains
iter_indices::Vector{Int}
chain_indices::Vector{Int}
data::Dict{Symbol,Matrix{Float64}}
end
const MOCK = MockChain(1:10, 1:3, Dict(:param1 => rand(10, 3), :param2 => rand(10, 3)))
AbstractMCMC.Chains.iter_indices(c::MockChain) = c.iter_indices
AbstractMCMC.Chains.chain_indices(c::MockChain) = c.chain_indices
Base.size(c::MockChain) = (AbstractMCMC.Chains.niters(c), AbstractMCMC.Chains.nchains(c))
Base.keys(c::MockChain) = keys(c.data)
AbstractMCMC.Chains.get_data(c::MockChain, k::Symbol) = c.data[k]

@testset "AbstractChains interface" begin
AbstractMCMC.Chains.test_interface(MOCK)
end

end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ include("utils.jl")
include("stepper.jl")
include("transducer.jl")
include("logdensityproblems.jl")
include("chains.jl")
end
Loading