Skip to content

Commit 7c322eb

Browse files
committed
added example of using the LogDensityInterface to the README
1 parent 75bffde commit 7c322eb

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

README.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,24 @@ Quantiles
6868

6969
```
7070

71+
### Usage with [`LogDensityProblems.jl`](https://github.com/tpapp/LogDensityProblems.jl)
72+
73+
It can also be used with models defining the [`LogDensityProblems.jl`](https://github.com/tpapp/LogDensityProblems.jl) interface by wrapping it in `AbstractMCMC.LogDensityModel` before passing it to `sample`:
74+
75+
``` julia
76+
using AbstractMCMC: LogDensityModel
77+
using LogDensityProblems
78+
79+
# Use a struct instead of `typeof(density)` for sake of readability.
80+
struct LogTargetDensity end
81+
82+
LogDensityProblems.logdensity(p::LogTargetDensity, θ) = density(θ) # standard multivariate normal
83+
LogDensityProblems.dimension(p::LogTargetDensity) = 2
84+
LogDensityProblems.capabilities(::LogTargetDensity) = LogDensityProblems.LogDensityOrder{0}()
85+
86+
sample(LogDensityModel(LogTargetDensity()), spl, 100000; param_names=["μ", "σ"], chain_type=Chains)
87+
```
88+
7189
## Proposals
7290

7391
AdvancedMH offers various methods of defining your inference problem. Behind the scenes, a `MetropolisHastings` sampler simply holds
@@ -157,3 +175,13 @@ spl = MALA(x -> MvNormal((σ² / 2) .* x, σ² * I))
157175
# Sample from the posterior.
158176
chain = sample(model, spl, 100000; init_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])
159177
```
178+
179+
### Usage with [`LogDensityProblemsAD.jl`](https://github.com/tpapp/LogDensityProblemsAD.jl)
180+
181+
Using our implementation of the `LogDensityProblems.jl` interface from earlier, we can use [`LogDensityProblemsAD.jl`](https://github.com/tpapp/LogDensityProblemsAD.jl) to provide us with the gradient computation used in MALA:
182+
183+
```julia
184+
using LogDensityProblemsAD
185+
model_with_ad = LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), LogTargetDensity())
186+
sample(LogDensityModel(model_with_ad), spl, 100000; init_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])
187+
```

0 commit comments

Comments
 (0)