Skip to content

Commit 8dfc800

Browse files
Simplify the workflow for computing model gradients
1 parent 53fd7bb commit 8dfc800

14 files changed

+397
-33
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,6 @@ Manifest.toml
2929
*.local.*
3030

3131
# gitingest generated files
32-
digest.txt
32+
digest.txt
33+
34+
tmp/

JuliaBUGS/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
99
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
1010
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
1111
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
12+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
1213
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1314
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
1415
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
@@ -52,6 +53,7 @@ AdvancedHMC = "0.6, 0.7, 0.8"
5253
AdvancedMH = "0.8"
5354
BangBang = "0.4.1"
5455
Bijectors = "0.13, 0.14, 0.15.5"
56+
DifferentiationInterface = "0.7"
5557
Distributions = "0.23.8, 0.24, 0.25"
5658
Documenter = "0.27, 1"
5759
GLMakie = "0.10, 0.11, 0.12, 0.13"

JuliaBUGS/ext/JuliaBUGSAdvancedHMCExt.jl

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@ module JuliaBUGSAdvancedHMCExt
33
using AbstractMCMC
44
using AdvancedHMC
55
using ADTypes
6+
import DifferentiationInterface as DI
67
using JuliaBUGS
7-
using JuliaBUGS: BUGSModel, getparams, initialize!
8+
using JuliaBUGS: BUGSModel, BUGSModelWithGradient, getparams, initialize!
89
using JuliaBUGS.LogDensityProblems
910
using JuliaBUGS.LogDensityProblemsAD
11+
using JuliaBUGS.Model: _logdensity_switched
1012
using JuliaBUGS.Random
1113
using MCMCChains: Chains
1214

