Skip to content

Commit c09ba13

Browse files
committed
Make FlexiChains the default chain type
1 parent 8b84230 commit c09ba13

25 files changed

+316
-298
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
2020
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
2121
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
2222
EllipticalSliceSampling = "cad2338a-1db2-11e9-3401-43bc07c9ede2"
23+
FlexiChains = "4a37a8b9-6e57-4b92-8664-298d46e639f7"
2324
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2425
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
2526
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -64,6 +65,7 @@ DocStringExtensions = "0.8, 0.9"
6465
DynamicHMC = "3.4"
6566
DynamicPPL = "0.39.1"
6667
EllipticalSliceSampling = "0.5, 1, 2"
68+
FlexiChains = "0.3.1"
6769
ForwardDiff = "0.10.3, 1"
6870
Libtask = "0.9.3"
6971
LinearAlgebra = "1"

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ links = InterLinks(
1010
"AbstractMCMC" => "https://turinglang.org/AbstractMCMC.jl/stable/",
1111
"ADTypes" => "https://sciml.github.io/ADTypes.jl/stable/",
1212
"AdvancedVI" => "https://turinglang.org/AdvancedVI.jl/stable/",
13+
"FlexiChains" => "https://pysm.dev/FlexiChains.jl/stable/",
1314
"DistributionsAD" => "https://turinglang.org/DistributionsAD.jl/stable/",
1415
"OrderedCollections" => "https://juliacollections.github.io/OrderedCollections.jl/stable/",
1516
"Distributions" => "https://juliastats.org/Distributions.jl/stable/",

docs/src/api.md

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,8 @@
22

33
## Module-wide re-exports
44

5-
Turing.jl directly re-exports the entire public API of the following packages:
6-
7-
- [Distributions.jl](https://juliastats.org/Distributions.jl)
8-
- [MCMCChains.jl](https://turinglang.org/MCMCChains.jl)
9-
10-
Please see the individual packages for their documentation.
5+
Turing.jl directly re-exports the entire public API of [Distributions.jl](https://juliastats.org/Distributions.jl).
6+
Please see its documentation for more details.
117

128
## Individual exports and re-exports
139

@@ -47,13 +43,14 @@ even though [`Prior()`](@ref) is actually defined in the `Turing.Inference` modu
4743

4844
### Inference
4945

50-
| Exported symbol | Documentation | Description |
51-
|:----------------- |:------------------------------------------------------------------------- |:----------------------------------------- |
52-
| `sample` | [`StatsBase.sample`](https://turinglang.org/docs/usage/sampling-options/) | Sample from a model |
53-
| `MCMCThreads` | [`AbstractMCMC.MCMCThreads`](@extref) | Run MCMC using multiple threads |
54-
| `MCMCDistributed` | [`AbstractMCMC.MCMCDistributed`](@extref) | Run MCMC using multiple processes |
55-
| `MCMCSerial` | [`AbstractMCMC.MCMCSerial`](@extref) | Run MCMC using without parallelism |
56-
| `loadstate` | [`Turing.Inference.loadstate`](@ref) | Load saved state from `MCMCChains.Chains` |
46+
| Exported symbol | Documentation | Description |
47+
|:----------------- |:------------------------------------------------------------------------- |:----------------------------------- |
48+
| `sample` | [`StatsBase.sample`](https://turinglang.org/docs/usage/sampling-options/) | Sample from a model |
49+
| `MCMCThreads` | [`AbstractMCMC.MCMCThreads`](@extref) | Run MCMC using multiple threads |
50+
| `MCMCDistributed` | [`AbstractMCMC.MCMCDistributed`](@extref) | Run MCMC using multiple processes |
51+
| `MCMCSerial` | [`AbstractMCMC.MCMCSerial`](@extref) | Run MCMC using without parallelism |
52+
| `loadstate` | [`Turing.Inference.loadstate`](@ref) | Load saved state from an MCMC chain |
53+
| `VNChain` | n/a | Alias for `FlexiChain{VarName}` |
5754

5855
### Samplers
5956

src/Turing.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using Reexport, ForwardDiff
44
using DistributionsAD, Bijectors, StatsFuns, SpecialFunctions
55
using Statistics, LinearAlgebra
66
using Libtask
7-
@reexport using Distributions, MCMCChains
7+
@reexport using Distributions
88
using Compat: pkgversion
99

1010
using AdvancedVI: AdvancedVI
@@ -16,6 +16,7 @@ using Accessors: Accessors
1616
using StatsAPI: StatsAPI
1717
using StatsBase: StatsBase
1818
using AbstractMCMC
19+
using FlexiChains
1920

2021
using Accessors: Accessors
2122

@@ -172,6 +173,8 @@ export
172173
MAP,
173174
MLE,
174175
# Chain save/resume
175-
loadstate
176+
loadstate,
177+
# FlexiChains re-export
178+
VNChain
176179

177180
end

src/mcmc/Inference.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ using DynamicPPL:
2525
DefaultContext
2626
using Distributions, Libtask, Bijectors
2727
using DistributionsAD: VectorOfMultivariate
28+
using FlexiChains: FlexiChains, VNChain
2829
using LinearAlgebra
2930
using ..Turing: PROGRESS, Turing
3031
using StatsFuns: logsumexp
@@ -46,7 +47,6 @@ import Accessors
4647
import EllipticalSliceSampling
4748
import LogDensityProblems
4849
import Random
49-
import MCMCChains
5050
import StatsBase: predict
5151

5252
export Hamiltonian,
@@ -78,7 +78,7 @@ export Hamiltonian,
7878
# Generic AbstractMCMC methods dispatch #
7979
#########################################
8080

81-
const DEFAULT_CHAIN_TYPE = MCMCChains.Chains
81+
const DEFAULT_CHAIN_TYPE = VNChain
8282
include("abstractmcmc.jl")
8383

8484
####################

src/mcmc/abstractmcmc.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using MCMCChains: MCMCChains
2+
13
# TODO: Implement additional checks for certain samplers, e.g.
24
# HMC not supporting discrete parameters.
35
function _check_model(model::DynamicPPL.Model)

src/mcmc/emcee.jl

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -99,20 +99,27 @@ function AbstractMCMC.step(
9999
return transition, newstate
100100
end
101101

102-
function AbstractMCMC.bundle_samples(
103-
samples::Vector{<:Vector},
104-
model::AbstractModel,
105-
spl::Emcee,
106-
state::EmceeState,
107-
chain_type::Type{MCMCChains.Chains};
108-
kwargs...,
109-
)
110-
n_walkers = _get_n_walkers(spl)
111-
chains = map(1:n_walkers) do i
112-
this_walker_samples = [s[i] for s in samples]
113-
AbstractMCMC.bundle_samples(
114-
this_walker_samples, model, spl, state, chain_type; kwargs...
102+
# Have to define methods for both to avoid method ambiguities (as opposed to a single
103+
# `::Type{T<:AbstractMCMC.AbstractChains})` since default `bundle_samples` takes
104+
# `samples::AbstractVector`.
105+
for Tchain in (:(MCMCChains.Chains), :(FlexiChains.VNChain))
106+
@eval begin
107+
function AbstractMCMC.bundle_samples(
108+
samples::Vector{<:Vector},
109+
model::DynamicPPL.Model,
110+
spl::Emcee,
111+
state::EmceeState,
112+
::Type{$Tchain};
113+
kwargs...,
115114
)
115+
n_walkers = _get_n_walkers(spl)
116+
chains = map(1:n_walkers) do i
117+
this_walker_samples = [s[i] for s in samples]
118+
AbstractMCMC.bundle_samples(
119+
this_walker_samples, model, spl, state, $Tchain; kwargs...
120+
)
121+
end
122+
return AbstractMCMC.chainscat(chains...)
123+
end
116124
end
117-
return AbstractMCMC.chainscat(chains...)
118125
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
1515
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
1616
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
1717
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
18+
FlexiChains = "4a37a8b9-6e57-4b92-8664-298d46e639f7"
1819
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1920
HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5"
2021
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

0 commit comments

Comments
 (0)