Skip to content

Commit 658191d

Browse files
authored
Add extensions with weak dependencies (#81)
* Add extensions with weak dependencies * Fix test tolerance * Do not overwrite method * Fixes * Fix typo * Fix error hint
1 parent 1c922e4 commit 658191d

File tree

8 files changed

+78
-27
lines changed

8 files changed

+78
-27
lines changed

Project.toml

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AdvancedMH"
22
uuid = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
3-
version = "0.7.2"
3+
version = "0.7.3"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -9,11 +9,26 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
99
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1010
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1111

12+
[weakdeps]
13+
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
14+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
15+
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
16+
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
17+
18+
[extensions]
19+
AdvancedMHForwardDiffExt = ["DiffResults", "ForwardDiff"]
20+
AdvancedMHMCMCChainsExt = "MCMCChains"
21+
AdvancedMHStructArraysExt = "StructArrays"
22+
1223
[compat]
1324
AbstractMCMC = "4"
25+
DiffResults = "1"
1426
Distributions = "0.20, 0.21, 0.22, 0.23, 0.24, 0.25"
27+
ForwardDiff = "0.10"
1528
LogDensityProblems = "2"
29+
MCMCChains = "5, 6"
1630
Requires = "1"
31+
StructArrays = "0.6"
1732
julia = "1.6"
1833

1934
[extras]

ext/AdvancedMHForwardDiffExt.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
module AdvancedMHForwardDiffExt
2+
3+
if isdefined(Base, :get_extension)
4+
import AdvancedMH
5+
import DiffResults
6+
import ForwardDiff
7+
else
8+
import ..AdvancedMH
9+
import ..DiffResults
10+
import ..ForwardDiff
11+
end
12+
13+
function AdvancedMH.logdensity_and_gradient(model::AdvancedMH.DensityModel, params)
14+
res = DiffResults.GradientResult(params)
15+
ForwardDiff.gradient!(res, model.logdensity, params)
16+
return DiffResults.value(res), DiffResults.gradient(res)
17+
end
18+
19+
20+
end # module

src/mcmcchains-connect.jl renamed to ext/AdvancedMHMCMCChainsExt.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,12 @@
1-
import .MCMCChains: Chains
1+
module AdvancedMHMCMCChainsExt
2+
3+
if isdefined(Base, :get_extension)
4+
using AdvancedMH: AbstractMCMC, AbstractTransition, DensityModelOrLogDensityModel, Ensemble, MHSampler, Transition
5+
using MCMCChains: Chains
6+
else
7+
using ..AdvancedMH: AbstractMCMC, AbstractTransition, DensityModelOrLogDensityModel, Ensemble, MHSampler, Transition
8+
using ..MCMCChains: Chains
9+
end
210

311
# A basic chains constructor that works with the Transition struct we defined.
412
function AbstractMCMC.bundle_samples(
@@ -111,3 +119,5 @@ function AbstractMCMC.bundle_samples(
111119
start=discard_initial + 1, thin=thinning,
112120
)
113121
end
122+
123+
end # module

src/structarray-connect.jl renamed to ext/AdvancedMHStructArraysExt.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,12 @@
1-
import .StructArrays: StructArray
1+
module AdvancedMHStructArraysExt
2+
3+
if isdefined(Base, :get_extension)
4+
using AdvancedMH: AbstractMCMC, AbstractTransition, DensityModelOrLogDensityModel, MHSampler
5+
using StructArrays: StructArray
6+
else
7+
using ..AdvancedMH: AbstractMCMC, AbstractTransition, DensityModelOrLogDensityModel, MHSampler
8+
using ..StructArrays: StructArray
9+
end
210

311
# A basic chains constructor that works with the Transition struct we defined.
412
function AbstractMCMC.bundle_samples(
@@ -17,3 +25,6 @@ function AbstractMCMC.bundle_samples(
1725
end
1826

1927
AbstractMCMC.chainscat(c::StructArray, cs::StructArray...) = vcat(c, cs...)
28+
29+
30+
end # module

src/AdvancedMH.jl

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ module AdvancedMH
33
# Import the relevant libraries.
44
using AbstractMCMC
55
using Distributions
6-
using Requires
76

87
using LogDensityProblems: LogDensityProblems
98

@@ -117,11 +116,23 @@ function AbstractMCMC.bundle_samples(
117116
return nts
118117
end
119118

119+
if !isdefined(Base, :get_extension)
120+
using Requires
121+
end
122+
120123
function __init__()
121-
@require MCMCChains="c7f686f2-ff18-58e9-bc7b-31028e88f75d" include("mcmcchains-connect.jl")
122-
@require StructArrays="09ab397b-f2b6-538f-b94a-2f83cf4a842a" include("structarray-connect.jl")
123-
@require DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" begin
124-
@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include("forwarddiff.jl")
124+
# Better error message if users forget to load ForwardDiff
125+
Base.Experimental.register_error_hint(MethodError) do io, exc, arg_types, kwargs
126+
if exc.f === logdensity_and_gradient && length(arg_types) == 2 && first(arg_types) <: DensityModel && isempty(kwargs)
127+
print(io, "\\nDid you forget to load ForwardDiff?")
128+
end
129+
end
130+
@static if !isdefined(Base, :get_extension)
131+
@require MCMCChains="c7f686f2-ff18-58e9-bc7b-31028e88f75d" include("../ext/AdvancedMHMCMCChainsExt.jl")
132+
@require StructArrays="09ab397b-f2b6-538f-b94a-2f83cf4a842a" include("../ext/AdvancedMHStructArraysExt.jl")
133+
@require DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" begin
134+
@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include("../ext/AdvancedMHForwardDiffExt.jl")
135+
end
125136
end
126137
end
127138

src/MALA.jl

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -78,19 +78,12 @@ function AbstractMCMC.step(
7878
end
7979

8080
"""
81-
logdensity_and_gradient(model::DensityModel, params)
81+
logdensity_and_gradient(model::AdvancedMH.DensityModelOrLogDensityModel, params)
8282
8383
Return the value and gradient of the log density of the parameters `params` for the `model`.
8484
"""
85-
function logdensity_and_gradient(model::DensityModel, params)
86-
error("no implementation exist for `DensityModel`; try importing ForwardDiff.jl")
87-
end
88-
89-
"""
90-
logdensity_and_gradient(model::AbstractMCMC.LogDensityModel, params)
85+
logdensity_and_gradient(::DensityModelOrLogDensityModel, ::Any)
9186

92-
Return the value and gradient of the log density of the parameters `params` for the `model`.
93-
"""
9487
function logdensity_and_gradient(model::AbstractMCMC.LogDensityModel, params)
9588
check_capabilities(model)
9689
return LogDensityProblems.logdensity_and_gradient(model.logdensity, params)

src/forwarddiff.jl

Lines changed: 0 additions & 9 deletions
This file was deleted.

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ include("util.jl")
9898
)
9999
@test chain1b isa Chains
100100
@test range(chain1b) == range(26; step=4, length=10_000)
101-
@test mean(chain1b["μ"]) 0.0 atol=0.1
101+
@test mean(chain1b["μ"]) 0.0 atol=0.15
102102
@test mean(chain1b["σ"]) 1.0 atol=0.1
103103

104104
# NamedTuple of parameters

0 commit comments

Comments
 (0)