@@ -40,10 +42,13 @@ end
4042
function _gibbs_internal_hmc(
4143
rng::Random.AbstractRNG, cond_model::BUGSModel, sampler, ad_backend, state
4244
)
43-
# Wrap model with AD gradient computation
44-
logdensitymodel = AbstractMCMC.LogDensityModel(
45-
LogDensityProblemsAD.ADgradient(ad_backend, cond_model)
45+
# Create gradient model on-the-fly using DifferentiationInterface
46+
x = getparams(cond_model)
47+
prep = DI.prepare_gradient(
48+
_logdensity_switched, ad_backend, x, DI.Constant(cond_model)
4649
)
50+
ad_model = BUGSModelWithGradient(ad_backend, prep, cond_model)
51+
logdensitymodel = AbstractMCMC.LogDensityModel(ad_model)
4752

4853
# Take HMC/NUTS step
4954
if isnothing(state)
@@ -53,7 +58,7 @@ function _gibbs_internal_hmc(
5358
logdensitymodel,
5459
sampler;
5560
n_adapts=0, # Disable adaptation within Gibbs
56-
initial_params=getparams(cond_model),
61+
initial_params=x,
5762
)
5863
else
5964
# Use existing state for subsequent steps
@@ -67,7 +72,7 @@ end
6772

6873
function AbstractMCMC.bundle_samples(
6974
ts::Vector{<:AdvancedHMC.Transition},
70-
logdensitymodel::AbstractMCMC.LogDensityModel{<:LogDensityProblemsAD.ADGradientWrapper},
75+
logdensitymodel::AbstractMCMC.LogDensityModel{<:BUGSModelWithGradient},
7176
sampler::AdvancedHMC.AbstractHMCSampler,
7277
state,
7378
chain_type::Type{Chains};
@@ -98,4 +103,35 @@ function AbstractMCMC.bundle_samples(
98103
)
99104
end
100105

106+
# Keep backward compatibility with LogDensityProblemsAD wrapper
107+
function AbstractMCMC.bundle_samples(
108+
ts::Vector{<:AdvancedHMC.Transition},
109+
logdensitymodel::AbstractMCMC.LogDensityModel{<:LogDensityProblemsAD.ADGradientWrapper},
110+
sampler::AdvancedHMC.AbstractHMCSampler,
111+
state,
112+
chain_type::Type{Chains};
113+
discard_initial=0,
114+
thinning=1,
115+
kwargs...,
116+
)
117+
param_samples = [t.z.θ for t in ts]
118+
119+
stats_names = collect(keys(merge((; lp=ts[1].z.ℓπ.value), AdvancedHMC.stat(ts[1]))))
120+
stats_values = [
121+
vcat(ts[i].z.ℓπ.value, collect(values(AdvancedHMC.stat(ts[i])))) for
122+
i in eachindex(ts)
123+
]
124+
125+
# Delegate to gen_chains for proper parameter naming from BUGSModel
126+
return JuliaBUGS.gen_chains(
127+
logdensitymodel,
128+
param_samples,
129+
stats_names,
130+
stats_values;
131+
discard_initial=discard_initial,
132+
thinning=thinning,
133+
kwargs...,
134+
)
135+
end
136+
101137
end

JuliaBUGS/ext/JuliaBUGSAdvancedMHExt.jl

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@ module JuliaBUGSAdvancedMHExt
33
using AbstractMCMC
44
using AdvancedMH
55
using ADTypes
6+
import DifferentiationInterface as DI
67
using JuliaBUGS
7-
using JuliaBUGS: BUGSModel, getparams, initialize!
8+
using JuliaBUGS: BUGSModel, BUGSModelWithGradient, getparams, initialize!
89
using JuliaBUGS.LogDensityProblems
910
using JuliaBUGS.LogDensityProblemsAD
11+
using JuliaBUGS.Model: _logdensity_switched
1012
using JuliaBUGS.Random
1113
using MCMCChains: Chains
1214

@@ -52,10 +54,13 @@ end
5254
function _gibbs_internal_mh(
5355
rng::Random.AbstractRNG, cond_model::BUGSModel, sampler, ad_backend, state
5456
)
55-
# Wrap model with AD gradient computation for gradient-based proposals
56-
logdensitymodel = AbstractMCMC.LogDensityModel(
57-
LogDensityProblemsAD.ADgradient(ad_backend, cond_model)
57+
# Create gradient model on-the-fly using DifferentiationInterface
58+
x = getparams(cond_model)
59+
prep = DI.prepare_gradient(
60+
_logdensity_switched, ad_backend, x, DI.Constant(cond_model)
5861
)
62+
ad_model = BUGSModelWithGradient(ad_backend, prep, cond_model)
63+
logdensitymodel = AbstractMCMC.LogDensityModel(ad_model)
5964

6065
# Take MH step with gradient information
6166
if isnothing(state)
@@ -64,7 +69,7 @@ function _gibbs_internal_mh(
6469
logdensitymodel,
6570
sampler;
6671
n_adapts=0, # Disable adaptation within Gibbs
67-
initial_params=getparams(cond_model),
72+
initial_params=x,
6873
)
6974
else
7075
t, s = AbstractMCMC.step(rng, logdensitymodel, sampler, state; n_adapts=0)
@@ -103,6 +108,32 @@ function AbstractMCMC.bundle_samples(
103108
)
104109
end
105110

111+
function AbstractMCMC.bundle_samples(
112+
ts::Vector{<:AdvancedMH.Transition},
113+
logdensitymodel::AbstractMCMC.LogDensityModel{<:BUGSModelWithGradient},
114+
sampler::AdvancedMH.MHSampler,
115+
state,
116+
chain_type::Type{Chains};
117+
discard_initial=0,
118+
thinning=1,
119+
kwargs...,
120+
)
121+
param_samples = [t.params for t in ts]
122+
stats_names = [:lp]
123+
stats_values = [[t.lp] for t in ts]
124+
125+
return JuliaBUGS.gen_chains(
126+
logdensitymodel,
127+
param_samples,
128+
stats_names,
129+
stats_values;
130+
discard_initial=discard_initial,
131+
thinning=thinning,
132+
kwargs...,
133+
)
134+
end
135+
136+
# Keep backward compatibility with LogDensityProblemsAD wrapper
106137
function AbstractMCMC.bundle_samples(
107138
ts::Vector{<:AdvancedMH.Transition},
108139
logdensitymodel::AbstractMCMC.LogDensityModel{<:LogDensityProblemsAD.ADGradientWrapper},

JuliaBUGS/ext/JuliaBUGSMCMCChainsExt.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module JuliaBUGSMCMCChainsExt
22

33
using AbstractMCMC
44
using JuliaBUGS
5-
using JuliaBUGS: BUGSModel, find_generated_quantities_variables, evaluate!!, getparams
5+
using JuliaBUGS: BUGSModel, BUGSModelWithGradient, find_generated_quantities_variables, evaluate!!, getparams
66
using JuliaBUGS.AbstractPPL
77
using JuliaBUGS.Accessors
88
using JuliaBUGS.LogDensityProblemsAD
@@ -21,6 +21,20 @@ function JuliaBUGS.gen_chains(
2121
)
2222
end
2323

24+
function JuliaBUGS.gen_chains(
25+
model::AbstractMCMC.LogDensityModel{<:BUGSModelWithGradient},
26+
samples,
27+
stats_names,
28+
stats_values;
29+
kwargs...,
30+
)
31+
# Extract BUGSModel from gradient wrapper
32+
bugs_model = model.logdensity.base_model
33+
34+
return JuliaBUGS.gen_chains(bugs_model, samples, stats_names, stats_values; kwargs...)
35+
end
36+
37+
# Keep backward compatibility with LogDensityProblemsAD wrapper
2438
function JuliaBUGS.gen_chains(
2539
model::AbstractMCMC.LogDensityModel{<:LogDensityProblemsAD.ADGradientWrapper},
2640
samples,

JuliaBUGS/src/JuliaBUGS.jl

Lines changed: 98 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using Accessors
66
using ADTypes
77
using BangBang
88
using Bijectors: Bijectors
9+
using DifferentiationInterface
910
using Distributions
1011
using Graphs, MetaGraphsNext
1112
using LinearAlgebra
@@ -17,6 +18,7 @@ using Serialization: Serialization
1718
using StaticArrays
1819

1920
import Base: ==, hash, Symbol, size
21+
import DifferentiationInterface as DI
2022
import Distributions: truncated
2123

2224
export @bugs
@@ -239,20 +241,56 @@ function validate_bugs_expression(expr, line_num)
239241
end
240242

241243
"""
242-
compile(model_def, data[, initial_params]; skip_validation=false)
244+
compile(model_def, data[, initial_params]; skip_validation=false, adtype=nothing)
243245
244246
Compile the model with model definition and data. Optionally, initializations can be provided.
245247
If initializations are not provided, values will be sampled from the prior distributions.
246248
247249
By default, validates that all functions in the model are in the BUGS allowlist (suitable for @bugs macro).
248250
Set `skip_validation=true` to skip validation (for @model macro usage).
251+
252+
If `adtype` is provided, returns a `BUGSModelWithGradient` that supports gradient-based MCMC
253+
samplers like HMC/NUTS. The gradient computation is prepared during compilation for optimal performance.
254+
255+
# Arguments
256+
- `model_def::Expr`: Model definition from @bugs macro
257+
- `data::NamedTuple`: Observed data
258+
- `initial_params::NamedTuple=NamedTuple()`: Initial parameter values (optional)
259+
- `skip_validation::Bool=false`: Skip function validation (for @model macro)
260+
- `eval_module::Module=@__MODULE__`: Module for evaluation
261+
- `adtype`: AD backend specification. Can be:
262+
- `AutoReverseDiff(compile=true)` - ReverseDiff with tape compilation (fastest)
263+
- `AutoReverseDiff(compile=false)` - ReverseDiff without compilation
264+
- `:ReverseDiff` - Shorthand for `AutoReverseDiff(compile=true)`
265+
- `:ForwardDiff` - Shorthand for `AutoForwardDiff()`
266+
- `:Zygote` - Shorthand for `AutoZygote()`
267+
- Any other `ADTypes.AbstractADType`
268+
269+
# Examples
270+
```julia
271+
# Basic compilation
272+
model = compile(model_def, data)
273+
274+
# With gradient support using explicit ADType
275+
model = compile(model_def, data; adtype=AutoReverseDiff(compile=true))
276+
277+
# With gradient support using symbol shorthand
278+
model = compile(model_def, data; adtype=:ReverseDiff) # Same as above
279+
280+
# Using ForwardDiff for small models
281+
model = compile(model_def, data; adtype=:ForwardDiff)
282+
283+
# Sample with NUTS
284+
chain = AbstractMCMC.sample(model, NUTS(0.8), 1000)
285+
```
249286
"""
250287
function compile(
251288
model_def::Expr,
252289
data::NamedTuple,
253290
initial_params::NamedTuple=NamedTuple();
254291
skip_validation::Bool=false,
255292
eval_module::Module=@__MODULE__,
293+
adtype::Union{Nothing,ADTypes.AbstractADType,Symbol}=nothing,
256294
)
257295
# Validate functions by default (for @bugs macro usage)
258296
# Skip validation only for @model macro
@@ -281,7 +319,65 @@ function compile(
281319
values(eval_env),
282320
),
283321
)
284-
return BUGSModel(g, nonmissing_eval_env, model_def, data, initial_params)
322+
base_model = BUGSModel(g, nonmissing_eval_env, model_def, data, initial_params)
323+
324+
# If adtype provided, wrap with gradient capabilities
325+
if adtype !== nothing
326+
# Convert symbol to ADType if needed
327+
adtype_obj = _resolve_adtype(adtype)
328+
return _wrap_with_gradient(base_model, adtype_obj)
329+
end
330+
331+
return base_model
332+
end
333+
334+
"""
335+
_resolve_adtype(adtype) -> ADTypes.AbstractADType
336+
337+
Convert symbol shortcuts to ADTypes, or return the ADType as-is.
338+
339+
Supported symbol shortcuts:
340+
- `:ReverseDiff` -> `AutoReverseDiff(compile=true)`
341+
- `:ForwardDiff` -> `AutoForwardDiff()`
342+
- `:Zygote` -> `AutoZygote()`
343+
- `:Enzyme` -> `AutoEnzyme()`
344+
"""
345+
function _resolve_adtype(adtype::Symbol)
346+
if adtype === :ReverseDiff
347+
return ADTypes.AutoReverseDiff(compile=true)
348+
elseif adtype === :ForwardDiff
349+
return ADTypes.AutoForwardDiff()
350+
elseif adtype === :Zygote
351+
return ADTypes.AutoZygote()
352+
elseif adtype === :Enzyme
353+
return ADTypes.AutoEnzyme()
354+
else
355+
error("Unknown AD backend symbol: $adtype. " *
356+
"Supported symbols: :ReverseDiff, :ForwardDiff, :Zygote, :Enzyme. " *
357+
"Or use an ADTypes object like AutoReverseDiff(compile=true).")
358+
end
359+
end
360+
361+
# Pass through ADTypes objects unchanged
362+
_resolve_adtype(adtype::ADTypes.AbstractADType) = adtype
363+
364+
# Helper function to prepare gradient - separated to handle world age issues
365+
function _wrap_with_gradient(base_model::Model.BUGSModel, adtype::ADTypes.AbstractADType)
366+
# Get initial parameters for preparation
367+
# Use invokelatest to handle world age issues with generated functions
368+
x = Base.invokelatest(getparams, base_model)
369+
370+
# Prepare gradient using DifferentiationInterface
371+
# Use invokelatest to handle world age issues when calling logdensity during preparation
372+
prep = Base.invokelatest(
373+
DI.prepare_gradient,
374+
Model._logdensity_switched,
375+
adtype,
376+
x,
377+
DI.Constant(base_model)
378+
)
379+
380+
return Model.BUGSModelWithGradient(adtype, prep, base_model)
285381
end
286382
# function compile(
287383
# model_str::String,

JuliaBUGS/src/model/Model.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@ module Model
22

33
using Accessors
44
using AbstractPPL
5+
using ADTypes
56
using BangBang
67
using Bijectors
8+
import DifferentiationInterface as DI
79
using Distributions
810
using Graphs
911
using LinearAlgebra
@@ -21,5 +23,6 @@ include("logdensityproblems.jl")
2123
export parameters, variables, initialize!, getparams, settrans, set_evaluation_mode
2224
export regenerate_log_density_function, set_observed_values!
2325
export evaluate_with_rng!!, evaluate_with_env!!, evaluate_with_values!!
26+
export BUGSModelWithGradient, _logdensity_switched
2427

2528
end # Model

0 commit comments

Comments
 (0)