@@ -320,14 +320,14 @@ function compile(
320
320
),
321
321
)
322
322
base_model = BUGSModel (g, nonmissing_eval_env, model_def, data, initial_params)
323
-
323
+
324
324
# If adtype provided, wrap with gradient capabilities
325
325
if adtype != = nothing
326
326
# Convert symbol to ADType if needed
327
327
adtype_obj = _resolve_adtype (adtype)
328
328
return _wrap_with_gradient (base_model, adtype_obj)
329
329
end
330
-
330
+
331
331
return base_model
332
332
end
333
333
@@ -344,17 +344,19 @@ Supported symbol shortcuts:
344
344
"""
345
345
function _resolve_adtype (adtype:: Symbol )
346
346
if adtype === :ReverseDiff
347
- return ADTypes. AutoReverseDiff (compile= true )
347
+ return ADTypes. AutoReverseDiff (; compile= true )
348
348
elseif adtype === :ForwardDiff
349
349
return ADTypes. AutoForwardDiff ()
350
350
elseif adtype === :Zygote
351
351
return ADTypes. AutoZygote ()
352
352
elseif adtype === :Enzyme
353
353
return ADTypes. AutoEnzyme ()
354
354
else
355
- error (" Unknown AD backend symbol: $adtype . " *
356
- " Supported symbols: :ReverseDiff, :ForwardDiff, :Zygote, :Enzyme. " *
357
- " Or use an ADTypes object like AutoReverseDiff(compile=true)." )
355
+ error (
356
+ " Unknown AD backend symbol: $adtype . " *
357
+ " Supported symbols: :ReverseDiff, :ForwardDiff, :Zygote, :Enzyme. " *
358
+ " Or use an ADTypes object like AutoReverseDiff(compile=true)." ,
359
+ )
358
360
end
359
361
end
360
362
@@ -366,17 +368,13 @@ function _wrap_with_gradient(base_model::Model.BUGSModel, adtype::ADTypes.Abstra
366
368
# Get initial parameters for preparation
367
369
# Use invokelatest to handle world age issues with generated functions
368
370
x = Base. invokelatest (getparams, base_model)
369
-
371
+
370
372
# Prepare gradient using DifferentiationInterface
371
373
# Use invokelatest to handle world age issues when calling logdensity during preparation
372
374
prep = Base. invokelatest (
373
- DI. prepare_gradient,
374
- Model. _logdensity_switched,
375
- adtype,
376
- x,
377
- DI. Constant (base_model)
375
+ DI. prepare_gradient, Model. _logdensity_switched, adtype, x, DI. Constant (base_model)
378
376
)
379
-
377
+
380
378
return Model. BUGSModelWithGradient (adtype, prep, base_model)
381
379
end
382
380
# function compile(
0 commit comments