@@ -6,6 +6,7 @@ using Accessors
6
6
using ADTypes
7
7
using BangBang
8
8
using Bijectors: Bijectors
9
+ using DifferentiationInterface
9
10
using Distributions
10
11
using Graphs, MetaGraphsNext
11
12
using LinearAlgebra
@@ -17,6 +18,7 @@ using Serialization: Serialization
17
18
using StaticArrays
18
19
19
20
import Base: == , hash, Symbol, size
21
+ import DifferentiationInterface as DI
20
22
import Distributions: truncated
21
23
22
24
export @bugs
@@ -239,20 +241,56 @@ function validate_bugs_expression(expr, line_num)
239
241
end
240
242
241
243
"""
242
- compile(model_def, data[, initial_params]; skip_validation=false)
244
+ compile(model_def, data[, initial_params]; skip_validation=false, adtype=nothing )
243
245
244
246
Compile the model with model definition and data. Optionally, initializations can be provided.
245
247
If initializations are not provided, values will be sampled from the prior distributions.
246
248
247
249
By default, validates that all functions in the model are in the BUGS allowlist (suitable for @bugs macro).
248
250
Set `skip_validation=true` to skip validation (for @model macro usage).
251
+
252
+ If `adtype` is provided, returns a `BUGSModelWithGradient` that supports gradient-based MCMC
253
+ samplers like HMC/NUTS. The gradient computation is prepared during compilation for optimal performance.
254
+
255
+ # Arguments
256
+ - `model_def::Expr`: Model definition from @bugs macro
257
+ - `data::NamedTuple`: Observed data
258
+ - `initial_params::NamedTuple=NamedTuple()`: Initial parameter values (optional)
259
+ - `skip_validation::Bool=false`: Skip function validation (for @model macro)
260
+ - `eval_module::Module=@__MODULE__`: Module for evaluation
261
+ - `adtype`: AD backend specification. Can be:
262
+ - `AutoReverseDiff(compile=true)` - ReverseDiff with tape compilation (fastest)
263
+ - `AutoReverseDiff(compile=false)` - ReverseDiff without compilation
264
+ - `:ReverseDiff` - Shorthand for `AutoReverseDiff(compile=true)`
265
+ - `:ForwardDiff` - Shorthand for `AutoForwardDiff()`
266
+ - `:Zygote` - Shorthand for `AutoZygote()`
267
+ - Any other `ADTypes.AbstractADType`
268
+
269
+ # Examples
270
+ ```julia
271
+ # Basic compilation
272
+ model = compile(model_def, data)
273
+
274
+ # With gradient support using explicit ADType
275
+ model = compile(model_def, data; adtype=AutoReverseDiff(compile=true))
276
+
277
+ # With gradient support using symbol shorthand
278
+ model = compile(model_def, data; adtype=:ReverseDiff) # Same as above
279
+
280
+ # Using ForwardDiff for small models
281
+ model = compile(model_def, data; adtype=:ForwardDiff)
282
+
283
+ # Sample with NUTS
284
+ chain = AbstractMCMC.sample(model, NUTS(0.8), 1000)
285
+ ```
249
286
"""
250
287
function compile (
251
288
model_def:: Expr ,
252
289
data:: NamedTuple ,
253
290
initial_params:: NamedTuple = NamedTuple ();
254
291
skip_validation:: Bool = false ,
255
292
eval_module:: Module = @__MODULE__ ,
293
+ adtype:: Union{Nothing,ADTypes.AbstractADType,Symbol} = nothing ,
256
294
)
257
295
# Validate functions by default (for @bugs macro usage)
258
296
# Skip validation only for @model macro
@@ -281,7 +319,65 @@ function compile(
281
319
values (eval_env),
282
320
),
283
321
)
284
- return BUGSModel (g, nonmissing_eval_env, model_def, data, initial_params)
322
+ base_model = BUGSModel (g, nonmissing_eval_env, model_def, data, initial_params)
323
+
324
+ # If adtype provided, wrap with gradient capabilities
325
+ if adtype != = nothing
326
+ # Convert symbol to ADType if needed
327
+ adtype_obj = _resolve_adtype (adtype)
328
+ return _wrap_with_gradient (base_model, adtype_obj)
329
+ end
330
+
331
+ return base_model
332
+ end
333
+
334
+ """
335
+ _resolve_adtype(adtype) -> ADTypes.AbstractADType
336
+
337
+ Convert symbol shortcuts to ADTypes, or return the ADType as-is.
338
+
339
+ Supported symbol shortcuts:
340
+ - `:ReverseDiff` -> `AutoReverseDiff(compile=true)`
341
+ - `:ForwardDiff` -> `AutoForwardDiff()`
342
+ - `:Zygote` -> `AutoZygote()`
343
+ - `:Enzyme` -> `AutoEnzyme()`
344
+ """
345
+ function _resolve_adtype (adtype:: Symbol )
346
+ if adtype === :ReverseDiff
347
+ return ADTypes. AutoReverseDiff (compile= true )
348
+ elseif adtype === :ForwardDiff
349
+ return ADTypes. AutoForwardDiff ()
350
+ elseif adtype === :Zygote
351
+ return ADTypes. AutoZygote ()
352
+ elseif adtype === :Enzyme
353
+ return ADTypes. AutoEnzyme ()
354
+ else
355
+ error (" Unknown AD backend symbol: $adtype . " *
356
+ " Supported symbols: :ReverseDiff, :ForwardDiff, :Zygote, :Enzyme. " *
357
+ " Or use an ADTypes object like AutoReverseDiff(compile=true)." )
358
+ end
359
+ end
360
+
361
+ # Pass through ADTypes objects unchanged
362
+ _resolve_adtype (adtype:: ADTypes.AbstractADType ) = adtype
363
+
364
+ # Helper function to prepare gradient - separated to handle world age issues
365
+ function _wrap_with_gradient (base_model:: Model.BUGSModel , adtype:: ADTypes.AbstractADType )
366
+ # Get initial parameters for preparation
367
+ # Use invokelatest to handle world age issues with generated functions
368
+ x = Base. invokelatest (getparams, base_model)
369
+
370
+ # Prepare gradient using DifferentiationInterface
371
+ # Use invokelatest to handle world age issues when calling logdensity during preparation
372
+ prep = Base. invokelatest (
373
+ DI. prepare_gradient,
374
+ Model. _logdensity_switched,
375
+ adtype,
376
+ x,
377
+ DI. Constant (base_model)
378
+ )
379
+
380
+ return Model. BUGSModelWithGradient (adtype, prep, base_model)
285
381
end
286
382
# function compile(
287
383
# model_str::String,
0 commit comments