Skip to content

Commit a16fa77

Browse files
committed
Remove MCMCCHains by default
1 parent c8946a4 commit a16fa77

File tree

2 files changed

+64
-18
lines changed

2 files changed

+64
-18
lines changed

src/AdvancedMH.jl

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
module AdvancedMH
22

33
# Import the relevant libraries.
4-
using Reexport
54
using AbstractMCMC
6-
using Distributions
75
using Random
6+
using Requires
7+
using Distributions
88

99
# Import specific functions and types to use or overload.
10-
import MCMCChains: Chains
1110
import AbstractMCMC: step!, AbstractSampler, AbstractTransition, transition_type, bundle_samples
1211

1312
# Exports
@@ -22,7 +21,7 @@ abstract type ProposalStyle end
2221
"""
2322
DensityModel{F<:Function} <: AbstractModel
2423
25-
`DensityModel` wraps around a self-contained log-liklihood function `ℓπ`.
24+
`DensityModel` wraps around a self-contained log-liklihood function `logdensity`.
2625
2726
Example:
2827
@@ -32,52 +31,69 @@ DensityModel
3231
```
3332
"""
3433
struct DensityModel{F<:Function} <: AbstractModel
35-
ℓπ :: F
34+
logdensity :: F
3635
end
3736

3837
# Create a very basic Transition type, only stores the
3938
# parameter draws and the log probability of the draw.
40-
struct Transition{T<:Union{Vector{<:Real}, <:Real}, L<:Real} <: AbstractTransition
41-
θ :: T
39+
struct Transition{T<:Union{Vector, Real, NamedTuple}, L<:Real} <: AbstractTransition
40+
params :: T
4241
lp :: L
4342
end
4443

4544
# Store the new draw and its log density.
46-
Transition(model::M, θ::T) where {M<:DensityModel, T} = Transition(θ, ℓπ(model, θ))
45+
Transition(model::M, params::T) where {M<:DensityModel, T} = Transition(params, logdensity(model, params))
4746

4847
# Tell the interface what transition type we would like to use.
49-
transition_type(model::DensityModel, spl::Metropolis) = typeof(Transition(spl.init_θ, ℓπ(model, spl.init_θ)))
48+
transition_type(model::DensityModel, spl::Metropolis) = typeof(Transition(spl.init_params, logdensity(model, spl.init_params)))
5049

5150
# Calculate the density of the model given some parameterization.
52-
ℓπ(model::DensityModel, θ::T) where T = model.ℓπ)
53-
ℓπ(model::DensityModel, t::Transition) = t.lp
51+
logdensity(model::DensityModel, params) = model.logdensity(params)
52+
logdensity(model::DensityModel, t::Transition) = t.lp
5453

5554
# A basic chains constructor that works with the Transition struct we defined.
5655
function bundle_samples(
5756
rng::AbstractRNG,
58-
::DensityModel,
57+
model::DensityModel,
5958
s::Metropolis,
6059
N::Integer,
6160
ts::Vector{T};
6261
param_names=missing,
6362
kwargs...
6463
) where {ModelType<:AbstractModel, T<:AbstractTransition}
65-
# Turn all the transitions into a vector-of-vectors.
66-
vals = copy(reduce(hcat,[vcat(t.θ, t.lp) for t in ts])')
64+
return ts
65+
end
6766

67+
function bundle_samples(
68+
rng::AbstractRNG,
69+
model::DensityModel,
70+
s::Metropolis,
71+
N::Integer,
72+
ts::Vector{T},
73+
chain_type::Type{NamedTuple};
74+
param_names=missing,
75+
kwargs...
76+
) where {ModelType<:AbstractModel, T<:AbstractTransition}
6877
# Check if we received any parameter names.
6978
if ismissing(param_names)
70-
param_names = ["Parameter $i" for i in 1:length(s.init_θ)]
79+
param_names = ["param_$i" for i in 1:length(s.init_params)]
7180
else
7281
# Deepcopy to be thread safe.
7382
param_names = deepcopy(param_names)
7483
end
7584

76-
# Add the log density field to the parameter names.
7785
push!(param_names, "lp")
7886

79-
# Bundle everything up and return a Chains struct.
80-
return Chains(vals, param_names, (internals=["lp"],))
87+
# Turn all the transitions into a vector-of-NamedTuple.
88+
keys = tuple(Symbol.(param_names)...)
89+
nts = [NamedTuple{keys}(tuple(t.params..., t.lp)) for t in ts]
90+
91+
return nts
92+
end
93+
94+
function __init__()
95+
@require MCMCChains="c7f686f2-ff18-58e9-bc7b-31028e88f75d" include("mcmcchains-connect.jl")
96+
@require StructArrays="09ab397b-f2b6-538f-b94a-2f83cf4a842a" include("structarray-connect.jl")
8197
end
8298

8399
# Include inference methods.

src/mcmcchains-connect.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import MCMCChains: Chains
2+
3+
# A basic chains constructor that works with the Transition struct we defined.
4+
function bundle_samples(
5+
rng::AbstractRNG,
6+
model::DensityModel,
7+
s::Metropolis,
8+
N::Integer,
9+
ts::Vector{T},
10+
chain_type::Type{Chains};
11+
param_names=missing,
12+
kwargs...
13+
) where {ModelType<:AbstractModel, T<:AbstractTransition}
14+
# Turn all the transitions into a vector-of-vectors.
15+
vals = copy(reduce(hcat,[vcat(t.params, t.lp) for t in ts])')
16+
17+
# Check if we received any parameter names.
18+
if ismissing(param_names)
19+
param_names = ["Parameter $i" for i in 1:length(s.init_params)]
20+
else
21+
# Deepcopy to be thread safe.
22+
param_names = deepcopy(param_names)
23+
end
24+
25+
# Add the log density field to the parameter names.
26+
push!(param_names, "lp")
27+
28+
# Bundle everything up and return a Chains struct.
29+
return Chains(vals, param_names, (internals=["lp"],))
30+
end

0 commit comments

Comments
 (0)