Skip to content

Commit e174117

Browse files
authored
Merge pull request #70 from TuringLang/torfjelde/logdensitymodel
Added support for models using the LogDensityProblems.jl interface
2 parents 5950f4a + f58c636 commit e174117

File tree

10 files changed

+122
-35
lines changed

10 files changed

+122
-35
lines changed

Project.toml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
name = "AdvancedMH"
22
uuid = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
3-
version = "0.6.8"
3+
version = "0.7.0"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
77
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
8+
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
89
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
910
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1011

1112
[compat]
12-
AbstractMCMC = "2, 3.0, 4"
13+
AbstractMCMC = "4"
1314
Distributions = "0.20, 0.21, 0.22, 0.23, 0.24, 0.25"
1415
Requires = "1"
1516
julia = "1"
@@ -18,9 +19,11 @@ julia = "1"
1819
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
1920
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2021
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
22+
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
23+
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
2124
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
2225
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
2326
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2427

2528
[targets]
26-
test = ["DiffResults", "ForwardDiff", "LinearAlgebra", "MCMCChains", "StructArrays", "Test"]
29+
test = ["DiffResults", "ForwardDiff", "LinearAlgebra", "LogDensityProblems", "LogDensityProblemsAD", "MCMCChains", "StructArrays", "Test"]

README.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,24 @@ Quantiles
6868

