Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,6 @@ Manifest.toml
*.local.*

# gitingest generated files
digest.txt
digest.txt

tmp/
9 changes: 9 additions & 0 deletions JuliaBUGS/History.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# JuliaBUGS Changelog

## 0.10.4

- **DifferentiationInterface.jl integration**: JuliaBUGS now uses [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl) for automatic differentiation, providing a unified interface to multiple AD backends.
- Add `adtype` parameter to `compile()` function for specifying AD backends via [ADTypes.jl](https://github.com/SciML/ADTypes.jl)
- Supports multiple backends: `AutoReverseDiff`, `AutoForwardDiff`, `AutoZygote`, `AutoEnzyme`, `AutoMooncake`
- Gradient computation is prepared during compilation for optimal performance
- Example: `model = compile(model_def, data; adtype=AutoReverseDiff(compile=true))`
- Backward compatible: models without `adtype` work as before

## 0.10.1

Expose docs for changes in [v0.10.0](https://github.com/TuringLang/JuliaBUGS.jl/releases/tag/JuliaBUGS-v0.10.0)
Expand Down
4 changes: 3 additions & 1 deletion JuliaBUGS/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "JuliaBUGS"
uuid = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf"
version = "0.10.3"
version = "0.10.4"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it really non-breaking to make such a big change?

Copy link
Member Author

@shravanngoswamii shravanngoswamii Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The adtype parameter is optional (defaults to nothing), all existing code works unchanged, so this is backward compatible.

However, if you prefer 0.11.0, I can change it.


[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -9,6 +9,7 @@ AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
Expand Down Expand Up @@ -52,6 +53,7 @@ AdvancedHMC = "0.6, 0.7, 0.8"
AdvancedMH = "0.8"
BangBang = "0.4.1"
Bijectors = "0.13, 0.14, 0.15.5"
DifferentiationInterface = "0.7"
Distributions = "0.23.8, 0.24, 0.25"
Documenter = "0.27, 1"
GLMakie = "0.10, 0.11, 0.12, 0.13"
Expand Down
15 changes: 7 additions & 8 deletions JuliaBUGS/benchmark/benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,15 @@ function _create_results_dataframe(results::OrderedDict{Symbol,BenchmarkResult})
),
)
end
DataFrames.rename!(
df,
:Density_Time => "Density Time (µs)",
:Density_Gradient_Time => "Density+Gradient Time (µs)",
)
return df
end

function _print_results_table(
results::OrderedDict{Symbol,BenchmarkResult}; backend=Val(:text)
)
function _print_results_table(results::OrderedDict{Symbol,BenchmarkResult}; backend=:text)
df = _create_results_dataframe(results)
return pretty_table(
df;
header=["Model", "Parameters", "Density Time (µs)", "Density+Gradient Time (µs)"],
backend=backend,
)
return pretty_table(df; backend=backend)
end
4 changes: 2 additions & 2 deletions JuliaBUGS/benchmark/run_benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ for (model_name, model) in zip(examples_to_benchmark, juliabugs_models)
end

println("### Stan results:")
_print_results_table(stan_results; backend=Val(:markdown))
_print_results_table(stan_results; backend=:markdown)

println("### JuliaBUGS Mooncake results:")
_print_results_table(juliabugs_results; backend=Val(:markdown))
_print_results_table(juliabugs_results; backend=:markdown)
204 changes: 162 additions & 42 deletions JuliaBUGS/docs/src/example.md

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion JuliaBUGS/examples/sir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using JuliaBUGS: @model
using Distributions
using DifferentialEquations
using LogDensityProblems, LogDensityProblemsAD
using ADTypes
using AbstractMCMC, AdvancedHMC, MCMCChains
using Distributed # For distributed example

Expand Down Expand Up @@ -113,7 +114,7 @@ model = JuliaBUGS.set_evaluation_mode(model, JuliaBUGS.UseGraph())
# --- MCMC Sampling: NUTS with ForwardDiff AD ---

# Create an AD-aware wrapper for the model using ForwardDiff for gradients
ad_model_forwarddiff = ADgradient(:ForwardDiff, model)
ad_model_forwarddiff = ADgradient(AutoForwardDiff(), model)

# MCMC settings
n_samples = 1000
Expand Down
48 changes: 41 additions & 7 deletions JuliaBUGS/ext/JuliaBUGSAdvancedHMCExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ module JuliaBUGSAdvancedHMCExt
using AbstractMCMC
using AdvancedHMC
using ADTypes
import DifferentiationInterface as DI
using JuliaBUGS
using JuliaBUGS: BUGSModel, getparams, initialize!
using JuliaBUGS: BUGSModel, BUGSModelWithGradient, getparams, initialize!
using JuliaBUGS.LogDensityProblems
using JuliaBUGS.LogDensityProblemsAD
using JuliaBUGS.Model: _logdensity_switched
using JuliaBUGS.Random
using MCMCChains: Chains

Expand Down Expand Up @@ -40,10 +42,11 @@ end
function _gibbs_internal_hmc(
rng::Random.AbstractRNG, cond_model::BUGSModel, sampler, ad_backend, state
)
# Wrap model with AD gradient computation
logdensitymodel = AbstractMCMC.LogDensityModel(
LogDensityProblemsAD.ADgradient(ad_backend, cond_model)
)
# Create gradient model on-the-fly using DifferentiationInterface
x = getparams(cond_model)
prep = DI.prepare_gradient(_logdensity_switched, ad_backend, x, DI.Constant(cond_model))
ad_model = BUGSModelWithGradient(ad_backend, prep, cond_model)
logdensitymodel = AbstractMCMC.LogDensityModel(ad_model)

# Take HMC/NUTS step
if isnothing(state)
Expand All @@ -53,7 +56,7 @@ function _gibbs_internal_hmc(
logdensitymodel,
sampler;
n_adapts=0, # Disable adaptation within Gibbs
initial_params=getparams(cond_model),
initial_params=x,
)
else
# Use existing state for subsequent steps
Expand All @@ -67,7 +70,7 @@ end

function AbstractMCMC.bundle_samples(
ts::Vector{<:AdvancedHMC.Transition},
logdensitymodel::AbstractMCMC.LogDensityModel{<:LogDensityProblemsAD.ADGradientWrapper},
logdensitymodel::AbstractMCMC.LogDensityModel{<:BUGSModelWithGradient},
sampler::AdvancedHMC.AbstractHMCSampler,
state,
chain_type::Type{Chains};
Expand Down Expand Up @@ -98,4 +101,35 @@ function AbstractMCMC.bundle_samples(
)
end

# Keep backward compatibility with LogDensityProblemsAD wrapper
function AbstractMCMC.bundle_samples(
ts::Vector{<:AdvancedHMC.Transition},
logdensitymodel::AbstractMCMC.LogDensityModel{<:LogDensityProblemsAD.ADGradientWrapper},
sampler::AdvancedHMC.AbstractHMCSampler,
state,
chain_type::Type{Chains};
discard_initial=0,
thinning=1,
kwargs...,
)
param_samples = [t.z.θ for t in ts]

stats_names = collect(keys(merge((; lp=ts[1].z.ℓπ.value), AdvancedHMC.stat(ts[1]))))
stats_values = [
vcat(ts[i].z.ℓπ.value, collect(values(AdvancedHMC.stat(ts[i])))) for
i in eachindex(ts)
]

# Delegate to gen_chains for proper parameter naming from BUGSModel
return JuliaBUGS.gen_chains(
logdensitymodel,
param_samples,
stats_names,
stats_values;
discard_initial=discard_initial,
thinning=thinning,
kwargs...,
)
end

end
41 changes: 35 additions & 6 deletions JuliaBUGS/ext/JuliaBUGSAdvancedMHExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ module JuliaBUGSAdvancedMHExt
using AbstractMCMC
using AdvancedMH
using ADTypes
import DifferentiationInterface as DI
using JuliaBUGS
using JuliaBUGS: BUGSModel, getparams, initialize!
using JuliaBUGS: BUGSModel, BUGSModelWithGradient, getparams, initialize!
using JuliaBUGS.LogDensityProblems
using JuliaBUGS.LogDensityProblemsAD
using JuliaBUGS.Model: _logdensity_switched
using JuliaBUGS.Random
using MCMCChains: Chains

Expand Down Expand Up @@ -52,10 +54,11 @@ end
function _gibbs_internal_mh(
rng::Random.AbstractRNG, cond_model::BUGSModel, sampler, ad_backend, state
)
# Wrap model with AD gradient computation for gradient-based proposals
logdensitymodel = AbstractMCMC.LogDensityModel(
LogDensityProblemsAD.ADgradient(ad_backend, cond_model)
)
# Create gradient model on-the-fly using DifferentiationInterface
x = getparams(cond_model)
prep = DI.prepare_gradient(_logdensity_switched, ad_backend, x, DI.Constant(cond_model))
ad_model = BUGSModelWithGradient(ad_backend, prep, cond_model)
logdensitymodel = AbstractMCMC.LogDensityModel(ad_model)

# Take MH step with gradient information
if isnothing(state)
Expand All @@ -64,7 +67,7 @@ function _gibbs_internal_mh(
logdensitymodel,
sampler;
n_adapts=0, # Disable adaptation within Gibbs
initial_params=getparams(cond_model),
initial_params=x,
)
else
t, s = AbstractMCMC.step(rng, logdensitymodel, sampler, state; n_adapts=0)
Expand Down Expand Up @@ -103,6 +106,32 @@ function AbstractMCMC.bundle_samples(
)
end

function AbstractMCMC.bundle_samples(
ts::Vector{<:AdvancedMH.Transition},
logdensitymodel::AbstractMCMC.LogDensityModel{<:BUGSModelWithGradient},
sampler::AdvancedMH.MHSampler,
state,
chain_type::Type{Chains};
discard_initial=0,
thinning=1,
kwargs...,
)
param_samples = [t.params for t in ts]
stats_names = [:lp]
stats_values = [[t.lp] for t in ts]

return JuliaBUGS.gen_chains(
logdensitymodel,
param_samples,
stats_names,
stats_values;
discard_initial=discard_initial,
thinning=thinning,
kwargs...,
)
end

# Keep backward compatibility with LogDensityProblemsAD wrapper
function AbstractMCMC.bundle_samples(
ts::Vector{<:AdvancedMH.Transition},
logdensitymodel::AbstractMCMC.LogDensityModel{<:LogDensityProblemsAD.ADGradientWrapper},
Expand Down
21 changes: 20 additions & 1 deletion JuliaBUGS/ext/JuliaBUGSMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@ module JuliaBUGSMCMCChainsExt

using AbstractMCMC
using JuliaBUGS
using JuliaBUGS: BUGSModel, find_generated_quantities_variables, evaluate!!, getparams
using JuliaBUGS:
BUGSModel,
BUGSModelWithGradient,
find_generated_quantities_variables,
evaluate!!,
getparams
using JuliaBUGS.AbstractPPL
using JuliaBUGS.Accessors
using JuliaBUGS.LogDensityProblemsAD
Expand All @@ -21,6 +26,20 @@ function JuliaBUGS.gen_chains(
)
end

function JuliaBUGS.gen_chains(
model::AbstractMCMC.LogDensityModel{<:BUGSModelWithGradient},
samples,
stats_names,
stats_values;
kwargs...,
)
# Extract BUGSModel from gradient wrapper
bugs_model = model.logdensity.base_model

return JuliaBUGS.gen_chains(bugs_model, samples, stats_names, stats_values; kwargs...)
end

# Keep backward compatibility with LogDensityProblemsAD wrapper
function JuliaBUGS.gen_chains(
model::AbstractMCMC.LogDensityModel{<:LogDensityProblemsAD.ADGradientWrapper},
samples,
Expand Down
62 changes: 60 additions & 2 deletions JuliaBUGS/src/JuliaBUGS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using Accessors
using ADTypes
using BangBang
using Bijectors: Bijectors
using DifferentiationInterface
using Distributions
using Graphs, MetaGraphsNext
using LinearAlgebra
Expand All @@ -17,6 +18,7 @@ using Serialization: Serialization
using StaticArrays

import Base: ==, hash, Symbol, size
import DifferentiationInterface as DI
import Distributions: truncated

export @bugs
Expand Down Expand Up @@ -239,20 +241,54 @@ function validate_bugs_expression(expr, line_num)
end

"""
compile(model_def, data[, initial_params]; skip_validation=false)
compile(model_def, data[, initial_params]; skip_validation=false, adtype=nothing)

Compile the model with model definition and data. Optionally, initializations can be provided.
If initializations are not provided, values will be sampled from the prior distributions.

By default, validates that all functions in the model are in the BUGS allowlist (suitable for @bugs macro).
Set `skip_validation=true` to skip validation (for @model macro usage).

If `adtype` is provided, returns a `BUGSModelWithGradient` that supports gradient-based MCMC
samplers like HMC/NUTS. The gradient computation is prepared during compilation for optimal performance.

# Arguments
- `model_def::Expr`: Model definition from @bugs macro
- `data::NamedTuple`: Observed data
- `initial_params::NamedTuple=NamedTuple()`: Initial parameter values (optional)
- `skip_validation::Bool=false`: Skip function validation (for @model macro)
- `eval_module::Module=@__MODULE__`: Module for evaluation
- `adtype`: AD backend specification using ADTypes. Examples:
- `AutoReverseDiff(compile=true)` - ReverseDiff with tape compilation (fastest)
- `AutoReverseDiff(compile=false)` - ReverseDiff without compilation
- `AutoForwardDiff()` - ForwardDiff backend
- `AutoZygote()` - Zygote backend
- `AutoEnzyme()` - Enzyme backend
- `AutoMooncake()` - Mooncake backend
- Any other `ADTypes.AbstractADType`

# Examples
```julia
# Basic compilation
model = compile(model_def, data)

# With gradient support using ReverseDiff (recommended for most models)
model = compile(model_def, data; adtype=AutoReverseDiff(compile=true))

# Using ForwardDiff for small models
model = compile(model_def, data; adtype=AutoForwardDiff())

# Sample with NUTS
chain = AbstractMCMC.sample(model, NUTS(0.8), 1000)
```
"""
function compile(
model_def::Expr,
data::NamedTuple,
initial_params::NamedTuple=NamedTuple();
skip_validation::Bool=false,
eval_module::Module=@__MODULE__,
adtype::Union{Nothing,ADTypes.AbstractADType,Symbol}=nothing,
)
# Validate functions by default (for @bugs macro usage)
# Skip validation only for @model macro
Expand Down Expand Up @@ -281,7 +317,29 @@ function compile(
values(eval_env),
),
)
return BUGSModel(g, nonmissing_eval_env, model_def, data, initial_params)
base_model = BUGSModel(g, nonmissing_eval_env, model_def, data, initial_params)

# If adtype provided, wrap with gradient capabilities
if adtype !== nothing
return _wrap_with_gradient(base_model, adtype)
end

return base_model
end

# Helper function to prepare gradient - separated to handle world age issues
function _wrap_with_gradient(base_model::Model.BUGSModel, adtype::ADTypes.AbstractADType)
# Get initial parameters for preparation
# Use invokelatest to handle world age issues with generated functions
x = Base.invokelatest(getparams, base_model)

# Prepare gradient using DifferentiationInterface
# Use invokelatest to handle world age issues when calling logdensity during preparation
prep = Base.invokelatest(
DI.prepare_gradient, Model._logdensity_switched, adtype, x, DI.Constant(base_model)
)

return Model.BUGSModelWithGradient(adtype, prep, base_model)
end
# function compile(
# model_str::String,
Expand Down
Loading
Loading