Skip to content

Commit 5538397

Browse files
authored
Add examples for AD backend switching (#423)
* Add examples for AD backend switching * Apply reviews * Fix typo
1 parent ce8fa6f commit 5538397

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

docs/src/autodiff.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
# Gradient in AdvancedHMC.jl
22

3-
AdvancedHMC.jl supports automatic differentiation using [`LogDensityProblemsAD`](https://github.com/tpapp/LogDensityProblemsAD.jl) across various AD backends and allows user-specified gradients. While the default AD backend for AdvancedHMC.jl is ForwardDiff.jl, we can seamlessly change to other backend like Mooncake.jl using various syntax like `Hamiltonian(metric, ℓπ, AutoZygote())`. Different AD backend can also be pluged in using `Hamiltonian(metric, ℓπ, Zygote)`, `Hamiltonian(metric, ℓπ, Val(:Zygote))` but we recommend using ADTypes since that would allow you to have more freedom for specifying the AD backend.
3+
AdvancedHMC.jl supports automatic differentiation using [`LogDensityProblemsAD`](https://github.com/tpapp/LogDensityProblemsAD.jl) across various AD backends and allows user-specified gradients. While the default AD backend for AdvancedHMC.jl is ForwardDiff.jl, we can seamlessly change to other backend like Mooncake.jl using various syntax like `Hamiltonian(metric, ℓπ, AutoMooncake(; config = nothing))`. While some AD backends support syntax like `Hamiltonian(metric, ℓπ, Zygote)`, `Hamiltonian(metric, ℓπ, Val(:Zygote))`, we recommend using ADTypes since that would allow you to have more freedom for specifying the AD backend:
4+
5+
```julia
6+
using AdvancedHMC, ADTypes, DifferentiationInterface, Mooncake, Zygote
7+
hamiltonian = Hamiltonian(metric, ℓπ, AutoMooncake(; config=nothing))
8+
hamiltonian = Hamiltonian(metric, ℓπ, AutoZygote())
9+
```
410

511
In order to use user-specified gradients, please replace ForwardDiff.jl with `ℓπ_grad` in the `Hamiltonian` constructor as `Hamiltonian(metric, ℓπ, ℓπ_grad)`, where the gradient function `ℓπ_grad` should return a tuple containing both the log-posterior and its gradient, for example `ℓπ_grad(x) = (log_posterior, grad)`.

0 commit comments

Comments
 (0)