6969
```
7070

71+
### Usage with [`LogDensityProblems.jl`](https://github.com/tpapp/LogDensityProblems.jl)
72+
73+
It can also be used with models defining the [`LogDensityProblems.jl`](https://github.com/tpapp/LogDensityProblems.jl) interface by wrapping it in `AbstractMCMC.LogDensityModel` before passing it to `sample`:
74+
75+
``` julia
76+
using AbstractMCMC: LogDensityModel
77+
using LogDensityProblems
78+
79+
# Use a struct instead of `typeof(density)` for sake of readability.
80+
struct LogTargetDensity end
81+
82+
LogDensityProblems.logdensity(p::LogTargetDensity, θ) = density(θ) # standard multivariate normal
83+
LogDensityProblems.dimension(p::LogTargetDensity) = 2
84+
LogDensityProblems.capabilities(::LogTargetDensity) = LogDensityProblems.LogDensityOrder{0}()
85+
86+
sample(LogDensityModel(LogTargetDensity()), spl, 100000; param_names=["μ", "σ"], chain_type=Chains)
87+
```
88+
7189
## Proposals
7290

7391
AdvancedMH offers various methods of defining your inference problem. Behind the scenes, a `MetropolisHastings` sampler simply holds
@@ -162,3 +180,13 @@ spl = MALA(x -> MvNormal((σ² / 2) .* x, σ² * I))
162180
# Sample from the posterior.
163181
chain = sample(model, spl, 100000; init_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])
164182
```
183+
184+
### Usage with [`LogDensityProblemsAD.jl`](https://github.com/tpapp/LogDensityProblemsAD.jl)
185+
186+
Using our implementation of the `LogDensityProblems.jl` interface from earlier, we can use [`LogDensityProblemsAD.jl`](https://github.com/tpapp/LogDensityProblemsAD.jl) to provide us with the gradient computation used in MALA:
187+
188+
```julia
189+
using LogDensityProblemsAD
190+
model_with_ad = LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), LogTargetDensity())
191+
sample(LogDensityModel(model_with_ad), spl, 100000; init_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])
192+
```

src/AdvancedMH.jl

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ using AbstractMCMC
55
using Distributions
66
using Requires
77

8+
using LogDensityProblems: LogDensityProblems
9+
810
import Random
911

1012
# Exports
@@ -48,6 +50,8 @@ struct DensityModel{F} <: AbstractMCMC.AbstractModel
4850
logdensity :: F
4951
end
5052

53+
const DensityModelOrLogDensityModel = Union{<:DensityModel,<:AbstractMCMC.LogDensityModel}
54+
5155
# Create a very basic Transition type, only stores the
5256
# parameter draws and the log probability of the draw.
5357
struct Transition{T,L<:Real} <: AbstractTransition
@@ -56,16 +60,22 @@ struct Transition{T,L<:Real} <: AbstractTransition
5660
end
5761

5862
# Store the new draw and its log density.
59-
Transition(model::DensityModel, params) = Transition(params, logdensity(model, params))
63+
Transition(model::DensityModelOrLogDensityModel, params) = Transition(params, logdensity(model, params))
64+
function Transition(model::AbstractMCMC.LogDensityModel, params)
65+
return Transition(params, LogDensityProblems.logdensity(model.logdensity, params))
66+
end
6067

6168
# Calculate the density of the model given some parameterization.
62-
logdensity(model::DensityModel, params) = model.logdensity(params)
63-
logdensity(model::DensityModel, t::Transition) = t.lp
69+
logdensity(model::DensityModelOrLogDensityModel, params) = model.logdensity(params)
70+
logdensity(model::DensityModelOrLogDensityModel, t::Transition) = t.lp
71+
72+
logdensity(model::AbstractMCMC.LogDensityModel, params) = LogDensityProblems.logdensity(model.logdensity, params)
73+
logdensity(model::AbstractMCMC.LogDensityModel, t::Transition) = t.lp
6474

6575
# A basic chains constructor that works with the Transition struct we defined.
6676
function AbstractMCMC.bundle_samples(
6777
ts::Vector{<:AbstractTransition},
68-
model::DensityModel,
78+
model::Union{<:DensityModelOrLogDensityModel,<:AbstractMCMC.LogDensityModel},
6979
sampler::MHSampler,
7080
state,
7181
chain_type::Type{Vector{NamedTuple}};
@@ -91,7 +101,7 @@ end
91101

92102
function AbstractMCMC.bundle_samples(
93103
ts::Vector{<:Transition{<:NamedTuple}},
94-
model::DensityModel,
104+
model::Union{<:DensityModelOrLogDensityModel,<:AbstractMCMC.LogDensityModel},
95105
sampler::MHSampler,
96106
state,
97107
chain_type::Type{Vector{NamedTuple}};

src/MALA.jl

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,34 @@ struct GradientTransition{T<:Union{Vector, Real, NamedTuple}, L<:Real, G<:Union{
1919
gradient::G
2020
end
2121

22-
logdensity(model::DensityModel, t::GradientTransition) = t.lp
22+
logdensity(model::DensityModelOrLogDensityModel, t::GradientTransition) = t.lp
2323

2424
propose(rng::Random.AbstractRNG, ::MALA, model) = error("please specify initial parameters")
25-
function transition(sampler::MALA, model::DensityModel, params)
25+
function transition(sampler::MALA, model::DensityModelOrLogDensityModel, params)
2626
return GradientTransition(params, logdensity_and_gradient(model, params)...)
2727
end
2828

29+
check_capabilities(model::DensityModelOrLogDensityModel) = nothing
30+
function check_capabilities(model::AbstractMCMC.LogDensityModel)
31+
cap = LogDensityProblems.capabilities(model.logdensity)
32+
if cap === nothing
33+
throw(ArgumentError("The log density function does not support the LogDensityProblems.jl interface"))
34+
end
35+
36+
if cap === LogDensityProblems.LogDensityOrder{0}()
37+
throw(ArgumentError("The gradient of the log density function is not defined: Implement `LogDensityProblems.logdensity_and_gradient` or use automatic differentiation provided by LogDensityProblemsAD.jl"))
38+
end
39+
end
40+
2941
function AbstractMCMC.step(
3042
rng::Random.AbstractRNG,
31-
model::DensityModel,
43+
model::DensityModelOrLogDensityModel,
3244
sampler::MALA,
3345
transition_prev::GradientTransition;
3446
kwargs...
3547
)
48+
check_capabilities(model)
49+
3650
# Extract value and gradient of the log density of the current state.
3751
state = transition_prev.params
3852
logdensity_state = transition_prev.lp
@@ -76,3 +90,13 @@ function logdensity_and_gradient(model::DensityModel, params)
7690
return value(res), gradient(res)
7791
end
7892

93+
"""
94+
logdensity_and_gradient(model::AbstractMCMC.LogDensityModel, params)
95+
96+
Return the value and gradient of the log density of the parameters `params` for the `model`.
97+
"""
98+
function logdensity_and_gradient(model::AbstractMCMC.LogDensityModel, params)
99+
return LogDensityProblems.logdensity_and_gradient(model.logdensity, params)
100+
end
101+
102+

src/emcee.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ struct Ensemble{D} <: MHSampler
33
proposal::D
44
end
55

6-
function transition(sampler::Ensemble, model::DensityModel, params)
6+
function transition(sampler::Ensemble, model::DensityModelOrLogDensityModel, params)
77
return [Transition(model, x) for x in params]
88
end
99

@@ -13,7 +13,7 @@ end
1313
# (if accepted) or the previous proposal (if not accepted).
1414
function AbstractMCMC.step(
1515
rng::Random.AbstractRNG,
16-
model::DensityModel,
16+
model::DensityModelOrLogDensityModel,
1717
spl::Ensemble,
1818
params_prev::Vector{<:Transition};
1919
kwargs...,
@@ -26,7 +26,7 @@ end
2626
#
2727
# Initial proposal
2828
#
29-
function propose(rng::Random.AbstractRNG, spl::Ensemble, model::DensityModel)
29+
function propose(rng::Random.AbstractRNG, spl::Ensemble, model::DensityModelOrLogDensityModel)
3030
# Make the first proposal with a static draw from the prior.
3131
static_prop = StaticProposal(spl.proposal.proposal)
3232
mh_spl = MetropolisHastings(static_prop)
@@ -39,7 +39,7 @@ end
3939
function propose(
4040
rng::Random.AbstractRNG,
4141
spl::Ensemble,
42-
model::DensityModel,
42+
model::DensityModelOrLogDensityModel,
4343
walkers::Vector{<:Transition},
4444
)
4545
new_walkers = similar(walkers)
@@ -68,7 +68,7 @@ StretchProposal(p) = StretchProposal(p, 2.0)
6868
function move(
6969
rng::Random.AbstractRNG,
7070
spl::Ensemble{<:StretchProposal},
71-
model::DensityModel,
71+
model::DensityModelOrLogDensityModel,
7272
walker::Transition,
7373
other_walker::Transition,
7474
)

src/mcmcchains-connect.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import .MCMCChains: Chains
33
# A basic chains constructor that works with the Transition struct we defined.
44
function AbstractMCMC.bundle_samples(
55
ts::Vector{<:AbstractTransition},
6-
model::DensityModel,
6+
model::DensityModelOrLogDensityModel,
77
sampler::MHSampler,
88
state,
99
chain_type::Type{Chains};
@@ -32,7 +32,7 @@ end
3232

3333
function AbstractMCMC.bundle_samples(
3434
ts::Vector{<:Transition{<:NamedTuple}},
35-
model::DensityModel,
35+
model::DensityModelOrLogDensityModel,
3636
sampler::MHSampler,
3737
state,
3838
chain_type::Type{Chains};
@@ -71,7 +71,7 @@ end
7171

7272
function AbstractMCMC.bundle_samples(
7373
ts::Vector{<:Vector{<:AbstractTransition}},
74-
model::DensityModel,
74+
model::DensityModelOrLogDensityModel,
7575
sampler::Ensemble,
7676
state,
7777
chain_type::Type{Chains};

src/mh-core.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,23 +48,23 @@ end
4848
StaticMH(d) = MetropolisHastings(StaticProposal(d))
4949
RWMH(d) = MetropolisHastings(RandomWalkProposal(d))
5050

51-
function propose(rng::Random.AbstractRNG, sampler::MHSampler, model::DensityModel)
51+
function propose(rng::Random.AbstractRNG, sampler::MHSampler, model::DensityModelOrLogDensityModel)
5252
return propose(rng, sampler.proposal, model)
5353
end
5454
function propose(
5555
rng::Random.AbstractRNG,
5656
sampler::MHSampler,
57-
model::DensityModel,
57+
model::DensityModelOrLogDensityModel,
5858
transition_prev::Transition,
5959
)
6060
return propose(rng, sampler.proposal, model, transition_prev.params)
6161
end
6262

63-
function transition(sampler::MHSampler, model::DensityModel, params)
63+
function transition(sampler::MHSampler, model::DensityModelOrLogDensityModel, params)
6464
logdensity = AdvancedMH.logdensity(model, params)
6565
return transition(sampler, model, params, logdensity)
6666
end
67-
function transition(sampler::MHSampler, model::DensityModel, params, logdensity::Real)
67+
function transition(sampler::MHSampler, model::DensityModelOrLogDensityModel, params, logdensity::Real)
6868
return Transition(params, logdensity)
6969
end
7070

@@ -73,7 +73,7 @@ end
7373
# In this case they are identical.
7474
function AbstractMCMC.step(
7575
rng::Random.AbstractRNG,
76-
model::DensityModel,
76+
model::DensityModelOrLogDensityModel,
7777
sampler::MHSampler;
7878
init_params=nothing,
7979
kwargs...
@@ -89,7 +89,7 @@ end
8989
# or the previous proposal (if not accepted).
9090
function AbstractMCMC.step(
9191
rng::Random.AbstractRNG,
92-
model::DensityModel,
92+
model::DensityModelOrLogDensityModel,
9393
sampler::MHSampler,
9494
transition_prev::AbstractTransition;
9595
kwargs...

src/proposal.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,15 @@ end
4141
function propose(
4242
rng::Random.AbstractRNG,
4343
proposal::RandomWalkProposal{issymmetric,<:Union{Distribution,AbstractArray}},
44-
::DensityModel
44+
::DensityModelOrLogDensityModel
4545
) where {issymmetric}
4646
return rand(rng, proposal)
4747
end
4848

4949
function propose(
5050
rng::Random.AbstractRNG,
5151
proposal::RandomWalkProposal{issymmetric,<:Union{Distribution,AbstractArray}},
52-
model::DensityModel,
52+
model::DensityModelOrLogDensityModel,
5353
t
5454
) where {issymmetric}
5555
return t + rand(rng, proposal)
@@ -70,7 +70,7 @@ end
7070
function propose(
7171
rng::Random.AbstractRNG,
7272
proposal::StaticProposal{issymmetric,<:Union{Distribution,AbstractArray}},
73-
model::DensityModel,
73+
model::DensityModelOrLogDensityModel,
7474
t=nothing
7575
) where {issymmetric}
7676
return rand(rng, proposal)
@@ -103,15 +103,15 @@ end
103103
function propose(
104104
rng::Random.AbstractRNG,
105105
proposal::Proposal{<:Function},
106-
model::DensityModel
106+
model::DensityModelOrLogDensityModel
107107
)
108108
return propose(rng, proposal(), model)
109109
end
110110

111111
function propose(
112112
rng::Random.AbstractRNG,
113113
proposal::Proposal{<:Function},
114-
model::DensityModel,
114+
model::DensityModelOrLogDensityModel,
115115
t
116116
)
117117
return propose(rng, proposal(t), model)
@@ -132,7 +132,7 @@ end
132132
function propose(
133133
rng::Random.AbstractRNG,
134134
proposals::AbstractArray{<:Proposal},
135-
model::DensityModel,
135+
model::DensityModelOrLogDensityModel,
136136
)
137137
return map(proposals) do proposal
138138
return propose(rng, proposal, model)
@@ -141,7 +141,7 @@ end
141141
function propose(
142142
rng::Random.AbstractRNG,
143143
proposals::AbstractArray{<:Proposal},
144-
model::DensityModel,
144+
model::DensityModelOrLogDensityModel,
145145
ts,
146146
)
147147
return map(proposals, ts) do proposal, t
@@ -152,7 +152,7 @@ end
152152
@generated function propose(
153153
rng::Random.AbstractRNG,
154154
proposals::NamedTuple{names},
155-
model::DensityModel,
155+
model::DensityModelOrLogDensityModel,
156156
) where {names}
157157
isempty(names) && return :(NamedTuple())
158158
expr = Expr(:tuple)
@@ -163,7 +163,7 @@ end
163163
@generated function propose(
164164
rng::Random.AbstractRNG,
165165
proposals::NamedTuple{names},
166-
model::DensityModel,
166+
model::DensityModelOrLogDensityModel,
167167
ts,
168168
) where {names}
169169
isempty(names) && return :(NamedTuple())

src/structarray-connect.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import .StructArrays: StructArray
33
# A basic chains constructor that works with the Transition struct we defined.
44
function AbstractMCMC.bundle_samples(
55
ts::Vector{<:AbstractTransition},
6-
model::DensityModel,
6+
model::DensityModelOrLogDensityModel,
77
sampler::MHSampler,
88
state,
99
chain_type::Type{StructArray};

0 commit comments

Comments
 (0)