From 8dfc80005b7c238475bad2dd00330438268d486a Mon Sep 17 00:00:00 2001 From: Shravan Goswami Date: Thu, 2 Oct 2025 22:38:52 +0530 Subject: [PATCH 1/7] Simplify the workflow for computing model gradients --- .gitignore | 4 +- JuliaBUGS/Project.toml | 2 + JuliaBUGS/ext/JuliaBUGSAdvancedHMCExt.jl | 48 +++++++-- JuliaBUGS/ext/JuliaBUGSAdvancedMHExt.jl | 41 ++++++- JuliaBUGS/ext/JuliaBUGSMCMCChainsExt.jl | 16 ++- JuliaBUGS/src/JuliaBUGS.jl | 100 +++++++++++++++++- JuliaBUGS/src/model/Model.jl | 3 + JuliaBUGS/src/model/logdensityproblems.jl | 82 ++++++++++++++ .../test/BUGSPrimitives/distributions.jl | 3 +- JuliaBUGS/test/Project.toml | 2 + JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl | 62 +++++++++-- JuliaBUGS/test/ext/JuliaBUGSMCMCChainsExt.jl | 10 +- JuliaBUGS/test/model/bugsmodel.jl | 52 +++++++++ JuliaBUGS/test/parallel_sampling.jl | 5 +- 14 files changed, 397 insertions(+), 33 deletions(-) diff --git a/.gitignore b/.gitignore index 779fd4d05..e5c898ac2 100644 --- a/.gitignore +++ b/.gitignore @@ -29,4 +29,6 @@ Manifest.toml *.local.* # gitingest generated files -digest.txt \ No newline at end of file +digest.txt + +tmp/ \ No newline at end of file diff --git a/JuliaBUGS/Project.toml b/JuliaBUGS/Project.toml index 8077d3517..5cfec993d 100644 --- a/JuliaBUGS/Project.toml +++ b/JuliaBUGS/Project.toml @@ -9,6 +9,7 @@ AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" @@ -52,6 +53,7 @@ AdvancedHMC = "0.6, 0.7, 0.8" AdvancedMH = "0.8" BangBang = "0.4.1" Bijectors = "0.13, 0.14, 0.15.5" +DifferentiationInterface = "0.7" Distributions = "0.23.8, 0.24, 0.25" Documenter = "0.27, 1" GLMakie = "0.10, 0.11, 0.12, 0.13" diff --git a/JuliaBUGS/ext/JuliaBUGSAdvancedHMCExt.jl b/JuliaBUGS/ext/JuliaBUGSAdvancedHMCExt.jl index 179ef02b5..bf052e3ed 100644 --- a/JuliaBUGS/ext/JuliaBUGSAdvancedHMCExt.jl +++ b/JuliaBUGS/ext/JuliaBUGSAdvancedHMCExt.jl @@ -3,10 +3,12 @@ module JuliaBUGSAdvancedHMCExt using AbstractMCMC using AdvancedHMC using ADTypes +import DifferentiationInterface as DI using JuliaBUGS -using JuliaBUGS: BUGSModel, getparams, initialize! +using JuliaBUGS: BUGSModel, BUGSModelWithGradient, getparams, initialize! using JuliaBUGS.LogDensityProblems using JuliaBUGS.LogDensityProblemsAD +using JuliaBUGS.Model: _logdensity_switched using JuliaBUGS.Random using MCMCChains: Chains @@ -40,10 +42,13 @@ end function _gibbs_internal_hmc( rng::Random.AbstractRNG, cond_model::BUGSModel, sampler, ad_backend, state ) - # Wrap model with AD gradient computation - logdensitymodel = AbstractMCMC.LogDensityModel( - LogDensityProblemsAD.ADgradient(ad_backend, cond_model) + # Create gradient model on-the-fly using DifferentiationInterface + x = getparams(cond_model) + prep = DI.prepare_gradient( + _logdensity_switched, ad_backend, x, DI.Constant(cond_model) ) + ad_model = BUGSModelWithGradient(ad_backend, prep, cond_model) + logdensitymodel = AbstractMCMC.LogDensityModel(ad_model) # Take HMC/NUTS step if isnothing(state) @@ -53,7 +58,7 @@ function _gibbs_internal_hmc( logdensitymodel, sampler; n_adapts=0, # Disable adaptation within Gibbs - initial_params=getparams(cond_model), + initial_params=x, ) else # Use existing state for subsequent steps @@ -67,7 +72,7 @@ end function AbstractMCMC.bundle_samples( ts::Vector{<:AdvancedHMC.Transition}, - logdensitymodel::AbstractMCMC.LogDensityModel{<:LogDensityProblemsAD.ADGradientWrapper}, + logdensitymodel::AbstractMCMC.LogDensityModel{<:BUGSModelWithGradient}, sampler::AdvancedHMC.AbstractHMCSampler, state, chain_type::Type{Chains}; @@ -98,4 +103,35 @@ function AbstractMCMC.bundle_samples( ) end +# Keep backward compatibility with LogDensityProblemsAD wrapper +function AbstractMCMC.bundle_samples( + ts::Vector{<:AdvancedHMC.Transition}, + logdensitymodel::AbstractMCMC.LogDensityModel{<:LogDensityProblemsAD.ADGradientWrapper}, + sampler::AdvancedHMC.AbstractHMCSampler, + state, + chain_type::Type{Chains}; + discard_initial=0, + thinning=1, + kwargs..., +) + param_samples = [t.z.θ for t in ts] + + stats_names = collect(keys(merge((; lp=ts[1].z.ℓπ.value), AdvancedHMC.stat(ts[1])))) + stats_values = [ + vcat(ts[i].z.ℓπ.value, collect(values(AdvancedHMC.stat(ts[i])))) for + i in eachindex(ts) + ] + + # Delegate to gen_chains for proper parameter naming from BUGSModel + return JuliaBUGS.gen_chains( + logdensitymodel, + param_samples, + stats_names, + stats_values; + discard_initial=discard_initial, + thinning=thinning, + kwargs..., + ) +end + end diff --git a/JuliaBUGS/ext/JuliaBUGSAdvancedMHExt.jl b/JuliaBUGS/ext/JuliaBUGSAdvancedMHExt.jl index ca30555be..1d07ade50 100644 --- a/JuliaBUGS/ext/JuliaBUGSAdvancedMHExt.jl +++ b/JuliaBUGS/ext/JuliaBUGSAdvancedMHExt.jl @@ -3,10 +3,12 @@ module JuliaBUGSAdvancedMHExt using AbstractMCMC using AdvancedMH using ADTypes +import DifferentiationInterface as DI using JuliaBUGS -using JuliaBUGS: BUGSModel, getparams, initialize! +using JuliaBUGS: BUGSModel, BUGSModelWithGradient, getparams, initialize! using JuliaBUGS.LogDensityProblems using JuliaBUGS.LogDensityProblemsAD +using JuliaBUGS.Model: _logdensity_switched using JuliaBUGS.Random using MCMCChains: Chains @@ -52,10 +54,13 @@ end function _gibbs_internal_mh( rng::Random.AbstractRNG, cond_model::BUGSModel, sampler, ad_backend, state ) - # Wrap model with AD gradient computation for gradient-based proposals - logdensitymodel = AbstractMCMC.LogDensityModel( - LogDensityProblemsAD.ADgradient(ad_backend, cond_model) + # Create gradient model on-the-fly using DifferentiationInterface + x = getparams(cond_model) + prep = DI.prepare_gradient( + _logdensity_switched, ad_backend, x, DI.Constant(cond_model) ) + ad_model = BUGSModelWithGradient(ad_backend, prep, cond_model) + logdensitymodel = AbstractMCMC.LogDensityModel(ad_model) # Take MH step with gradient information if isnothing(state) @@ -64,7 +69,7 @@ function _gibbs_internal_mh( logdensitymodel, sampler; n_adapts=0, # Disable adaptation within Gibbs - initial_params=getparams(cond_model), + initial_params=x, ) else t, s = AbstractMCMC.step(rng, logdensitymodel, sampler, state; n_adapts=0) @@ -103,6 +108,32 @@ function AbstractMCMC.bundle_samples( ) end +function AbstractMCMC.bundle_samples( + ts::Vector{<:AdvancedMH.Transition}, + logdensitymodel::AbstractMCMC.LogDensityModel{<:BUGSModelWithGradient}, + sampler::AdvancedMH.MHSampler, + state, + chain_type::Type{Chains}; + discard_initial=0, + thinning=1, + kwargs..., +) + param_samples = [t.params for t in ts] + stats_names = [:lp] + stats_values = [[t.lp] for t in ts] + + return JuliaBUGS.gen_chains( + logdensitymodel, + param_samples, + stats_names, + stats_values; + discard_initial=discard_initial, + thinning=thinning, + kwargs..., + ) +end + +# Keep backward compatibility with LogDensityProblemsAD wrapper function AbstractMCMC.bundle_samples( ts::Vector{<:AdvancedMH.Transition}, logdensitymodel::AbstractMCMC.LogDensityModel{<:LogDensityProblemsAD.ADGradientWrapper}, diff --git a/JuliaBUGS/ext/JuliaBUGSMCMCChainsExt.jl b/JuliaBUGS/ext/JuliaBUGSMCMCChainsExt.jl index eec864093..40d77e848 100644 --- a/JuliaBUGS/ext/JuliaBUGSMCMCChainsExt.jl +++ b/JuliaBUGS/ext/JuliaBUGSMCMCChainsExt.jl @@ -2,7 +2,7 @@ module JuliaBUGSMCMCChainsExt using AbstractMCMC using JuliaBUGS -using JuliaBUGS: BUGSModel, find_generated_quantities_variables, evaluate!!, getparams +using JuliaBUGS: BUGSModel, BUGSModelWithGradient, find_generated_quantities_variables, evaluate!!, getparams using JuliaBUGS.AbstractPPL using JuliaBUGS.Accessors using JuliaBUGS.LogDensityProblemsAD @@ -21,6 +21,20 @@ function JuliaBUGS.gen_chains( ) end +function JuliaBUGS.gen_chains( + model::AbstractMCMC.LogDensityModel{<:BUGSModelWithGradient}, + samples, + stats_names, + stats_values; + kwargs..., +) + # Extract BUGSModel from gradient wrapper + bugs_model = model.logdensity.base_model + + return JuliaBUGS.gen_chains(bugs_model, samples, stats_names, stats_values; kwargs...) +end + +# Keep backward compatibility with LogDensityProblemsAD wrapper function JuliaBUGS.gen_chains( model::AbstractMCMC.LogDensityModel{<:LogDensityProblemsAD.ADGradientWrapper}, samples, diff --git a/JuliaBUGS/src/JuliaBUGS.jl b/JuliaBUGS/src/JuliaBUGS.jl index e340d8470..3e780011c 100644 --- a/JuliaBUGS/src/JuliaBUGS.jl +++ b/JuliaBUGS/src/JuliaBUGS.jl @@ -6,6 +6,7 @@ using Accessors using ADTypes using BangBang using Bijectors: Bijectors +using DifferentiationInterface using Distributions using Graphs, MetaGraphsNext using LinearAlgebra @@ -17,6 +18,7 @@ using Serialization: Serialization using StaticArrays import Base: ==, hash, Symbol, size +import DifferentiationInterface as DI import Distributions: truncated export @bugs @@ -239,13 +241,48 @@ function validate_bugs_expression(expr, line_num) end """ - compile(model_def, data[, initial_params]; skip_validation=false) + compile(model_def, data[, initial_params]; skip_validation=false, adtype=nothing) Compile the model with model definition and data. Optionally, initializations can be provided. If initializations are not provided, values will be sampled from the prior distributions. By default, validates that all functions in the model are in the BUGS allowlist (suitable for @bugs macro). Set `skip_validation=true` to skip validation (for @model macro usage). + +If `adtype` is provided, returns a `BUGSModelWithGradient` that supports gradient-based MCMC +samplers like HMC/NUTS. The gradient computation is prepared during compilation for optimal performance. + +# Arguments +- `model_def::Expr`: Model definition from @bugs macro +- `data::NamedTuple`: Observed data +- `initial_params::NamedTuple=NamedTuple()`: Initial parameter values (optional) +- `skip_validation::Bool=false`: Skip function validation (for @model macro) +- `eval_module::Module=@__MODULE__`: Module for evaluation +- `adtype`: AD backend specification. Can be: + - `AutoReverseDiff(compile=true)` - ReverseDiff with tape compilation (fastest) + - `AutoReverseDiff(compile=false)` - ReverseDiff without compilation + - `:ReverseDiff` - Shorthand for `AutoReverseDiff(compile=true)` + - `:ForwardDiff` - Shorthand for `AutoForwardDiff()` + - `:Zygote` - Shorthand for `AutoZygote()` + - Any other `ADTypes.AbstractADType` + +# Examples +```julia +# Basic compilation +model = compile(model_def, data) + +# With gradient support using explicit ADType +model = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) + +# With gradient support using symbol shorthand +model = compile(model_def, data; adtype=:ReverseDiff) # Same as above + +# Using ForwardDiff for small models +model = compile(model_def, data; adtype=:ForwardDiff) + +# Sample with NUTS +chain = AbstractMCMC.sample(model, NUTS(0.8), 1000) +``` """ function compile( model_def::Expr, @@ -253,6 +290,7 @@ function compile( initial_params::NamedTuple=NamedTuple(); skip_validation::Bool=false, eval_module::Module=@__MODULE__, + adtype::Union{Nothing,ADTypes.AbstractADType,Symbol}=nothing, ) # Validate functions by default (for @bugs macro usage) # Skip validation only for @model macro @@ -281,7 +319,65 @@ function compile( values(eval_env), ), ) - return BUGSModel(g, nonmissing_eval_env, model_def, data, initial_params) + base_model = BUGSModel(g, nonmissing_eval_env, model_def, data, initial_params) + + # If adtype provided, wrap with gradient capabilities + if adtype !== nothing + # Convert symbol to ADType if needed + adtype_obj = _resolve_adtype(adtype) + return _wrap_with_gradient(base_model, adtype_obj) + end + + return base_model +end + +""" + _resolve_adtype(adtype) -> ADTypes.AbstractADType + +Convert symbol shortcuts to ADTypes, or return the ADType as-is. + +Supported symbol shortcuts: +- `:ReverseDiff` -> `AutoReverseDiff(compile=true)` +- `:ForwardDiff` -> `AutoForwardDiff()` +- `:Zygote` -> `AutoZygote()` +- `:Enzyme` -> `AutoEnzyme()` +""" +function _resolve_adtype(adtype::Symbol) + if adtype === :ReverseDiff + return ADTypes.AutoReverseDiff(compile=true) + elseif adtype === :ForwardDiff + return ADTypes.AutoForwardDiff() + elseif adtype === :Zygote + return ADTypes.AutoZygote() + elseif adtype === :Enzyme + return ADTypes.AutoEnzyme() + else + error("Unknown AD backend symbol: $adtype. " * + "Supported symbols: :ReverseDiff, :ForwardDiff, :Zygote, :Enzyme. " * + "Or use an ADTypes object like AutoReverseDiff(compile=true).") + end +end + +# Pass through ADTypes objects unchanged +_resolve_adtype(adtype::ADTypes.AbstractADType) = adtype + +# Helper function to prepare gradient - separated to handle world age issues +function _wrap_with_gradient(base_model::Model.BUGSModel, adtype::ADTypes.AbstractADType) + # Get initial parameters for preparation + # Use invokelatest to handle world age issues with generated functions + x = Base.invokelatest(getparams, base_model) + + # Prepare gradient using DifferentiationInterface + # Use invokelatest to handle world age issues when calling logdensity during preparation + prep = Base.invokelatest( + DI.prepare_gradient, + Model._logdensity_switched, + adtype, + x, + DI.Constant(base_model) + ) + + return Model.BUGSModelWithGradient(adtype, prep, base_model) end # function compile( # model_str::String, diff --git a/JuliaBUGS/src/model/Model.jl b/JuliaBUGS/src/model/Model.jl index 37ca24aa8..2efe7adb1 100644 --- a/JuliaBUGS/src/model/Model.jl +++ b/JuliaBUGS/src/model/Model.jl @@ -2,8 +2,10 @@ module Model using Accessors using AbstractPPL +using ADTypes using BangBang using Bijectors +import DifferentiationInterface as DI using Distributions using Graphs using LinearAlgebra @@ -21,5 +23,6 @@ include("logdensityproblems.jl") export parameters, variables, initialize!, getparams, settrans, set_evaluation_mode export regenerate_log_density_function, set_observed_values! export evaluate_with_rng!!, evaluate_with_env!!, evaluate_with_values!! +export BUGSModelWithGradient, _logdensity_switched end # Model diff --git a/JuliaBUGS/src/model/logdensityproblems.jl b/JuliaBUGS/src/model/logdensityproblems.jl index 07d82b018..1b0381cae 100644 --- a/JuliaBUGS/src/model/logdensityproblems.jl +++ b/JuliaBUGS/src/model/logdensityproblems.jl @@ -24,3 +24,85 @@ end function LogDensityProblems.capabilities(::AbstractBUGSModel) return LogDensityProblems.LogDensityOrder{0}() end + +""" + BUGSModelWithGradient{B,P,M} + +Wraps a BUGSModel with automatic differentiation capabilities using DifferentiationInterface. +Implements both `LogDensityProblems.logdensity` and `LogDensityProblems.logdensity_and_gradient`. + +# Fields +- `backend::B`: ADTypes backend (e.g., AutoReverseDiff()) +- `prep::P`: Prepared gradient from DifferentiationInterface (can be nothing) +- `base_model::M`: The underlying BUGSModel + +# Example +```julia +model_def = @bugs begin + x ~ dnorm(0, 1) +end +data = NamedTuple() + +# Create model with gradient capabilities +model = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) + +# Use with gradient-based MCMC +chain = AbstractMCMC.sample(rng, model, NUTS(0.8), 1000) +``` +""" +struct BUGSModelWithGradient{B<:ADTypes.AbstractADType,P,M<:BUGSModel} + backend::B + prep::P + base_model::M +end + +# Forward base BUGSModel interface +function LogDensityProblems.logdensity(model::BUGSModelWithGradient, x::AbstractVector) + return LogDensityProblems.logdensity(model.base_model, x) +end + +function LogDensityProblems.dimension(model::BUGSModelWithGradient) + return LogDensityProblems.dimension(model.base_model) +end + +function LogDensityProblems.capabilities(::Type{<:BUGSModelWithGradient}) + return LogDensityProblems.LogDensityOrder{1}() # Gradient available +end + +""" + _logdensity_switched(x, base_model_constant) + +Helper function that switches argument order for DifferentiationInterface compatibility. +DI expects the active argument (to differentiate w.r.t.) to come first. +""" +function _logdensity_switched(x::AbstractVector, base_model_constant::DI.Constant) + base_model = DI.unwrap(base_model_constant) + return LogDensityProblems.logdensity(base_model, x) +end + +# Fallback for testing during preparation (when DI calls without Constant wrapper) +function _logdensity_switched(x::AbstractVector, base_model::BUGSModel) + return LogDensityProblems.logdensity(base_model, x) +end + +""" + LogDensityProblems.logdensity_and_gradient(model::BUGSModelWithGradient, x) + +Compute log density and its gradient using DifferentiationInterface. +Uses prepared gradient if available, otherwise falls back to unprepared computation. +""" +function LogDensityProblems.logdensity_and_gradient( + model::BUGSModelWithGradient, x::AbstractVector +) + # Active argument (x) comes first for DI + # Base model passed as Constant context + if model.prep === nothing + return DI.value_and_gradient( + _logdensity_switched, model.backend, x, DI.Constant(model.base_model) + ) + else + return DI.value_and_gradient( + _logdensity_switched, model.prep, model.backend, x, DI.Constant(model.base_model) + ) + end +end diff --git a/JuliaBUGS/test/BUGSPrimitives/distributions.jl b/JuliaBUGS/test/BUGSPrimitives/distributions.jl index 69505e2f7..0262b051d 100644 --- a/JuliaBUGS/test/BUGSPrimitives/distributions.jl +++ b/JuliaBUGS/test/BUGSPrimitives/distributions.jl @@ -15,9 +15,8 @@ end A[1:2, 1:2] ~ dwish(B[:, :], 2) C[1:2] ~ dmnorm(mu[:], A[:, :]) end - model = compile(model_def, (mu=[0, 0], B=[1 0; 0 1]), (A=[1 0; 0 1],)) + ad_model = compile(model_def, (mu=[0, 0], B=[1 0; 0 1]), (A=[1 0; 0 1],); adtype=AutoReverseDiff()) - ad_model = ADgradient(:ReverseDiff, model) theta = [ 0.7931743744870574, 0.5151017206811268, diff --git a/JuliaBUGS/test/Project.toml b/JuliaBUGS/test/Project.toml index 9c02130ee..056862146 100644 --- a/JuliaBUGS/test/Project.toml +++ b/JuliaBUGS/test/Project.toml @@ -7,6 +7,7 @@ AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170" BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" @@ -36,6 +37,7 @@ AdvancedHMC = "0.6, 0.7" AdvancedMH = "0.8" BangBang = "0.4.1" ChainRules = "1" +DifferentiationInterface = "0.7" Distributions = "0.23.8, 0.24, 0.25" Documenter = "0.27, 1" Graphs = "1" diff --git a/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl b/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl index 3eafb4736..cb6a62c68 100644 --- a/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl +++ b/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl @@ -6,10 +6,9 @@ y = x[1] + x[3] end data = (mu=[0, 0], sigma=[1 0; 0 1]) - model = compile(model_def, data) - ad_model = ADgradient(:ReverseDiff, model; compile=Val(true)) + ad_model = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) n_samples, n_adapts = 10, 0 - D = LogDensityProblems.dimension(model) + D = LogDensityProblems.dimension(ad_model) initial_θ = rand(D) samples_and_stats = AbstractMCMC.sample( StableRNG(1234), @@ -27,18 +26,67 @@ [Symbol("x[3]"), Symbol("x[1:2][1]"), Symbol("x[1:2][2]"), :y] end + @testset "Symbol AD backend shortcuts" begin + model_def = @bugs begin + mu ~ dnorm(0, 1) + for i in 1:N + y[i] ~ dnorm(mu, 1) + end + end + data = (N=5, y=[1.0, 2.0, 1.5, 2.5, 1.8]) + + # Test that symbol shortcut works + ad_model_symbol = compile(model_def, data; adtype=:ReverseDiff) + ad_model_explicit = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) + + @test ad_model_symbol isa JuliaBUGS.Model.BUGSModelWithGradient + @test ad_model_explicit isa JuliaBUGS.Model.BUGSModelWithGradient + + # Test that both produce equivalent results + n_samples, n_adapts = 100, 100 + D = LogDensityProblems.dimension(ad_model_symbol) + initial_θ = rand(StableRNG(123), D) + + samples_symbol = AbstractMCMC.sample( + StableRNG(1234), + ad_model_symbol, + NUTS(0.8), + n_samples; + progress=false, + chain_type=Chains, + n_adapts=n_adapts, + init_params=initial_θ, + discard_initial=n_adapts, + ) + + samples_explicit = AbstractMCMC.sample( + StableRNG(1234), + ad_model_explicit, + NUTS(0.8), + n_samples; + progress=false, + chain_type=Chains, + n_adapts=n_adapts, + init_params=initial_θ, + discard_initial=n_adapts, + ) + + # Results should be very similar (same RNG seed) + @test summarize(samples_symbol)[:mu].nt.mean[1] ≈ + summarize(samples_explicit)[:mu].nt.mean[1] rtol=0.1 + end + @testset "Inference results on examples: $example" for example in [:seeds, :rats, :stacks] (; model_def, data, inits, reference_results) = Base.getfield( JuliaBUGS.BUGSExamples, example ) - model = JuliaBUGS.compile(model_def, data, inits) - ad_model = ADgradient(:ReverseDiff, model; compile=Val(true)) + ad_model = JuliaBUGS.compile(model_def, data, inits; adtype=AutoReverseDiff(compile=true)) n_samples, n_adapts = 1000, 1000 - D = LogDensityProblems.dimension(model) - initial_θ = JuliaBUGS.getparams(model) + D = LogDensityProblems.dimension(ad_model) + initial_θ = JuliaBUGS.getparams(ad_model.base_model) samples_and_stats = AbstractMCMC.sample( StableRNG(1234), diff --git a/JuliaBUGS/test/ext/JuliaBUGSMCMCChainsExt.jl b/JuliaBUGS/test/ext/JuliaBUGSMCMCChainsExt.jl index b6b512980..7d6f35832 100644 --- a/JuliaBUGS/test/ext/JuliaBUGSMCMCChainsExt.jl +++ b/JuliaBUGS/test/ext/JuliaBUGSMCMCChainsExt.jl @@ -26,11 +26,10 @@ y=[1.58, 4.80, 7.10, 8.86, 11.73, 14.52, 18.22, 18.73, 21.04, 22.93], ) - model = compile(model_def, data, (;)) - ad_model = ADgradient(:ReverseDiff, model; compile=Val(true)) + ad_model = compile(model_def, data, (;); adtype=AutoReverseDiff(compile=true)) n_samples, n_adapts = 2000, 1000 - D = LogDensityProblems.dimension(model) + D = LogDensityProblems.dimension(ad_model) initial_θ = rand(D) hmc_chain = AbstractMCMC.sample( @@ -73,7 +72,7 @@ n_samples, n_adapts = 20000, 5000 mh_chain = AbstractMCMC.sample( - model, + ad_model.base_model, RWMH(MvNormal(zeros(D), I)), n_samples; progress=false, @@ -107,8 +106,7 @@ sigma[2] ~ InverseGamma(2, 3) sigma[3] ~ InverseGamma(2, 3) end - model = compile(model_def, (;)) - ad_model = ADgradient(:ReverseDiff, model; compile=Val(true)) + ad_model = compile(model_def, (;); adtype=AutoReverseDiff(compile=true)) hmc_chain = AbstractMCMC.sample( ad_model, NUTS(0.8), 10; progress=false, chain_type=Chains ) diff --git a/JuliaBUGS/test/model/bugsmodel.jl b/JuliaBUGS/test/model/bugsmodel.jl index 3f602ad72..ddfe835c5 100644 --- a/JuliaBUGS/test/model/bugsmodel.jl +++ b/JuliaBUGS/test/model/bugsmodel.jl @@ -402,4 +402,56 @@ end @test occursin("Variable sizes and types:", output) end end + + @testset "AD Type Parameter" begin + model_def = @bugs begin + mu ~ dnorm(0, 1) + y ~ dnorm(mu, 1) + end + data = (y=1.5,) + + @testset "Symbol shortcuts" begin + # Test :ReverseDiff shortcut + model_rd = compile(model_def, data; adtype=:ReverseDiff) + @test model_rd isa JuliaBUGS.Model.BUGSModelWithGradient + + # Test equivalence with explicit ADType + model_explicit = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) + @test model_explicit isa JuliaBUGS.Model.BUGSModelWithGradient + + # Test that unknown symbol throws error + @test_throws ErrorException compile(model_def, data; adtype=:UnknownBackend) + end + + @testset "Explicit ADTypes" begin + # Test with compile=true + model_compile = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) + @test model_compile isa JuliaBUGS.Model.BUGSModelWithGradient + + # Test with compile=false + model_nocompile = compile(model_def, data; adtype=AutoReverseDiff(compile=false)) + @test model_nocompile isa JuliaBUGS.Model.BUGSModelWithGradient + end + + @testset "Default behavior (no adtype)" begin + # Without adtype, should return regular BUGSModel + model_default = compile(model_def, data) + @test model_default isa BUGSModel + @test !(model_default isa JuliaBUGS.Model.BUGSModelWithGradient) + end + + @testset "Gradient computation" begin + model = compile(model_def, data; adtype=:ReverseDiff) + test_point = [0.0] + + # Test that gradient can be computed + ℓ, grad = LogDensityProblems.logdensity_and_gradient(model, test_point) + + @test ℓ isa Real + @test grad isa Vector + @test length(grad) == 1 + @test isfinite(ℓ) + @test all(isfinite, grad) + end + end end diff --git a/JuliaBUGS/test/parallel_sampling.jl b/JuliaBUGS/test/parallel_sampling.jl index 7871aca7f..8b857e4a0 100644 --- a/JuliaBUGS/test/parallel_sampling.jl +++ b/JuliaBUGS/test/parallel_sampling.jl @@ -19,9 +19,8 @@ data = (N=N, x=x_data) inits = (mu=0.0, tau=1.0) - model = compile(model_def, data, inits) - # Use compile=Val(false) for thread safety with ReverseDiff - ad_model = ADgradient(:ReverseDiff, model; compile=Val(false)) + # Use compile=false for thread safety with ReverseDiff + ad_model = compile(model_def, data, inits; adtype=AutoReverseDiff(compile=false)) # Single chain reference n_samples = 200 From d77bbac62ee735aa50e96a02d5d6d0971d31412b Mon Sep 17 00:00:00 2001 From: Shravan Goswami Date: Thu, 2 Oct 2025 23:33:18 +0530 Subject: [PATCH 2/7] update docs and bump version to 0.10.4 --- JuliaBUGS/History.md | 10 ++ JuliaBUGS/Project.toml | 2 +- JuliaBUGS/docs/src/example.md | 200 +++++++++++++++++++++++++++------- 3 files changed, 169 insertions(+), 43 deletions(-) diff --git a/JuliaBUGS/History.md b/JuliaBUGS/History.md index cb5567431..066d5ad3a 100644 --- a/JuliaBUGS/History.md +++ b/JuliaBUGS/History.md @@ -1,5 +1,15 @@ # JuliaBUGS Changelog +## 0.10.4 + +- **DifferentiationInterface.jl integration**: JuliaBUGS now uses [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl) for automatic differentiation, providing a unified interface to multiple AD backends. + - Add `adtype` parameter to `compile()` function for specifying AD backends + - Support convenient symbol shortcuts: `:ReverseDiff`, `:ForwardDiff`, `:Zygote`, `:Enzyme` + - Gradient computation is prepared during compilation for optimal performance + - Example: `model = compile(model_def, data; adtype=:ReverseDiff)` + - Full control available via explicit ADTypes: `adtype=AutoReverseDiff(compile=true)` + - Backward compatible: models without `adtype` work as before + ## 0.10.1 Expose docs for changes in [v0.10.0](https://github.com/TuringLang/JuliaBUGS.jl/releases/tag/JuliaBUGS-v0.10.0) diff --git a/JuliaBUGS/Project.toml b/JuliaBUGS/Project.toml index 5cfec993d..65bd31780 100644 --- a/JuliaBUGS/Project.toml +++ b/JuliaBUGS/Project.toml @@ -1,6 +1,6 @@ name = "JuliaBUGS" uuid = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf" -version = "0.10.3" +version = "0.10.4" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/JuliaBUGS/docs/src/example.md b/JuliaBUGS/docs/src/example.md index eaba1a01e..3cf73a4cc 100644 --- a/JuliaBUGS/docs/src/example.md +++ b/JuliaBUGS/docs/src/example.md @@ -190,33 +190,54 @@ initialize!(model, initializations) initialize!(model, rand(26)) ``` -`LogDensityProblemsAD.jl` defined some extensions that support automatic differentiation packages. -For example, with `ReverseDiff.jl` +### Automatic Differentiation + +JuliaBUGS integrates with automatic differentiation (AD) through [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl), enabling gradient-based inference methods like Hamiltonian Monte Carlo (HMC) and No-U-Turn Sampler (NUTS). + +#### Specifying an AD Backend + +To compile a model with gradient support, pass the `adtype` parameter to `compile`: ```julia -using LogDensityProblemsAD, ReverseDiff +# Using explicit ADType from ADTypes.jl +using ADTypes +model = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) + +# Using convenient symbol shortcuts +model = compile(model_def, data; adtype=:ReverseDiff) # Equivalent to above +``` -ad_model = ADgradient(:ReverseDiff, model; compile=Val(true)) +Available AD backends include: +- `:ReverseDiff` - ReverseDiff with tape compilation (recommended for most models) +- `:ForwardDiff` - ForwardDiff (efficient for models with few parameters) +- `:Zygote` - Zygote (source-to-source AD) +- `:Enzyme` - Enzyme (experimental, high-performance) + +For fine-grained control, use explicit `ADTypes` constructors: + +```julia +# ReverseDiff without compilation +model = compile(model_def, data; adtype=AutoReverseDiff(compile=false)) ``` -Here `ad_model` will also implement all the interfaces of [`LogDensityProblems.jl`](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/). -`LogDensityProblemsAD.jl` will automatically add the interface function [`logdensity_and_gradient`](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.logdensity_and_gradient) to the model, which will return the log density and gradient of the model. -And `ad_model` can be used in the same way as `model` in the example below. +The compiled model with gradient support implements the [`LogDensityProblems.jl`](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/) interface, including [`logdensity_and_gradient`](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.logdensity_and_gradient), which returns both the log density and its gradient. ### Inference -For a differentiable model, we can use [`AdvancedHMC.jl`](https://github.com/TuringLang/AdvancedHMC.jl) to perform inference. -For instance, +For gradient-based inference, we use [`AdvancedHMC.jl`](https://github.com/TuringLang/AdvancedHMC.jl) with models compiled with an `adtype`: ```julia -using AdvancedHMC, AbstractMCMC, LogDensityProblems, MCMCChains +using AdvancedHMC, AbstractMCMC, LogDensityProblems, MCMCChains, ReverseDiff + +# Compile with gradient support +model = compile(model_def, data; adtype=:ReverseDiff) n_samples, n_adapts = 2000, 1000 D = LogDensityProblems.dimension(model); initial_θ = rand(D) samples_and_stats = AbstractMCMC.sample( - ad_model, + model, NUTS(0.8), n_samples; chain_type = Chains, @@ -224,6 +245,7 @@ samples_and_stats = AbstractMCMC.sample( init_params = initial_θ, discard_initial = n_adapts ) +describe(samples_and_stats) ``` This will return the MCMC Chain, @@ -234,39 +256,72 @@ Chains MCMC chain (2000×40×1 Array{Real, 3}): Iterations = 1001:1:3000 Number of chains = 1 Samples per chain = 2000 -parameters = alpha0, alpha12, alpha1, alpha2, tau, b[16], b[12], b[10], b[14], b[13], b[7], b[6], b[20], b[1], b[4], b[5], b[2], b[18], b[8], b[3], b[9], b[21], b[17], b[15], b[11], b[19], sigma +parameters = tau, alpha12, alpha2, alpha1, alpha0, b[1], b[2], b[3], b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11], b[12], b[13], b[14], b[15], b[16], b[17], b[18], b[19], b[20], b[21], sigma internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size, is_adapt Summary Statistics - parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec - Symbol Float64 Float64 Float64 Real Float64 Float64 Missing - - alpha0 -0.5642 0.2320 0.0084 766.9305 1022.5211 1.0021 missing - alpha12 -0.8489 0.5247 0.0170 946.0418 1044.1109 1.0002 missing - alpha1 0.0587 0.3715 0.0119 966.4367 1233.2257 1.0007 missing - alpha2 1.3852 0.3410 0.0127 712.2978 974.1566 1.0002 missing - tau 1.8880 0.7705 0.0447 348.9331 338.3655 1.0030 missing - b[16] -0.2445 0.4459 0.0132 1528.0578 843.8225 1.0003 missing - b[12] 0.2050 0.3602 0.0086 1868.6126 1202.1363 0.9996 missing - b[10] -0.3500 0.2893 0.0090 1047.3119 1245.9358 1.0008 missing - ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ - 19 rows omitted + parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec + Symbol Float64 Float64 Float64 Real Float64 Float64 Missing + + tau 73.1490 193.8441 43.2582 56.3430 20.6688 1.0155 missing + alpha12 -0.8052 0.4392 0.0158 761.2180 1049.1664 1.0020 missing + alpha2 1.3428 0.2813 0.0140 422.8810 1013.2570 1.0061 missing + alpha1 0.0845 0.3126 0.0113 773.2202 981.8487 1.0051 missing + alpha0 -0.5480 0.1944 0.0087 537.6212 1156.2083 1.0014 missing + b[1] -0.1905 0.2540 0.0129 374.3372 971.7526 1.0034 missing + b[2] 0.0161 0.2178 0.0056 1505.6353 1002.8787 1.0001 missing + b[3] -0.1986 0.2375 0.0128 367.6766 1287.8215 1.0015 missing + b[4] 0.2792 0.2498 0.0163 201.1558 1168.7538 1.0068 missing + b[5] 0.1170 0.2397 0.0092 659.5422 1484.8584 1.0016 missing + b[6] 0.0667 0.2821 0.0074 1745.5567 902.1014 1.0067 missing + b[7] 0.0597 0.2218 0.0055 1589.5590 1145.6017 1.0065 missing + b[8] 0.1769 0.2316 0.0102 554.5974 1318.8089 1.0001 missing + b[9] -0.1257 0.2233 0.0073 930.0346 1186.4283 1.0031 missing + b[10] -0.2513 0.2392 0.0159 213.6323 1142.4487 1.0096 missing + b[11] 0.0768 0.2783 0.0081 1376.5999 1218.1537 1.0009 missing + b[12] 0.1171 0.2768 0.0079 1354.9409 1130.8217 1.0052 missing + b[13] -0.0688 0.2433 0.0055 1895.0387 1527.7066 1.0010 missing + b[14] -0.1363 0.2558 0.0075 1276.0992 1208.8587 1.0001 missing + b[15] 0.2334 0.2757 0.0135 439.2241 837.3396 1.0036 missing + b[16] -0.1212 0.3024 0.0106 1093.4416 914.9457 0.9997 missing + b[17] -0.2120 0.3142 0.0166 360.6420 702.4098 1.0009 missing + b[18] 0.0346 0.2282 0.0056 1665.0325 1281.7179 1.0011 missing + b[19] -0.0244 0.2400 0.0052 2186.7638 1179.6971 1.0132 missing + b[20] 0.2108 0.2421 0.0131 349.7657 1263.5781 1.0016 missing + b[21] -0.0509 0.2813 0.0061 2200.5614 916.6256 0.9998 missing + sigma 0.2797 0.1362 0.0168 56.3430 21.4971 1.0123 missing Quantiles - parameters 2.5% 25.0% 50.0% 75.0% 97.5% - Symbol Float64 Float64 Float64 Float64 Float64 - - alpha0 -1.0143 -0.7143 -0.5590 -0.4100 -0.1185 - alpha12 -1.9063 -1.1812 -0.8296 -0.5153 0.1521 - alpha1 -0.6550 -0.1822 0.0512 0.2885 0.8180 - alpha2 0.7214 1.1663 1.3782 1.5998 2.0986 - tau 0.5461 1.3941 1.8353 2.3115 3.6225 - b[16] -1.2359 -0.4836 -0.1909 0.0345 0.5070 - b[12] -0.4493 -0.0370 0.1910 0.4375 0.9828 - b[10] -0.9570 -0.5264 -0.3331 -0.1514 0.1613 - ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ - 19 rows omitted - + parameters 2.5% 25.0% 50.0% 75.0% 97.5% + Symbol Float64 Float64 Float64 Float64 Float64 + + tau 3.1280 7.4608 13.0338 28.2289 929.6520 + alpha12 -1.6645 -1.0887 -0.7952 -0.5635 0.1162 + alpha2 0.8398 1.1494 1.3233 1.5337 1.9177 + alpha1 -0.5796 -0.1059 0.1042 0.2883 0.6702 + alpha0 -0.9340 -0.6751 -0.5463 -0.4086 -0.1752 + b[1] -0.7430 -0.3415 -0.1566 -0.0074 0.2535 + b[2] -0.4261 -0.1083 0.0192 0.1420 0.4810 + b[3] -0.7394 -0.3377 -0.1687 -0.0242 0.2041 + b[4] -0.1108 0.0873 0.2409 0.4375 0.8267 + b[5] -0.3141 -0.0458 0.0900 0.2563 0.6489 + b[6] -0.4679 -0.0896 0.0291 0.2202 0.7060 + b[7] -0.3861 -0.0685 0.0534 0.1847 0.5207 + b[8] -0.2326 0.0221 0.1505 0.3162 0.6861 + b[9] -0.6007 -0.2482 -0.0984 0.0057 0.2771 + b[10] -0.7936 -0.4108 -0.2255 -0.0617 0.1290 + b[11] -0.4381 -0.0796 0.0353 0.2178 0.7232 + b[12] -0.3806 -0.0451 0.0750 0.2671 0.7625 + b[13] -0.5841 -0.2135 -0.0443 0.0652 0.4055 + b[14] -0.6854 -0.2872 -0.1015 0.0147 0.3476 + b[15] -0.2054 0.0257 0.1898 0.4004 0.8660 + b[16] -0.8173 -0.2829 -0.0804 0.0532 0.4094 + b[17] -0.9071 -0.3911 -0.1595 0.0099 0.2864 + b[18] -0.4526 -0.0919 0.0140 0.1686 0.4985 + b[19] -0.5055 -0.1547 -0.0091 0.1134 0.4528 + b[20] -0.2120 0.0318 0.1788 0.3673 0.7416 + b[21] -0.6482 -0.2044 -0.0263 0.1051 0.5246 + sigma 0.0328 0.1882 0.2770 0.3661 0.5654 ``` This is consistent with the result in the [OpenBUGS seeds example](https://chjackson.github.io/openbugsdoc/Examples/Seeds.html). @@ -283,7 +338,7 @@ The model compilation code remains the same, and we can sample multiple chains i ```julia n_chains = 4 samples_and_stats = AbstractMCMC.sample( - ad_model, + model, AdvancedHMC.NUTS(0.65), AbstractMCMC.MCMCThreads(), n_samples, @@ -311,7 +366,7 @@ For example: ```julia @everywhere begin - using JuliaBUGS, LogDensityProblems, LogDensityProblemsAD, AbstractMCMC, AdvancedHMC, MCMCChains, ReverseDiff # also other packages one may need + using JuliaBUGS, LogDensityProblems, AbstractMCMC, AdvancedHMC, MCMCChains, ADTypes, ReverseDiff # Define the functions to use # Use `@bugs_primitive` to register the functions to use in the model @@ -322,7 +377,7 @@ end n_chains = nprocs() - 1 # use all the processes except the parent process samples_and_stats = AbstractMCMC.sample( - ad_model, + model, AdvancedHMC.NUTS(0.65), AbstractMCMC.MCMCDistributed(), n_samples, @@ -342,6 +397,67 @@ In this case, we pass two additional arguments to `AbstractMCMC.sample`: Note that the `init_params` argument is now a vector of initial parameters for each chain. Sometimes the progress logger can cause problems in distributed setting, so we can disable it by setting `progress = false`. +## Choosing an Automatic Differentiation Backend + +JuliaBUGS integrates with multiple automatic differentiation (AD) backends through [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl), providing flexibility to choose the most suitable backend for your model. + +### Available Backends + +The following AD backends are supported via convenient symbol shortcuts: + +- **`:ReverseDiff`** (recommended) — Tape-based reverse-mode AD, highly efficient for models with many parameters. Uses compilation by default for optimal performance. +- **`:ForwardDiff`** — Forward-mode AD, efficient for models with few parameters (typically < 20). +- **`:Zygote`** — Source-to-source reverse-mode AD, general-purpose but may be slower than ReverseDiff for many models. +- **`:Enzyme`** — Experimental high-performance AD backend with LLVM-level transformations. + +### Usage Examples + +#### Basic Usage with Symbol Shortcuts + +The simplest way to specify an AD backend is using symbol shortcuts: + +```julia +# ReverseDiff with tape compilation (recommended for most models) +model = compile(model_def, data; adtype=:ReverseDiff) + +# ForwardDiff (good for small models with few parameters) +model = compile(model_def, data; adtype=:ForwardDiff) + +# Zygote (source-to-source AD) +model = compile(model_def, data; adtype=:Zygote) +``` + +#### Advanced Configuration + +For fine-grained control, use explicit `ADTypes` constructors: + +```julia +using ADTypes + +# ReverseDiff without tape compilation +model = compile(model_def, data; adtype=AutoReverseDiff(compile=false)) + +# ReverseDiff with compilation (equivalent to :ReverseDiff) +model = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) +``` + +### Performance Considerations + +- **ReverseDiff with compilation** (`:ReverseDiff`) is recommended for most models, especially those with many parameters. Compilation adds a one-time overhead but significantly speeds up subsequent gradient evaluations. + +- **ForwardDiff** (`:ForwardDiff`) is often faster for models with few parameters (< 20), as it avoids tape construction overhead. + +- **Tape compilation trade-off**: While `AutoReverseDiff(compile=true)` has higher initial compilation time, it provides faster gradient evaluations during sampling. For quick prototyping or models that will only be sampled a few times, `AutoReverseDiff(compile=false)` may be preferable. + +### Compatibility + +All models compiled with an `adtype` implement the full [`LogDensityProblems.jl`](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/) interface, making them compatible with: + +- [`AdvancedHMC.jl`](https://github.com/TuringLang/AdvancedHMC.jl) — NUTS and HMC samplers +- Any other sampler that works with the LogDensityProblems interface + +The gradient computation is prepared during model compilation for optimal performance during sampling. + ## More Examples We have transcribed all the examples from the first volume of the BUGS Examples ([original](https://www.multibugs.org/examples/latest/VolumeI.html) and [transcribed](https://github.com/TuringLang/JuliaBUGS.jl/tree/main/JuliaBUGS/src/BUGSExamples/Volume_1)). All programs and data are included, and can be compiled using the steps described in the tutorial above. From 199ba715c61fe3006d209b5a4f85257f412435a2 Mon Sep 17 00:00:00 2001 From: Shravan Goswami Date: Thu, 2 Oct 2025 23:56:45 +0530 Subject: [PATCH 3/7] run JuliaFormatter --- JuliaBUGS/ext/JuliaBUGSAdvancedHMCExt.jl | 4 +--- JuliaBUGS/ext/JuliaBUGSAdvancedMHExt.jl | 4 +--- JuliaBUGS/ext/JuliaBUGSMCMCChainsExt.jl | 7 +++++- JuliaBUGS/src/JuliaBUGS.jl | 24 +++++++++---------- JuliaBUGS/src/model/logdensityproblems.jl | 6 ++++- .../test/BUGSPrimitives/distributions.jl | 4 +++- JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl | 24 ++++++++++--------- JuliaBUGS/test/ext/JuliaBUGSMCMCChainsExt.jl | 4 ++-- JuliaBUGS/test/model/bugsmodel.jl | 20 +++++++++------- JuliaBUGS/test/parallel_sampling.jl | 2 +- 10 files changed, 55 insertions(+), 44 deletions(-) diff --git a/JuliaBUGS/ext/JuliaBUGSAdvancedHMCExt.jl b/JuliaBUGS/ext/JuliaBUGSAdvancedHMCExt.jl index bf052e3ed..88f7a1d9b 100644 --- a/JuliaBUGS/ext/JuliaBUGSAdvancedHMCExt.jl +++ b/JuliaBUGS/ext/JuliaBUGSAdvancedHMCExt.jl @@ -44,9 +44,7 @@ function _gibbs_internal_hmc( ) # Create gradient model on-the-fly using DifferentiationInterface x = getparams(cond_model) - prep = DI.prepare_gradient( - _logdensity_switched, ad_backend, x, DI.Constant(cond_model) - ) + prep = DI.prepare_gradient(_logdensity_switched, ad_backend, x, DI.Constant(cond_model)) ad_model = BUGSModelWithGradient(ad_backend, prep, cond_model) logdensitymodel = AbstractMCMC.LogDensityModel(ad_model) diff --git a/JuliaBUGS/ext/JuliaBUGSAdvancedMHExt.jl b/JuliaBUGS/ext/JuliaBUGSAdvancedMHExt.jl index 1d07ade50..88c43249b 100644 --- a/JuliaBUGS/ext/JuliaBUGSAdvancedMHExt.jl +++ b/JuliaBUGS/ext/JuliaBUGSAdvancedMHExt.jl @@ -56,9 +56,7 @@ function _gibbs_internal_mh( ) # Create gradient model on-the-fly using DifferentiationInterface x = getparams(cond_model) - prep = DI.prepare_gradient( - _logdensity_switched, ad_backend, x, DI.Constant(cond_model) - ) + prep = DI.prepare_gradient(_logdensity_switched, ad_backend, x, DI.Constant(cond_model)) ad_model = BUGSModelWithGradient(ad_backend, prep, cond_model) logdensitymodel = AbstractMCMC.LogDensityModel(ad_model) diff --git a/JuliaBUGS/ext/JuliaBUGSMCMCChainsExt.jl b/JuliaBUGS/ext/JuliaBUGSMCMCChainsExt.jl index 40d77e848..a69007b74 100644 --- a/JuliaBUGS/ext/JuliaBUGSMCMCChainsExt.jl +++ b/JuliaBUGS/ext/JuliaBUGSMCMCChainsExt.jl @@ -2,7 +2,12 @@ module JuliaBUGSMCMCChainsExt using AbstractMCMC using JuliaBUGS -using JuliaBUGS: BUGSModel, BUGSModelWithGradient, find_generated_quantities_variables, evaluate!!, getparams +using JuliaBUGS: + BUGSModel, + BUGSModelWithGradient, + find_generated_quantities_variables, + evaluate!!, + getparams using JuliaBUGS.AbstractPPL using JuliaBUGS.Accessors using JuliaBUGS.LogDensityProblemsAD diff --git a/JuliaBUGS/src/JuliaBUGS.jl b/JuliaBUGS/src/JuliaBUGS.jl index 3e780011c..de31e43a7 100644 --- a/JuliaBUGS/src/JuliaBUGS.jl +++ b/JuliaBUGS/src/JuliaBUGS.jl @@ -320,14 +320,14 @@ function compile( ), ) base_model = BUGSModel(g, nonmissing_eval_env, model_def, data, initial_params) - + # If adtype provided, wrap with gradient capabilities if adtype !== nothing # Convert symbol to ADType if needed adtype_obj = _resolve_adtype(adtype) return _wrap_with_gradient(base_model, adtype_obj) end - + return base_model end @@ -344,7 +344,7 @@ Supported symbol shortcuts: """ function _resolve_adtype(adtype::Symbol) if adtype === :ReverseDiff - return ADTypes.AutoReverseDiff(compile=true) + return ADTypes.AutoReverseDiff(; compile=true) elseif adtype === :ForwardDiff return ADTypes.AutoForwardDiff() elseif adtype === :Zygote @@ -352,9 +352,11 @@ function _resolve_adtype(adtype::Symbol) elseif adtype === :Enzyme return ADTypes.AutoEnzyme() else - error("Unknown AD backend symbol: $adtype. " * - "Supported symbols: :ReverseDiff, :ForwardDiff, :Zygote, :Enzyme. " * - "Or use an ADTypes object like AutoReverseDiff(compile=true).") + error( + "Unknown AD backend symbol: $adtype. " * + "Supported symbols: :ReverseDiff, :ForwardDiff, :Zygote, :Enzyme. " * + "Or use an ADTypes object like AutoReverseDiff(compile=true).", + ) end end @@ -366,17 +368,13 @@ function _wrap_with_gradient(base_model::Model.BUGSModel, adtype::ADTypes.Abstra # Get initial parameters for preparation # Use invokelatest to handle world age issues with generated functions x = Base.invokelatest(getparams, base_model) - + # Prepare gradient using DifferentiationInterface # Use invokelatest to handle world age issues when calling logdensity during preparation prep = Base.invokelatest( - DI.prepare_gradient, - Model._logdensity_switched, - adtype, - x, - DI.Constant(base_model) + DI.prepare_gradient, Model._logdensity_switched, adtype, x, DI.Constant(base_model) ) - + return Model.BUGSModelWithGradient(adtype, prep, base_model) end # function compile( diff --git a/JuliaBUGS/src/model/logdensityproblems.jl b/JuliaBUGS/src/model/logdensityproblems.jl index 1b0381cae..2b97c5c7d 100644 --- a/JuliaBUGS/src/model/logdensityproblems.jl +++ b/JuliaBUGS/src/model/logdensityproblems.jl @@ -102,7 +102,11 @@ function LogDensityProblems.logdensity_and_gradient( ) else return DI.value_and_gradient( - _logdensity_switched, model.prep, model.backend, x, DI.Constant(model.base_model) + _logdensity_switched, + model.prep, + model.backend, + x, + DI.Constant(model.base_model), ) end end diff --git a/JuliaBUGS/test/BUGSPrimitives/distributions.jl b/JuliaBUGS/test/BUGSPrimitives/distributions.jl index 0262b051d..82c4f04af 100644 --- a/JuliaBUGS/test/BUGSPrimitives/distributions.jl +++ b/JuliaBUGS/test/BUGSPrimitives/distributions.jl @@ -15,7 +15,9 @@ end A[1:2, 1:2] ~ dwish(B[:, :], 2) C[1:2] ~ dmnorm(mu[:], A[:, :]) end - ad_model = compile(model_def, (mu=[0, 0], B=[1 0; 0 1]), (A=[1 0; 0 1],); adtype=AutoReverseDiff()) + ad_model = compile( + model_def, (mu=[0, 0], B=[1 0; 0 1]), (A=[1 0; 0 1],); adtype=AutoReverseDiff() + ) theta = [ 0.7931743744870574, diff --git a/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl b/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl index cb6a62c68..5877aa298 100644 --- a/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl +++ b/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl @@ -6,7 +6,7 @@ y = x[1] + x[3] end data = (mu=[0, 0], sigma=[1 0; 0 1]) - ad_model = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) + ad_model = compile(model_def, data; adtype=AutoReverseDiff(; compile=true)) n_samples, n_adapts = 10, 0 D = LogDensityProblems.dimension(ad_model) initial_θ = rand(D) @@ -34,19 +34,19 @@ end end data = (N=5, y=[1.0, 2.0, 1.5, 2.5, 1.8]) - + # Test that symbol shortcut works ad_model_symbol = compile(model_def, data; adtype=:ReverseDiff) - ad_model_explicit = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) - + ad_model_explicit = compile(model_def, data; adtype=AutoReverseDiff(; compile=true)) + @test ad_model_symbol isa JuliaBUGS.Model.BUGSModelWithGradient @test ad_model_explicit isa JuliaBUGS.Model.BUGSModelWithGradient - + # Test that both produce equivalent results n_samples, n_adapts = 100, 100 D = LogDensityProblems.dimension(ad_model_symbol) initial_θ = rand(StableRNG(123), D) - + samples_symbol = AbstractMCMC.sample( StableRNG(1234), ad_model_symbol, @@ -58,7 +58,7 @@ init_params=initial_θ, discard_initial=n_adapts, ) - + samples_explicit = AbstractMCMC.sample( StableRNG(1234), ad_model_explicit, @@ -70,10 +70,10 @@ init_params=initial_θ, discard_initial=n_adapts, ) - + # Results should be very similar (same RNG seed) - @test summarize(samples_symbol)[:mu].nt.mean[1] ≈ - summarize(samples_explicit)[:mu].nt.mean[1] rtol=0.1 + @test summarize(samples_symbol)[:mu].nt.mean[1] ≈ + summarize(samples_explicit)[:mu].nt.mean[1] rtol = 0.1 end @testset "Inference results on examples: $example" for example in @@ -81,7 +81,9 @@ (; model_def, data, inits, reference_results) = Base.getfield( JuliaBUGS.BUGSExamples, example ) - ad_model = JuliaBUGS.compile(model_def, data, inits; adtype=AutoReverseDiff(compile=true)) + ad_model = JuliaBUGS.compile( + model_def, data, inits; adtype=AutoReverseDiff(; compile=true) + ) n_samples, n_adapts = 1000, 1000 diff --git a/JuliaBUGS/test/ext/JuliaBUGSMCMCChainsExt.jl b/JuliaBUGS/test/ext/JuliaBUGSMCMCChainsExt.jl index 7d6f35832..2e7c16367 100644 --- a/JuliaBUGS/test/ext/JuliaBUGSMCMCChainsExt.jl +++ b/JuliaBUGS/test/ext/JuliaBUGSMCMCChainsExt.jl @@ -26,7 +26,7 @@ y=[1.58, 4.80, 7.10, 8.86, 11.73, 14.52, 18.22, 18.73, 21.04, 22.93], ) - ad_model = compile(model_def, data, (;); adtype=AutoReverseDiff(compile=true)) + ad_model = compile(model_def, data, (;); adtype=AutoReverseDiff(; compile=true)) n_samples, n_adapts = 2000, 1000 D = LogDensityProblems.dimension(ad_model) @@ -106,7 +106,7 @@ sigma[2] ~ InverseGamma(2, 3) sigma[3] ~ InverseGamma(2, 3) end - ad_model = compile(model_def, (;); adtype=AutoReverseDiff(compile=true)) + ad_model = compile(model_def, (;); adtype=AutoReverseDiff(; compile=true)) hmc_chain = AbstractMCMC.sample( ad_model, NUTS(0.8), 10; progress=false, chain_type=Chains ) diff --git a/JuliaBUGS/test/model/bugsmodel.jl b/JuliaBUGS/test/model/bugsmodel.jl index ddfe835c5..7a68d6264 100644 --- a/JuliaBUGS/test/model/bugsmodel.jl +++ b/JuliaBUGS/test/model/bugsmodel.jl @@ -414,22 +414,26 @@ end # Test :ReverseDiff shortcut model_rd = compile(model_def, data; adtype=:ReverseDiff) @test model_rd isa JuliaBUGS.Model.BUGSModelWithGradient - + # Test equivalence with explicit ADType - model_explicit = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) + model_explicit = compile( + model_def, data; adtype=AutoReverseDiff(; compile=true) + ) @test model_explicit isa JuliaBUGS.Model.BUGSModelWithGradient - + # Test that unknown symbol throws error @test_throws ErrorException compile(model_def, data; adtype=:UnknownBackend) end @testset "Explicit ADTypes" begin # Test with compile=true - model_compile = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) + model_compile = compile(model_def, data; adtype=AutoReverseDiff(; compile=true)) @test model_compile isa JuliaBUGS.Model.BUGSModelWithGradient - + # Test with compile=false - model_nocompile = compile(model_def, data; adtype=AutoReverseDiff(compile=false)) + model_nocompile = compile( + model_def, data; adtype=AutoReverseDiff(; compile=false) + ) @test model_nocompile isa JuliaBUGS.Model.BUGSModelWithGradient end @@ -443,10 +447,10 @@ end @testset "Gradient computation" begin model = compile(model_def, data; adtype=:ReverseDiff) test_point = [0.0] - + # Test that gradient can be computed ℓ, grad = LogDensityProblems.logdensity_and_gradient(model, test_point) - + @test ℓ isa Real @test grad isa Vector @test length(grad) == 1 diff --git a/JuliaBUGS/test/parallel_sampling.jl b/JuliaBUGS/test/parallel_sampling.jl index 8b857e4a0..23dfa4ebf 100644 --- a/JuliaBUGS/test/parallel_sampling.jl +++ b/JuliaBUGS/test/parallel_sampling.jl @@ -20,7 +20,7 @@ inits = (mu=0.0, tau=1.0) # Use compile=false for thread safety with ReverseDiff - ad_model = compile(model_def, data, inits; adtype=AutoReverseDiff(compile=false)) + ad_model = compile(model_def, data, inits; adtype=AutoReverseDiff(; compile=false)) # Single chain reference n_samples = 200 From 1e515d1e1edb3a2bb24c66e3b273c2f76035548b Mon Sep 17 00:00:00 2001 From: Shravan Goswami Date: Fri, 3 Oct 2025 00:26:53 +0530 Subject: [PATCH 4/7] try to fix benchmark failures --- JuliaBUGS/benchmark/benchmark.jl | 9 +++------ JuliaBUGS/benchmark/run_benchmarks.jl | 4 ++-- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/JuliaBUGS/benchmark/benchmark.jl b/JuliaBUGS/benchmark/benchmark.jl index 1c558fd16..14f86dd32 100644 --- a/JuliaBUGS/benchmark/benchmark.jl +++ b/JuliaBUGS/benchmark/benchmark.jl @@ -83,16 +83,13 @@ function _create_results_dataframe(results::OrderedDict{Symbol,BenchmarkResult}) ), ) end + DataFrames.rename!(df, :Density_Time => "Density Time (µs)", :Density_Gradient_Time => "Density+Gradient Time (µs)") return df end function _print_results_table( - results::OrderedDict{Symbol,BenchmarkResult}; backend=Val(:text) + results::OrderedDict{Symbol,BenchmarkResult}; backend=:text ) df = _create_results_dataframe(results) - return pretty_table( - df; - header=["Model", "Parameters", "Density Time (µs)", "Density+Gradient Time (µs)"], - backend=backend, - ) + return pretty_table(df; backend=backend) end diff --git a/JuliaBUGS/benchmark/run_benchmarks.jl b/JuliaBUGS/benchmark/run_benchmarks.jl index 3194239f3..88459701c 100644 --- a/JuliaBUGS/benchmark/run_benchmarks.jl +++ b/JuliaBUGS/benchmark/run_benchmarks.jl @@ -45,7 +45,7 @@ for (model_name, model) in zip(examples_to_benchmark, juliabugs_models) end println("### Stan results:") -_print_results_table(stan_results; backend=Val(:markdown)) +_print_results_table(stan_results; backend=:markdown) println("### JuliaBUGS Mooncake results:") -_print_results_table(juliabugs_results; backend=Val(:markdown)) +_print_results_table(juliabugs_results; backend=:markdown) From f0135a8df6b142d9272fc8dbce5d63af21b65f60 Mon Sep 17 00:00:00 2001 From: Shravan Goswami Date: Fri, 3 Oct 2025 00:45:37 +0530 Subject: [PATCH 5/7] format --- JuliaBUGS/benchmark/benchmark.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/JuliaBUGS/benchmark/benchmark.jl b/JuliaBUGS/benchmark/benchmark.jl index 14f86dd32..30d5985a6 100644 --- a/JuliaBUGS/benchmark/benchmark.jl +++ b/JuliaBUGS/benchmark/benchmark.jl @@ -83,13 +83,15 @@ function _create_results_dataframe(results::OrderedDict{Symbol,BenchmarkResult}) ), ) end - DataFrames.rename!(df, :Density_Time => "Density Time (µs)", :Density_Gradient_Time => "Density+Gradient Time (µs)") + DataFrames.rename!( + df, + :Density_Time => "Density Time (µs)", + :Density_Gradient_Time => "Density+Gradient Time (µs)", + ) return df end -function _print_results_table( - results::OrderedDict{Symbol,BenchmarkResult}; backend=:text -) +function _print_results_table(results::OrderedDict{Symbol,BenchmarkResult}; backend=:text) df = _create_results_dataframe(results) return pretty_table(df; backend=backend) end From c80dc9f73267431246b03cf24fad040867e3411b Mon Sep 17 00:00:00 2001 From: Shravan Goswami Date: Mon, 6 Oct 2025 15:38:44 +0530 Subject: [PATCH 6/7] add Mooncake and update docs --- JuliaBUGS/History.md | 2 +- JuliaBUGS/docs/src/example.md | 4 ++++ JuliaBUGS/src/JuliaBUGS.jl | 7 ++++++- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/JuliaBUGS/History.md b/JuliaBUGS/History.md index 066d5ad3a..65f850dd2 100644 --- a/JuliaBUGS/History.md +++ b/JuliaBUGS/History.md @@ -4,7 +4,7 @@ - **DifferentiationInterface.jl integration**: JuliaBUGS now uses [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl) for automatic differentiation, providing a unified interface to multiple AD backends. - Add `adtype` parameter to `compile()` function for specifying AD backends - - Support convenient symbol shortcuts: `:ReverseDiff`, `:ForwardDiff`, `:Zygote`, `:Enzyme` + - Support convenient symbol shortcuts: `:ReverseDiff`, `:ForwardDiff`, `:Zygote`, `:Enzyme`, `:Mooncake` - Gradient computation is prepared during compilation for optimal performance - Example: `model = compile(model_def, data; adtype=:ReverseDiff)` - Full control available via explicit ADTypes: `adtype=AutoReverseDiff(compile=true)` diff --git a/JuliaBUGS/docs/src/example.md b/JuliaBUGS/docs/src/example.md index 3cf73a4cc..2fc947740 100644 --- a/JuliaBUGS/docs/src/example.md +++ b/JuliaBUGS/docs/src/example.md @@ -409,6 +409,7 @@ The following AD backends are supported via convenient symbol shortcuts: - **`:ForwardDiff`** — Forward-mode AD, efficient for models with few parameters (typically < 20). - **`:Zygote`** — Source-to-source reverse-mode AD, general-purpose but may be slower than ReverseDiff for many models. - **`:Enzyme`** — Experimental high-performance AD backend with LLVM-level transformations. +- **`:Mooncake`** — High-performance reverse-mode AD with advanced optimizations. ### Usage Examples @@ -449,6 +450,9 @@ model = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) - **Tape compilation trade-off**: While `AutoReverseDiff(compile=true)` has higher initial compilation time, it provides faster gradient evaluations during sampling. For quick prototyping or models that will only be sampled a few times, `AutoReverseDiff(compile=false)` may be preferable. +!!! warning "Compiled tapes and control flow" + Compiled ReverseDiff tapes cannot handle value-dependent control flow (e.g., `if x[1] > 0`). If your model has such control flow, use `AutoReverseDiff(compile=false)` or a different backend like `:ForwardDiff` or `:Mooncake`. See the [ReverseDiff documentation](https://juliadiff.org/ReverseDiff.jl/stable/api/#The-AbstractTape-API) for details. + ### Compatibility All models compiled with an `adtype` implement the full [`LogDensityProblems.jl`](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/) interface, making them compatible with: diff --git a/JuliaBUGS/src/JuliaBUGS.jl b/JuliaBUGS/src/JuliaBUGS.jl index de31e43a7..ef7dfd106 100644 --- a/JuliaBUGS/src/JuliaBUGS.jl +++ b/JuliaBUGS/src/JuliaBUGS.jl @@ -264,6 +264,8 @@ samplers like HMC/NUTS. The gradient computation is prepared during compilation - `:ReverseDiff` - Shorthand for `AutoReverseDiff(compile=true)` - `:ForwardDiff` - Shorthand for `AutoForwardDiff()` - `:Zygote` - Shorthand for `AutoZygote()` + - `:Enzyme` - Shorthand for `AutoEnzyme()` + - `:Mooncake` - Shorthand for `AutoMooncake()` - Any other `ADTypes.AbstractADType` # Examples @@ -341,6 +343,7 @@ Supported symbol shortcuts: - `:ForwardDiff` -> `AutoForwardDiff()` - `:Zygote` -> `AutoZygote()` - `:Enzyme` -> `AutoEnzyme()` +- `:Mooncake` -> `AutoMooncake()` """ function _resolve_adtype(adtype::Symbol) if adtype === :ReverseDiff @@ -351,10 +354,12 @@ function _resolve_adtype(adtype::Symbol) return ADTypes.AutoZygote() elseif adtype === :Enzyme return ADTypes.AutoEnzyme() + elseif adtype === :Mooncake + return ADTypes.AutoMooncake(; config=nothing) else error( "Unknown AD backend symbol: $adtype. " * - "Supported symbols: :ReverseDiff, :ForwardDiff, :Zygote, :Enzyme. " * + "Supported symbols: :ReverseDiff, :ForwardDiff, :Zygote, :Enzyme, :Mooncake. " * "Or use an ADTypes object like AutoReverseDiff(compile=true).", ) end From ef36d7d2fc9dde0d7a9898e4e05a7baa7110698f Mon Sep 17 00:00:00 2001 From: Shravan Goswami Date: Tue, 7 Oct 2025 12:13:12 +0530 Subject: [PATCH 7/7] remove symbols --- JuliaBUGS/History.md | 7 +-- JuliaBUGS/docs/src/example.md | 52 ++++++++--------- JuliaBUGS/examples/sir.jl | 3 +- JuliaBUGS/src/JuliaBUGS.jl | 57 +++---------------- JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl | 26 ++++----- JuliaBUGS/test/model/bugsmodel.jl | 19 +------ 6 files changed, 54 insertions(+), 110 deletions(-) diff --git a/JuliaBUGS/History.md b/JuliaBUGS/History.md index 65f850dd2..c004752ac 100644 --- a/JuliaBUGS/History.md +++ b/JuliaBUGS/History.md @@ -3,11 +3,10 @@ ## 0.10.4 - **DifferentiationInterface.jl integration**: JuliaBUGS now uses [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl) for automatic differentiation, providing a unified interface to multiple AD backends. - - Add `adtype` parameter to `compile()` function for specifying AD backends - - Support convenient symbol shortcuts: `:ReverseDiff`, `:ForwardDiff`, `:Zygote`, `:Enzyme`, `:Mooncake` + - Add `adtype` parameter to `compile()` function for specifying AD backends via [ADTypes.jl](https://github.com/SciML/ADTypes.jl) + - Supports multiple backends: `AutoReverseDiff`, `AutoForwardDiff`, `AutoZygote`, `AutoEnzyme`, `AutoMooncake` - Gradient computation is prepared during compilation for optimal performance - - Example: `model = compile(model_def, data; adtype=:ReverseDiff)` - - Full control available via explicit ADTypes: `adtype=AutoReverseDiff(compile=true)` + - Example: `model = compile(model_def, data; adtype=AutoReverseDiff(compile=true))` - Backward compatible: models without `adtype` work as before ## 0.10.1 diff --git a/JuliaBUGS/docs/src/example.md b/JuliaBUGS/docs/src/example.md index 2fc947740..57ea54466 100644 --- a/JuliaBUGS/docs/src/example.md +++ b/JuliaBUGS/docs/src/example.md @@ -199,21 +199,19 @@ JuliaBUGS integrates with automatic differentiation (AD) through [Differentiatio To compile a model with gradient support, pass the `adtype` parameter to `compile`: ```julia -# Using explicit ADType from ADTypes.jl +# Compile with gradient support using ADTypes from ADTypes.jl using ADTypes model = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) - -# Using convenient symbol shortcuts -model = compile(model_def, data; adtype=:ReverseDiff) # Equivalent to above ``` Available AD backends include: -- `:ReverseDiff` - ReverseDiff with tape compilation (recommended for most models) -- `:ForwardDiff` - ForwardDiff (efficient for models with few parameters) -- `:Zygote` - Zygote (source-to-source AD) -- `:Enzyme` - Enzyme (experimental, high-performance) +- `AutoReverseDiff(compile=true)` - ReverseDiff with tape compilation (recommended for most models) +- `AutoForwardDiff()` - ForwardDiff (efficient for models with few parameters) +- `AutoZygote()` - Zygote (source-to-source AD) +- `AutoEnzyme()` - Enzyme (experimental, high-performance) +- `AutoMooncake()` - Mooncake (high-performance reverse-mode AD) -For fine-grained control, use explicit `ADTypes` constructors: +For fine-grained control, you can configure the AD backend: ```julia # ReverseDiff without compilation @@ -230,7 +228,7 @@ For gradient-based inference, we use [`AdvancedHMC.jl`](https://github.com/Turin using AdvancedHMC, AbstractMCMC, LogDensityProblems, MCMCChains, ReverseDiff # Compile with gradient support -model = compile(model_def, data; adtype=:ReverseDiff) +model = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) n_samples, n_adapts = 2000, 1000 @@ -403,34 +401,36 @@ JuliaBUGS integrates with multiple automatic differentiation (AD) backends throu ### Available Backends -The following AD backends are supported via convenient symbol shortcuts: +The following AD backends are supported via [ADTypes.jl](https://github.com/SciML/ADTypes.jl): -- **`:ReverseDiff`** (recommended) — Tape-based reverse-mode AD, highly efficient for models with many parameters. Uses compilation by default for optimal performance. -- **`:ForwardDiff`** — Forward-mode AD, efficient for models with few parameters (typically < 20). -- **`:Zygote`** — Source-to-source reverse-mode AD, general-purpose but may be slower than ReverseDiff for many models. -- **`:Enzyme`** — Experimental high-performance AD backend with LLVM-level transformations. -- **`:Mooncake`** — High-performance reverse-mode AD with advanced optimizations. +- **`AutoReverseDiff(compile=true)`** (recommended) — Tape-based reverse-mode AD, highly efficient for models with many parameters. Uses compilation by default for optimal performance. +- **`AutoForwardDiff()`** — Forward-mode AD, efficient for models with few parameters (typically < 20). +- **`AutoZygote()`** — Source-to-source reverse-mode AD, general-purpose but may be slower than ReverseDiff for many models. +- **`AutoEnzyme()`** — Experimental high-performance AD backend with LLVM-level transformations. +- **`AutoMooncake()`** — High-performance reverse-mode AD with advanced optimizations. ### Usage Examples -#### Basic Usage with Symbol Shortcuts +#### Basic Usage -The simplest way to specify an AD backend is using symbol shortcuts: +Specify an AD backend using ADTypes: ```julia +using ADTypes + # ReverseDiff with tape compilation (recommended for most models) -model = compile(model_def, data; adtype=:ReverseDiff) +model = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) # ForwardDiff (good for small models with few parameters) -model = compile(model_def, data; adtype=:ForwardDiff) +model = compile(model_def, data; adtype=AutoForwardDiff()) # Zygote (source-to-source AD) -model = compile(model_def, data; adtype=:Zygote) +model = compile(model_def, data; adtype=AutoZygote()) ``` #### Advanced Configuration -For fine-grained control, use explicit `ADTypes` constructors: +For fine-grained control, you can configure the AD backends: ```julia using ADTypes @@ -438,20 +438,20 @@ using ADTypes # ReverseDiff without tape compilation model = compile(model_def, data; adtype=AutoReverseDiff(compile=false)) -# ReverseDiff with compilation (equivalent to :ReverseDiff) +# ReverseDiff with compilation (default, recommended) model = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) ``` ### Performance Considerations -- **ReverseDiff with compilation** (`:ReverseDiff`) is recommended for most models, especially those with many parameters. Compilation adds a one-time overhead but significantly speeds up subsequent gradient evaluations. +- **ReverseDiff with compilation** (`AutoReverseDiff(compile=true)`) is recommended for most models, especially those with many parameters. Compilation adds a one-time overhead but significantly speeds up subsequent gradient evaluations. -- **ForwardDiff** (`:ForwardDiff`) is often faster for models with few parameters (< 20), as it avoids tape construction overhead. +- **ForwardDiff** (`AutoForwardDiff()`) is often faster for models with few parameters (< 20), as it avoids tape construction overhead. - **Tape compilation trade-off**: While `AutoReverseDiff(compile=true)` has higher initial compilation time, it provides faster gradient evaluations during sampling. For quick prototyping or models that will only be sampled a few times, `AutoReverseDiff(compile=false)` may be preferable. !!! warning "Compiled tapes and control flow" - Compiled ReverseDiff tapes cannot handle value-dependent control flow (e.g., `if x[1] > 0`). If your model has such control flow, use `AutoReverseDiff(compile=false)` or a different backend like `:ForwardDiff` or `:Mooncake`. See the [ReverseDiff documentation](https://juliadiff.org/ReverseDiff.jl/stable/api/#The-AbstractTape-API) for details. + Compiled ReverseDiff tapes cannot handle value-dependent control flow (e.g., `if x[1] > 0`). If your model has such control flow, use `AutoReverseDiff(compile=false)` or a different backend like `AutoForwardDiff()` or `AutoMooncake()`. See the [ReverseDiff documentation](https://juliadiff.org/ReverseDiff.jl/stable/api/#The-AbstractTape-API) for details. ### Compatibility diff --git a/JuliaBUGS/examples/sir.jl b/JuliaBUGS/examples/sir.jl index 108d47ce1..ccd4f3279 100644 --- a/JuliaBUGS/examples/sir.jl +++ b/JuliaBUGS/examples/sir.jl @@ -7,6 +7,7 @@ using JuliaBUGS: @model using Distributions using DifferentialEquations using LogDensityProblems, LogDensityProblemsAD +using ADTypes using AbstractMCMC, AdvancedHMC, MCMCChains using Distributed # For distributed example @@ -113,7 +114,7 @@ model = JuliaBUGS.set_evaluation_mode(model, JuliaBUGS.UseGraph()) # --- MCMC Sampling: NUTS with ForwardDiff AD --- # Create an AD-aware wrapper for the model using ForwardDiff for gradients -ad_model_forwarddiff = ADgradient(:ForwardDiff, model) +ad_model_forwarddiff = ADgradient(AutoForwardDiff(), model) # MCMC settings n_samples = 1000 diff --git a/JuliaBUGS/src/JuliaBUGS.jl b/JuliaBUGS/src/JuliaBUGS.jl index ef7dfd106..a6e4e5f0c 100644 --- a/JuliaBUGS/src/JuliaBUGS.jl +++ b/JuliaBUGS/src/JuliaBUGS.jl @@ -258,14 +258,13 @@ samplers like HMC/NUTS. The gradient computation is prepared during compilation - `initial_params::NamedTuple=NamedTuple()`: Initial parameter values (optional) - `skip_validation::Bool=false`: Skip function validation (for @model macro) - `eval_module::Module=@__MODULE__`: Module for evaluation -- `adtype`: AD backend specification. Can be: +- `adtype`: AD backend specification using ADTypes. Examples: - `AutoReverseDiff(compile=true)` - ReverseDiff with tape compilation (fastest) - `AutoReverseDiff(compile=false)` - ReverseDiff without compilation - - `:ReverseDiff` - Shorthand for `AutoReverseDiff(compile=true)` - - `:ForwardDiff` - Shorthand for `AutoForwardDiff()` - - `:Zygote` - Shorthand for `AutoZygote()` - - `:Enzyme` - Shorthand for `AutoEnzyme()` - - `:Mooncake` - Shorthand for `AutoMooncake()` + - `AutoForwardDiff()` - ForwardDiff backend + - `AutoZygote()` - Zygote backend + - `AutoEnzyme()` - Enzyme backend + - `AutoMooncake()` - Mooncake backend - Any other `ADTypes.AbstractADType` # Examples @@ -273,14 +272,11 @@ samplers like HMC/NUTS. The gradient computation is prepared during compilation # Basic compilation model = compile(model_def, data) -# With gradient support using explicit ADType +# With gradient support using ReverseDiff (recommended for most models) model = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) -# With gradient support using symbol shorthand -model = compile(model_def, data; adtype=:ReverseDiff) # Same as above - # Using ForwardDiff for small models -model = compile(model_def, data; adtype=:ForwardDiff) +model = compile(model_def, data; adtype=AutoForwardDiff()) # Sample with NUTS chain = AbstractMCMC.sample(model, NUTS(0.8), 1000) @@ -325,49 +321,12 @@ function compile( # If adtype provided, wrap with gradient capabilities if adtype !== nothing - # Convert symbol to ADType if needed - adtype_obj = _resolve_adtype(adtype) - return _wrap_with_gradient(base_model, adtype_obj) + return _wrap_with_gradient(base_model, adtype) end return base_model end -""" - _resolve_adtype(adtype) -> ADTypes.AbstractADType - -Convert symbol shortcuts to ADTypes, or return the ADType as-is. - -Supported symbol shortcuts: -- `:ReverseDiff` -> `AutoReverseDiff(compile=true)` -- `:ForwardDiff` -> `AutoForwardDiff()` -- `:Zygote` -> `AutoZygote()` -- `:Enzyme` -> `AutoEnzyme()` -- `:Mooncake` -> `AutoMooncake()` -""" -function _resolve_adtype(adtype::Symbol) - if adtype === :ReverseDiff - return ADTypes.AutoReverseDiff(; compile=true) - elseif adtype === :ForwardDiff - return ADTypes.AutoForwardDiff() - elseif adtype === :Zygote - return ADTypes.AutoZygote() - elseif adtype === :Enzyme - return ADTypes.AutoEnzyme() - elseif adtype === :Mooncake - return ADTypes.AutoMooncake(; config=nothing) - else - error( - "Unknown AD backend symbol: $adtype. " * - "Supported symbols: :ReverseDiff, :ForwardDiff, :Zygote, :Enzyme, :Mooncake. " * - "Or use an ADTypes object like AutoReverseDiff(compile=true).", - ) - end -end - -# Pass through ADTypes objects unchanged -_resolve_adtype(adtype::ADTypes.AbstractADType) = adtype - # Helper function to prepare gradient - separated to handle world age issues function _wrap_with_gradient(base_model::Model.BUGSModel, adtype::ADTypes.AbstractADType) # Get initial parameters for preparation diff --git a/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl b/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl index 5877aa298..843963864 100644 --- a/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl +++ b/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl @@ -26,7 +26,7 @@ [Symbol("x[3]"), Symbol("x[1:2][1]"), Symbol("x[1:2][2]"), :y] end - @testset "Symbol AD backend shortcuts" begin + @testset "AD backend sampling" begin model_def = @bugs begin mu ~ dnorm(0, 1) for i in 1:N @@ -35,21 +35,21 @@ end data = (N=5, y=[1.0, 2.0, 1.5, 2.5, 1.8]) - # Test that symbol shortcut works - ad_model_symbol = compile(model_def, data; adtype=:ReverseDiff) - ad_model_explicit = compile(model_def, data; adtype=AutoReverseDiff(; compile=true)) + # Test that ReverseDiff backend works + ad_model_compiled = compile(model_def, data; adtype=AutoReverseDiff(; compile=true)) + ad_model_nocompile = compile(model_def, data; adtype=AutoReverseDiff(; compile=false)) - @test ad_model_symbol isa JuliaBUGS.Model.BUGSModelWithGradient - @test ad_model_explicit isa JuliaBUGS.Model.BUGSModelWithGradient + @test ad_model_compiled isa JuliaBUGS.Model.BUGSModelWithGradient + @test ad_model_nocompile isa JuliaBUGS.Model.BUGSModelWithGradient # Test that both produce equivalent results n_samples, n_adapts = 100, 100 - D = LogDensityProblems.dimension(ad_model_symbol) + D = LogDensityProblems.dimension(ad_model_compiled) initial_θ = rand(StableRNG(123), D) - samples_symbol = AbstractMCMC.sample( + samples_compiled = AbstractMCMC.sample( StableRNG(1234), - ad_model_symbol, + ad_model_compiled, NUTS(0.8), n_samples; progress=false, @@ -59,9 +59,9 @@ discard_initial=n_adapts, ) - samples_explicit = AbstractMCMC.sample( + samples_nocompile = AbstractMCMC.sample( StableRNG(1234), - ad_model_explicit, + ad_model_nocompile, NUTS(0.8), n_samples; progress=false, @@ -72,8 +72,8 @@ ) # Results should be very similar (same RNG seed) - @test summarize(samples_symbol)[:mu].nt.mean[1] ≈ - summarize(samples_explicit)[:mu].nt.mean[1] rtol = 0.1 + @test summarize(samples_compiled)[:mu].nt.mean[1] ≈ + summarize(samples_nocompile)[:mu].nt.mean[1] rtol = 0.1 end @testset "Inference results on examples: $example" for example in diff --git a/JuliaBUGS/test/model/bugsmodel.jl b/JuliaBUGS/test/model/bugsmodel.jl index 7a68d6264..71b19834d 100644 --- a/JuliaBUGS/test/model/bugsmodel.jl +++ b/JuliaBUGS/test/model/bugsmodel.jl @@ -410,22 +410,7 @@ end end data = (y=1.5,) - @testset "Symbol shortcuts" begin - # Test :ReverseDiff shortcut - model_rd = compile(model_def, data; adtype=:ReverseDiff) - @test model_rd isa JuliaBUGS.Model.BUGSModelWithGradient - - # Test equivalence with explicit ADType - model_explicit = compile( - model_def, data; adtype=AutoReverseDiff(; compile=true) - ) - @test model_explicit isa JuliaBUGS.Model.BUGSModelWithGradient - - # Test that unknown symbol throws error - @test_throws ErrorException compile(model_def, data; adtype=:UnknownBackend) - end - - @testset "Explicit ADTypes" begin + @testset "ADTypes backends" begin # Test with compile=true model_compile = compile(model_def, data; adtype=AutoReverseDiff(; compile=true)) @test model_compile isa JuliaBUGS.Model.BUGSModelWithGradient @@ -445,7 +430,7 @@ end end @testset "Gradient computation" begin - model = compile(model_def, data; adtype=:ReverseDiff) + model = compile(model_def, data; adtype=AutoReverseDiff(; compile=true)) test_point = [0.0] # Test that gradient can be computed