1
1
module AdvancedMH
2
2
3
3
# Import the relevant libraries.
4
- using Reexport
5
4
using AbstractMCMC
6
- using Distributions
7
5
using Random
6
+ using Requires
7
+ using Distributions
8
8
9
9
# Import specific functions and types to use or overload.
10
- import MCMCChains: Chains
11
10
import AbstractMCMC: step!, AbstractSampler, AbstractTransition, transition_type, bundle_samples
12
11
13
12
# Exports
@@ -22,7 +21,7 @@ abstract type ProposalStyle end
22
21
"""
23
22
DensityModel{F<:Function} <: AbstractModel
24
23
25
- `DensityModel` wraps around a self-contained log-liklihood function `ℓπ `.
24
+ `DensityModel` wraps around a self-contained log-liklihood function `logdensity `.
26
25
27
26
Example:
28
27
@@ -32,52 +31,69 @@ DensityModel
32
31
```
33
32
"""
34
33
struct DensityModel{F<: Function } <: AbstractModel
35
- ℓπ :: F
34
+ logdensity :: F
36
35
end
37
36
38
37
# Create a very basic Transition type, only stores the
39
38
# parameter draws and the log probability of the draw.
40
- struct Transition{T<: Union{Vector{<:Real}, <: Real} , L<: Real } <: AbstractTransition
41
- θ :: T
39
+ struct Transition{T<: Union{Vector, Real, NamedTuple } , L<: Real } <: AbstractTransition
40
+ params :: T
42
41
lp :: L
43
42
end
44
43
45
44
# Store the new draw and its log density.
46
- Transition (model:: M , θ :: T ) where {M<: DensityModel , T} = Transition (θ, ℓπ (model, θ ))
45
+ Transition (model:: M , params :: T ) where {M<: DensityModel , T} = Transition (params, logdensity (model, params ))
47
46
48
47
# Tell the interface what transition type we would like to use.
49
- transition_type (model:: DensityModel , spl:: Metropolis ) = typeof (Transition (spl. init_θ, ℓπ (model, spl. init_θ )))
48
+ transition_type (model:: DensityModel , spl:: Metropolis ) = typeof (Transition (spl. init_params, logdensity (model, spl. init_params )))
50
49
51
50
# Calculate the density of the model given some parameterization.
52
- ℓπ (model:: DensityModel , θ :: T ) where T = model. ℓπ (θ )
53
- ℓπ (model:: DensityModel , t:: Transition ) = t. lp
51
+ logdensity (model:: DensityModel , params) = model. logdensity (params )
52
+ logdensity (model:: DensityModel , t:: Transition ) = t. lp
54
53
55
54
# A basic chains constructor that works with the Transition struct we defined.
56
55
function bundle_samples (
57
56
rng:: AbstractRNG ,
58
- ℓ :: DensityModel ,
57
+ model :: DensityModel ,
59
58
s:: Metropolis ,
60
59
N:: Integer ,
61
60
ts:: Vector{T} ;
62
61
param_names= missing ,
63
62
kwargs...
64
63
) where {ModelType<: AbstractModel , T<: AbstractTransition }
65
- # Turn all the transitions into a vector-of-vectors.
66
- vals = copy ( reduce (hcat,[ vcat (t . θ, t . lp) for t in ts]) ' )
64
+ return ts
65
+ end
67
66
67
+ function bundle_samples (
68
+ rng:: AbstractRNG ,
69
+ model:: DensityModel ,
70
+ s:: Metropolis ,
71
+ N:: Integer ,
72
+ ts:: Vector{T} ,
73
+ chain_type:: Type{NamedTuple} ;
74
+ param_names= missing ,
75
+ kwargs...
76
+ ) where {ModelType<: AbstractModel , T<: AbstractTransition }
68
77
# Check if we received any parameter names.
69
78
if ismissing (param_names)
70
- param_names = [" Parameter $i " for i in 1 : length (s. init_θ )]
79
+ param_names = [" param_ $i " for i in 1 : length (s. init_params )]
71
80
else
72
81
# Deepcopy to be thread safe.
73
82
param_names = deepcopy (param_names)
74
83
end
75
84
76
- # Add the log density field to the parameter names.
77
85
push! (param_names, " lp" )
78
86
79
- # Bundle everything up and return a Chains struct.
80
- return Chains (vals, param_names, (internals= [" lp" ],))
87
+ # Turn all the transitions into a vector-of-NamedTuple.
88
+ keys = tuple (Symbol .(param_names)... )
89
+ nts = [NamedTuple {keys} (tuple (t. params... , t. lp)) for t in ts]
90
+
91
+ return nts
92
+ end
93
+
94
+ function __init__ ()
95
+ @require MCMCChains= " c7f686f2-ff18-58e9-bc7b-31028e88f75d" include (" mcmcchains-connect.jl" )
96
+ @require StructArrays= " 09ab397b-f2b6-538f-b94a-2f83cf4a842a" include (" structarray-connect.jl" )
81
97
end
82
98
83
99
# Include inference methods.
0 commit comments