Skip to content

Commit 1c922e4

Browse files
authored
Some clean up + making MALA available without ForwardDiff (#78)
* removed unnecessary unions and double definitions * make MALA available even when ForwardDiff isn't exportedc * forgot to add src/forwarddiff.jl * version bump * fixed method ambiguity * fix constructor for MALA * added check for capabilties of LogDensityModel
1 parent 1638f06 commit 1c922e4

File tree

4 files changed

+24
-18
lines changed

4 files changed

+24
-18
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AdvancedMH"
22
uuid = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
3-
version = "0.7.1"
3+
version = "0.7.2"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/AdvancedMH.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,16 +66,15 @@ function Transition(model::AbstractMCMC.LogDensityModel, params)
6666
end
6767

6868
# Calculate the density of the model given some parameterization.
69-
logdensity(model::DensityModelOrLogDensityModel, params) = model.logdensity(params)
70-
logdensity(model::DensityModelOrLogDensityModel, t::Transition) = t.lp
71-
69+
logdensity(model::DensityModel, params) = model.logdensity(params)
70+
logdensity(::DensityModel, t::Transition) = t.lp
7271
logdensity(model::AbstractMCMC.LogDensityModel, params) = LogDensityProblems.logdensity(model.logdensity, params)
73-
logdensity(model::AbstractMCMC.LogDensityModel, t::Transition) = t.lp
72+
logdensity(::AbstractMCMC.LogDensityModel, t::Transition) = t.lp
7473

7574
# A basic chains constructor that works with the Transition struct we defined.
7675
function AbstractMCMC.bundle_samples(
7776
ts::Vector{<:AbstractTransition},
78-
model::Union{<:DensityModelOrLogDensityModel,<:AbstractMCMC.LogDensityModel},
77+
model::DensityModelOrLogDensityModel,
7978
sampler::MHSampler,
8079
state,
8180
chain_type::Type{Vector{NamedTuple}};
@@ -101,7 +100,7 @@ end
101100

102101
function AbstractMCMC.bundle_samples(
103102
ts::Vector{<:Transition{<:NamedTuple}},
104-
model::Union{<:DensityModelOrLogDensityModel,<:AbstractMCMC.LogDensityModel},
103+
model::DensityModelOrLogDensityModel,
105104
sampler::MHSampler,
106105
state,
107106
chain_type::Type{Vector{NamedTuple}};
@@ -122,13 +121,14 @@ function __init__()
122121
@require MCMCChains="c7f686f2-ff18-58e9-bc7b-31028e88f75d" include("mcmcchains-connect.jl")
123122
@require StructArrays="09ab397b-f2b6-538f-b94a-2f83cf4a842a" include("structarray-connect.jl")
124123
@require DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" begin
125-
@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include("MALA.jl")
124+
@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include("forwarddiff.jl")
126125
end
127126
end
128127

129128
# Include inference methods.
130129
include("proposal.jl")
131130
include("mh-core.jl")
132131
include("emcee.jl")
132+
include("MALA.jl")
133133

134134
end # module AdvancedMH

src/MALA.jl

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
1-
using .ForwardDiff: gradient!
2-
using .DiffResults: GradientResult, value, gradient
3-
41
struct MALA{D} <: MHSampler
52
proposal::D
3+
4+
MALA{D}(proposal::D) where {D} = new{D}(proposal)
65
end
76

7+
# If we were given a RandomWalkProposal, just use that instead.
8+
MALA(d::RandomWalkProposal) = MALA{typeof(d)}(d)
89

910
# Create a RandomWalkProposal if we weren't given one already.
1011
MALA(d) = MALA(RandomWalkProposal(d))
1112

12-
# If we were given a RandomWalkProposal, just use that instead.
13-
MALA(d::RandomWalkProposal) = MALA{typeof(d)}(d)
14-
1513

1614
struct GradientTransition{T<:Union{Vector, Real, NamedTuple}, L<:Real, G<:Union{Vector, Real, NamedTuple}} <: AbstractTransition
1715
params::T
@@ -85,9 +83,7 @@ end
8583
Return the value and gradient of the log density of the parameters `params` for the `model`.
8684
"""
8785
function logdensity_and_gradient(model::DensityModel, params)
88-
res = GradientResult(params)
89-
gradient!(res, model.logdensity, params)
90-
return value(res), gradient(res)
86+
error("no implementation exist for `DensityModel`; try importing ForwardDiff.jl")
9187
end
9288

9389
"""
@@ -96,7 +92,8 @@ end
9692
Return the value and gradient of the log density of the parameters `params` for the `model`.
9793
"""
9894
function logdensity_and_gradient(model::AbstractMCMC.LogDensityModel, params)
99-
return LogDensityProblems.logdensity_and_gradient(model.logdensity, params)
95+
check_capabilities(model)
96+
return LogDensityProblems.logdensity_and_gradient(model.logdensity, params)
10097
end
10198

10299

src/forwarddiff.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
using .ForwardDiff: gradient!
2+
using .DiffResults: GradientResult, value, gradient
3+
using .AdvancedMH: AdvancedMH
4+
5+
function AdvancedMH.logdensity_and_gradient(model::AdvancedMH.DensityModel, params)
6+
res = GradientResult(params)
7+
gradient!(res, model.logdensity, params)
8+
return value(res), gradient(res)
9+
end

0 commit comments

Comments
 (0)