Skip to content

Commit d77bbac

Browse files
update docs and bump version to 0.10.4
1 parent 8dfc800 commit d77bbac

File tree

3 files changed

+169
-43
lines changed

3 files changed

+169
-43
lines changed

JuliaBUGS/History.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
# JuliaBUGS Changelog
22

3+
## 0.10.4
4+
5+
- **DifferentiationInterface.jl integration**: JuliaBUGS now uses [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl) for automatic differentiation, providing a unified interface to multiple AD backends.
6+
- Add `adtype` parameter to `compile()` function for specifying AD backends
7+
- Support convenient symbol shortcuts: `:ReverseDiff`, `:ForwardDiff`, `:Zygote`, `:Enzyme`
8+
- Gradient computation is prepared during compilation for optimal performance
9+
- Example: `model = compile(model_def, data; adtype=:ReverseDiff)`
10+
- Full control available via explicit ADTypes: `adtype=AutoReverseDiff(compile=true)`
11+
- Backward compatible: models without `adtype` work as before
12+
313
## 0.10.1
414

515
Expose docs for changes in [v0.10.0](https://github.com/TuringLang/JuliaBUGS.jl/releases/tag/JuliaBUGS-v0.10.0)

JuliaBUGS/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "JuliaBUGS"
22
uuid = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf"
3-
version = "0.10.3"
3+
version = "0.10.4"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

JuliaBUGS/docs/src/example.md

Lines changed: 158 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -190,40 +190,62 @@ initialize!(model, initializations)
190190
initialize!(model, rand(26))
191191
```
192192

193-
`LogDensityProblemsAD.jl` defined some extensions that support automatic differentiation packages.
194-
For example, with `ReverseDiff.jl`
193+
### Automatic Differentiation
194+
195+
JuliaBUGS integrates with automatic differentiation (AD) through [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl), enabling gradient-based inference methods like Hamiltonian Monte Carlo (HMC) and No-U-Turn Sampler (NUTS).
196+
197+
#### Specifying an AD Backend
198+
199+
To compile a model with gradient support, pass the `adtype` parameter to `compile`:
195200

196201
```julia
197-
using LogDensityProblemsAD, ReverseDiff
202+
# Using explicit ADType from ADTypes.jl
203+
using ADTypes
204+
model = compile(model_def, data; adtype=AutoReverseDiff(compile=true))
205+
206+
# Using convenient symbol shortcuts
207+
model = compile(model_def, data; adtype=:ReverseDiff) # Equivalent to above
208+
```
198209

199-
ad_model = ADgradient(:ReverseDiff, model; compile=Val(true))
210+
Available AD backends include:
211+
- `:ReverseDiff` - ReverseDiff with tape compilation (recommended for most models)
212+
- `:ForwardDiff` - ForwardDiff (efficient for models with few parameters)
213+
- `:Zygote` - Zygote (source-to-source AD)
214+
- `:Enzyme` - Enzyme (experimental, high-performance)
215+
216+
For fine-grained control, use explicit `ADTypes` constructors:
217+
218+
```julia
219+
# ReverseDiff without compilation
220+
model = compile(model_def, data; adtype=AutoReverseDiff(compile=false))
200221
```
201222

202-
Here `ad_model` will also implement all the interfaces of [`LogDensityProblems.jl`](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/).
203-
`LogDensityProblemsAD.jl` will automatically add the interface function [`logdensity_and_gradient`](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.logdensity_and_gradient) to the model, which will return the log density and gradient of the model.
204-
And `ad_model` can be used in the same way as `model` in the example below.
223+
The compiled model with gradient support implements the [`LogDensityProblems.jl`](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/) interface, including [`logdensity_and_gradient`](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.logdensity_and_gradient), which returns both the log density and its gradient.
205224

206225
### Inference
207226

208-
For a differentiable model, we can use [`AdvancedHMC.jl`](https://github.com/TuringLang/AdvancedHMC.jl) to perform inference.
209-
For instance,
227+
For gradient-based inference, we use [`AdvancedHMC.jl`](https://github.com/TuringLang/AdvancedHMC.jl) with models compiled with an `adtype`:
210228

211229
```julia
212-
using AdvancedHMC, AbstractMCMC, LogDensityProblems, MCMCChains
230+
using AdvancedHMC, AbstractMCMC, LogDensityProblems, MCMCChains, ReverseDiff
231+
232+
# Compile with gradient support
233+
model = compile(model_def, data; adtype=:ReverseDiff)
213234

214235
n_samples, n_adapts = 2000, 1000
215236

216237
D = LogDensityProblems.dimension(model); initial_θ = rand(D)
217238

218239
samples_and_stats = AbstractMCMC.sample(
219-
ad_model,
240+
model,
220241
NUTS(0.8),
221242
n_samples;
222243
chain_type = Chains,
223244
n_adapts = n_adapts,
224245
init_params = initial_θ,
225246
discard_initial = n_adapts
226247
)
248+
describe(samples_and_stats)
227249
```
228250

229251
This will return the MCMC Chain,
@@ -234,39 +256,72 @@ Chains MCMC chain (2000×40×1 Array{Real, 3}):
234256
Iterations = 1001:1:3000
235257
Number of chains = 1
236258
Samples per chain = 2000
237-
parameters = alpha0, alpha12, alpha1, alpha2, tau, b[16], b[12], b[10], b[14], b[13], b[7], b[6], b[20], b[1], b[4], b[5], b[2], b[18], b[8], b[3], b[9], b[21], b[17], b[15], b[11], b[19], sigma
259+
parameters = tau, alpha12, alpha2, alpha1, alpha0, b[1], b[2], b[3], b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11], b[12], b[13], b[14], b[15], b[16], b[17], b[18], b[19], b[20], b[21], sigma
238260
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size, is_adapt
239261
240262
Summary Statistics
241-
parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec
242-
Symbol Float64 Float64 Float64 Real Float64 Float64 Missing
243-
244-
alpha0 -0.5642 0.2320 0.0084 766.9305 1022.5211 1.0021 missing
245-
alpha12 -0.8489 0.5247 0.0170 946.0418 1044.1109 1.0002 missing
246-
alpha1 0.0587 0.3715 0.0119 966.4367 1233.2257 1.0007 missing
247-
alpha2 1.3852 0.3410 0.0127 712.2978 974.1566 1.0002 missing
248-
tau 1.8880 0.7705 0.0447 348.9331 338.3655 1.0030 missing
249-
b[16] -0.2445 0.4459 0.0132 1528.0578 843.8225 1.0003 missing
250-
b[12] 0.2050 0.3602 0.0086 1868.6126 1202.1363 0.9996 missing
251-
b[10] -0.3500 0.2893 0.0090 1047.3119 1245.9358 1.0008 missing
252-
⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮
253-
19 rows omitted
263+
parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec
264+
Symbol Float64 Float64 Float64 Real Float64 Float64 Missing
265+
266+
tau 73.1490 193.8441 43.2582 56.3430 20.6688 1.0155 missing
267+
alpha12 -0.8052 0.4392 0.0158 761.2180 1049.1664 1.0020 missing
268+
alpha2 1.3428 0.2813 0.0140 422.8810 1013.2570 1.0061 missing
269+
alpha1 0.0845 0.3126 0.0113 773.2202 981.8487 1.0051 missing
270+
alpha0 -0.5480 0.1944 0.0087 537.6212 1156.2083 1.0014 missing
271+
b[1] -0.1905 0.2540 0.0129 374.3372 971.7526 1.0034 missing
272+
b[2] 0.0161 0.2178 0.0056 1505.6353 1002.8787 1.0001 missing
273+
b[3] -0.1986 0.2375 0.0128 367.6766 1287.8215 1.0015 missing
274+
b[4] 0.2792 0.2498 0.0163 201.1558 1168.7538 1.0068 missing
275+
b[5] 0.1170 0.2397 0.0092 659.5422 1484.8584 1.0016 missing
276+
b[6] 0.0667 0.2821 0.0074 1745.5567 902.1014 1.0067 missing
277+
b[7] 0.0597 0.2218 0.0055 1589.5590 1145.6017 1.0065 missing
278+
b[8] 0.1769 0.2316 0.0102 554.5974 1318.8089 1.0001 missing
279+
b[9] -0.1257 0.2233 0.0073 930.0346 1186.4283 1.0031 missing
280+
b[10] -0.2513 0.2392 0.0159 213.6323 1142.4487 1.0096 missing
281+
b[11] 0.0768 0.2783 0.0081 1376.5999 1218.1537 1.0009 missing
282+
b[12] 0.1171 0.2768 0.0079 1354.9409 1130.8217 1.0052 missing
283+
b[13] -0.0688 0.2433 0.0055 1895.0387 1527.7066 1.0010 missing
284+
b[14] -0.1363 0.2558 0.0075 1276.0992 1208.8587 1.0001 missing
285+
b[15] 0.2334 0.2757 0.0135 439.2241 837.3396 1.0036 missing
286+
b[16] -0.1212 0.3024 0.0106 1093.4416 914.9457 0.9997 missing
287+
b[17] -0.2120 0.3142 0.0166 360.6420 702.4098 1.0009 missing
288+
b[18] 0.0346 0.2282 0.0056 1665.0325 1281.7179 1.0011 missing
289+
b[19] -0.0244 0.2400 0.0052 2186.7638 1179.6971 1.0132 missing
290+
b[20] 0.2108 0.2421 0.0131 349.7657 1263.5781 1.0016 missing
291+
b[21] -0.0509 0.2813 0.0061 2200.5614 916.6256 0.9998 missing
292+
sigma 0.2797 0.1362 0.0168 56.3430 21.4971 1.0123 missing
254293
255294
Quantiles
256-
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
257-
Symbol Float64 Float64 Float64 Float64 Float64
258-
259-
alpha0 -1.0143 -0.7143 -0.5590 -0.4100 -0.1185
260-
alpha12 -1.9063 -1.1812 -0.8296 -0.5153 0.1521
261-
alpha1 -0.6550 -0.1822 0.0512 0.2885 0.8180
262-
alpha2 0.7214 1.1663 1.3782 1.5998 2.0986
263-
tau 0.5461 1.3941 1.8353 2.3115 3.6225
264-
b[16] -1.2359 -0.4836 -0.1909 0.0345 0.5070
265-
b[12] -0.4493 -0.0370 0.1910 0.4375 0.9828
266-
b[10] -0.9570 -0.5264 -0.3331 -0.1514 0.1613
267-
⋮ ⋮ ⋮ ⋮ ⋮ ⋮
268-
19 rows omitted
269-
295+
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
296+
Symbol Float64 Float64 Float64 Float64 Float64
297+
298+
tau 3.1280 7.4608 13.0338 28.2289 929.6520
299+
alpha12 -1.6645 -1.0887 -0.7952 -0.5635 0.1162
300+
alpha2 0.8398 1.1494 1.3233 1.5337 1.9177
301+
alpha1 -0.5796 -0.1059 0.1042 0.2883 0.6702
302+
alpha0 -0.9340 -0.6751 -0.5463 -0.4086 -0.1752
303+
b[1] -0.7430 -0.3415 -0.1566 -0.0074 0.2535
304+
b[2] -0.4261 -0.1083 0.0192 0.1420 0.4810
305+
b[3] -0.7394 -0.3377 -0.1687 -0.0242 0.2041
306+
b[4] -0.1108 0.0873 0.2409 0.4375 0.8267
307+
b[5] -0.3141 -0.0458 0.0900 0.2563 0.6489
308+
b[6] -0.4679 -0.0896 0.0291 0.2202 0.7060
309+
b[7] -0.3861 -0.0685 0.0534 0.1847 0.5207
310+
b[8] -0.2326 0.0221 0.1505 0.3162 0.6861
311+
b[9] -0.6007 -0.2482 -0.0984 0.0057 0.2771
312+
b[10] -0.7936 -0.4108 -0.2255 -0.0617 0.1290
313+
b[11] -0.4381 -0.0796 0.0353 0.2178 0.7232
314+
b[12] -0.3806 -0.0451 0.0750 0.2671 0.7625
315+
b[13] -0.5841 -0.2135 -0.0443 0.0652 0.4055
316+
b[14] -0.6854 -0.2872 -0.1015 0.0147 0.3476
317+
b[15] -0.2054 0.0257 0.1898 0.4004 0.8660
318+
b[16] -0.8173 -0.2829 -0.0804 0.0532 0.4094
319+
b[17] -0.9071 -0.3911 -0.1595 0.0099 0.2864
320+
b[18] -0.4526 -0.0919 0.0140 0.1686 0.4985
321+
b[19] -0.5055 -0.1547 -0.0091 0.1134 0.4528
322+
b[20] -0.2120 0.0318 0.1788 0.3673 0.7416
323+
b[21] -0.6482 -0.2044 -0.0263 0.1051 0.5246
324+
sigma 0.0328 0.1882 0.2770 0.3661 0.5654
270325
```
271326

272327
This is consistent with the result in the [OpenBUGS seeds example](https://chjackson.github.io/openbugsdoc/Examples/Seeds.html).
@@ -283,7 +338,7 @@ The model compilation code remains the same, and we can sample multiple chains i
283338
```julia
284339
n_chains = 4
285340
samples_and_stats = AbstractMCMC.sample(
286-
ad_model,
341+
model,
287342
AdvancedHMC.NUTS(0.65),
288343
AbstractMCMC.MCMCThreads(),
289344
n_samples,
@@ -311,7 +366,7 @@ For example:
311366

312367
```julia
313368
@everywhere begin
314-
using JuliaBUGS, LogDensityProblems, LogDensityProblemsAD, AbstractMCMC, AdvancedHMC, MCMCChains, ReverseDiff # also other packages one may need
369+
using JuliaBUGS, LogDensityProblems, AbstractMCMC, AdvancedHMC, MCMCChains, ADTypes, ReverseDiff
315370

316371
# Define the functions to use
317372
# Use `@bugs_primitive` to register the functions to use in the model
@@ -322,7 +377,7 @@ end
322377

323378
n_chains = nprocs() - 1 # use all the processes except the parent process
324379
samples_and_stats = AbstractMCMC.sample(
325-
ad_model,
380+
model,
326381
AdvancedHMC.NUTS(0.65),
327382
AbstractMCMC.MCMCDistributed(),
328383
n_samples,
@@ -342,6 +397,67 @@ In this case, we pass two additional arguments to `AbstractMCMC.sample`:
342397
Note that the `init_params` argument is now a vector of initial parameters for each chain.
343398
Sometimes the progress logger can cause problems in distributed setting, so we can disable it by setting `progress = false`.
344399

400+
## Choosing an Automatic Differentiation Backend
401+
402+
JuliaBUGS integrates with multiple automatic differentiation (AD) backends through [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl), providing flexibility to choose the most suitable backend for your model.
403+
404+
### Available Backends
405+
406+
The following AD backends are supported via convenient symbol shortcuts:
407+
408+
- **`:ReverseDiff`** (recommended) — Tape-based reverse-mode AD, highly efficient for models with many parameters. Uses compilation by default for optimal performance.
409+
- **`:ForwardDiff`** — Forward-mode AD, efficient for models with few parameters (typically < 20).
410+
- **`:Zygote`** — Source-to-source reverse-mode AD, general-purpose but may be slower than ReverseDiff for many models.
411+
- **`:Enzyme`** — Experimental high-performance AD backend with LLVM-level transformations.
412+
413+
### Usage Examples
414+
415+
#### Basic Usage with Symbol Shortcuts
416+
417+
The simplest way to specify an AD backend is using symbol shortcuts:
418+
419+
```julia
420+
# ReverseDiff with tape compilation (recommended for most models)
421+
model = compile(model_def, data; adtype=:ReverseDiff)
422+
423+
# ForwardDiff (good for small models with few parameters)
424+
model = compile(model_def, data; adtype=:ForwardDiff)
425+
426+
# Zygote (source-to-source AD)
427+
model = compile(model_def, data; adtype=:Zygote)
428+
```
429+
430+
#### Advanced Configuration
431+
432+
For fine-grained control, use explicit `ADTypes` constructors:
433+
434+
```julia
435+
using ADTypes
436+
437+
# ReverseDiff without tape compilation
438+
model = compile(model_def, data; adtype=AutoReverseDiff(compile=false))
439+
440+
# ReverseDiff with compilation (equivalent to :ReverseDiff)
441+
model = compile(model_def, data; adtype=AutoReverseDiff(compile=true))
442+
```
443+
444+
### Performance Considerations
445+
446+
- **ReverseDiff with compilation** (`:ReverseDiff`) is recommended for most models, especially those with many parameters. Compilation adds a one-time overhead but significantly speeds up subsequent gradient evaluations.
447+
448+
- **ForwardDiff** (`:ForwardDiff`) is often faster for models with few parameters (< 20), as it avoids tape construction overhead.
449+
450+
- **Tape compilation trade-off**: While `AutoReverseDiff(compile=true)` has higher initial compilation time, it provides faster gradient evaluations during sampling. For quick prototyping or models that will only be sampled a few times, `AutoReverseDiff(compile=false)` may be preferable.
451+
452+
### Compatibility
453+
454+
All models compiled with an `adtype` implement the full [`LogDensityProblems.jl`](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/) interface, making them compatible with:
455+
456+
- [`AdvancedHMC.jl`](https://github.com/TuringLang/AdvancedHMC.jl) — NUTS and HMC samplers
457+
- Any other sampler that works with the LogDensityProblems interface
458+
459+
The gradient computation is prepared during model compilation for optimal performance during sampling.
460+
345461
## More Examples
346462

347463
We have transcribed all the examples from the first volume of the BUGS Examples ([original](https://www.multibugs.org/examples/latest/VolumeI.html) and [transcribed](https://github.com/TuringLang/JuliaBUGS.jl/tree/main/JuliaBUGS/src/BUGSExamples/Volume_1)). All programs and data are included, and can be compiled using the steps described in the tutorial above.

0 commit comments

Comments
 (0)