Skip to content

Commit 1638f06

Browse files
authored
Remove explicit use of LogDensityModel (#76)
* Simplify README * Simplify tests
1 parent ac6bf80 commit 1638f06

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

README.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,9 @@ Quantiles
7070

7171
### Usage with [`LogDensityProblems.jl`](https://github.com/tpapp/LogDensityProblems.jl)
7272

73-
It can also be used with models defining the [`LogDensityProblems.jl`](https://github.com/tpapp/LogDensityProblems.jl) interface by wrapping it in `AbstractMCMC.LogDensityModel` before passing it to `sample`:
73+
Alternatively, you can define your model with the [`LogDensityProblems.jl`](https://github.com/tpapp/LogDensityProblems.jl) interface:
7474

7575
``` julia
76-
using AbstractMCMC: LogDensityModel
7776
using LogDensityProblems
7877

7978
# Use a struct instead of `typeof(density)` for sake of readability.
@@ -83,7 +82,7 @@ LogDensityProblems.logdensity(p::LogTargetDensity, θ) = density(θ) # standard
8382
LogDensityProblems.dimension(p::LogTargetDensity) = 2
8483
LogDensityProblems.capabilities(::LogTargetDensity) = LogDensityProblems.LogDensityOrder{0}()
8584

86-
sample(LogDensityModel(LogTargetDensity()), spl, 100000; param_names=["μ", "σ"], chain_type=Chains)
85+
sample(LogTargetDensity(), spl, 100000; param_names=["μ", "σ"], chain_type=Chains)
8786
```
8887

8988
## Proposals
@@ -156,7 +155,6 @@ takes the gradient computed at the current sample. It is required to specify an
156155
using AdvancedMH
157156
using Distributions
158157
using MCMCChains
159-
using DiffResults
160158
using ForwardDiff
161159
using StructArrays
162160

@@ -181,12 +179,14 @@ spl = MALA(x -> MvNormal((σ² / 2) .* x, σ² * I))
181179
chain = sample(model, spl, 100000; init_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])
182180
```
183181

184-
### Usage with [`LogDensityProblemsAD.jl`](https://github.com/tpapp/LogDensityProblemsAD.jl)
182+
### Usage with [`LogDensityProblems.jl`](https://github.com/tpapp/LogDensityProblems.jl)
185183

186-
Using our implementation of the `LogDensityProblems.jl` interface from earlier, we can use [`LogDensityProblemsAD.jl`](https://github.com/tpapp/LogDensityProblemsAD.jl) to provide us with the gradient computation used in MALA:
184+
As above, we can define the model with the LogDensityProblems.jl interface.
185+
We can implement the gradient of the log density function manually, or use [`LogDensityProblemsAD.jl`](https://github.com/tpapp/LogDensityProblemsAD.jl) to provide us with the gradient computation used in MALA.
186+
Using our implementation of the `LogDensityProblems.jl` interface above:
187187

188188
```julia
189189
using LogDensityProblemsAD
190190
model_with_ad = LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), LogTargetDensity())
191-
sample(LogDensityModel(model_with_ad), spl, 100000; init_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])
191+
sample(model_with_ad, spl, 100000; init_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])
192192
```

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ include("util.jl")
265265
@testset "LogDensityProblems interface" begin
266266
admodel = LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), density)
267267
chain2 = sample(
268-
AdvancedMH.AbstractMCMC.LogDensityModel(admodel),
268+
admodel,
269269
spl1,
270270
100000;
271271
init_params=ones(2),

0 commit comments

Comments
 (0)