|
| 1 | +# Example demonstrating the use of Gaussian Processes (GPs) within JuliaBUGS |
| 2 | +# for modeling golf putting accuracy based on distance. |
| 3 | +# This example uses AbstractGPs.jl for the GP implementation and AdvancedHMC.jl |
| 4 | +# for sampling from the posterior distribution. |
| 5 | + |
| 6 | +using JuliaBUGS |
| 7 | +using JuliaBUGS: @model |
| 8 | + |
| 9 | +# Required packages for GP modeling and MCMC |
| 10 | +using AbstractGPs, Distributions, LogExpFunctions |
| 11 | +using LogDensityProblems, LogDensityProblemsAD |
| 12 | +using AbstractMCMC, AdvancedHMC, MCMCChains |
| 13 | + |
| 14 | +# Differentiation backend |
| 15 | +using DifferentiationInterface |
| 16 | +using Mooncake: Mooncake |
| 17 | + |
| 18 | +# --- Data Definition --- |
| 19 | + |
| 20 | +# Golf putting data from Gelman et al. (BDA3, Chapter 5) |
| 21 | +golf_data = ( |
| 22 | + distance=[2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], # Distance in feet |
| 23 | + n=[ # Number of putts attempted |
| 24 | + 1443, |
| 25 | + 694, |
| 26 | + 455, |
| 27 | + 353, |
| 28 | + 272, |
| 29 | + 256, |
| 30 | + 240, |
| 31 | + 217, |
| 32 | + 200, |
| 33 | + 237, |
| 34 | + 202, |
| 35 | + 192, |
| 36 | + 174, |
| 37 | + 167, |
| 38 | + 201, |
| 39 | + 195, |
| 40 | + 191, |
| 41 | + 147, |
| 42 | + 152, |
| 43 | + ], |
| 44 | + y=[ # Number of successful putts |
| 45 | + 1346, |
| 46 | + 577, |
| 47 | + 337, |
| 48 | + 208, |
| 49 | + 149, |
| 50 | + 136, |
| 51 | + 111, |
| 52 | + 69, |
| 53 | + 67, |
| 54 | + 75, |
| 55 | + 52, |
| 56 | + 46, |
| 57 | + 54, |
| 58 | + 28, |
| 59 | + 27, |
| 60 | + 31, |
| 61 | + 33, |
| 62 | + 20, |
| 63 | + 24, |
| 64 | + ], |
| 65 | +) |
| 66 | + |
| 67 | +# Prepare data in the format expected by the BUGS model |
| 68 | +data = ( |
| 69 | + d=golf_data.distance, |
| 70 | + n=golf_data.n, |
| 71 | + y=golf_data.y, |
| 72 | + jitter=1e-6, # Small value added to GP kernel diagonal for numerical stability |
| 73 | + N=length(golf_data.distance), |
| 74 | +) |
| 75 | + |
| 76 | +# --- BUGS Model Definition --- |
| 77 | + |
| 78 | +@model function gp_golf_putting((; v, l, f_latent, y), N, n, d, jitter) |
| 79 | + # Priors for GP hyperparameters |
| 80 | + v ~ Distributions.Gamma(2, 1) # Variance |
| 81 | + l ~ Distributions.Gamma(4, 1) # Lengthscale |
| 82 | + |
| 83 | + # Latent GP function values |
| 84 | + # f_latent represents the underlying putting success probability (on logit scale) |
| 85 | + # modeled by a GP. |
| 86 | + f_latent[1:N] ~ gp_predict(v, l, d[1:N], jitter) |
| 87 | + |
| 88 | + # Likelihood: Binomial distribution for observed successes |
| 89 | + # The success probability for each distance is the logistic transformation of the latent GP value. |
| 90 | + y[1:N] ~ y_distribution(n[1:N], f_latent[1:N]) |
| 91 | +end |
| 92 | + |
| 93 | +# --- Custom Primitive Definitions for BUGS --- |
| 94 | + |
| 95 | +# Register the GP kernel type with JuliaBUGS |
| 96 | +# This allows using AbstractGPs types directly in the model definition. |
| 97 | +JuliaBUGS.@register_primitive GP with_lengthscale SEKernel |
| 98 | + |
| 99 | +# Define a function callable within the BUGS model to compute GP predictions. |
| 100 | +# BUGS requires functions to operate on basic numerical types, so this wraps the GP call. |
| 101 | +JuliaBUGS.@register_primitive function gp_predict(v, l, d, jitter) |
| 102 | + # Create a GP with a Squared Exponential kernel using the provided hyperparameters |
| 103 | + kernel = v * with_lengthscale(SEKernel(), l) |
| 104 | + gp = GP(kernel) |
| 105 | + # Return the distribution representing the GP evaluated at distances `d` with jitter |
| 106 | + return gp(d, jitter) |
| 107 | +end |
| 108 | + |
| 109 | +# Define a function for the observation model (likelihood). |
| 110 | +# This creates a product distribution of Binomials, one for each distance. |
| 111 | +JuliaBUGS.@register_primitive function y_distribution(n, f_latent) |
| 112 | + return product_distribution(Binomial.(n, logistic.(f_latent))) |
| 113 | +end |
| 114 | + |
| 115 | +# --- Model Instantiation --- |
| 116 | + |
| 117 | +# Create the JuliaBUGS model instance |
| 118 | +# Provide initial values (missing for parameters to be inferred) and observed data |
| 119 | +model = gp_golf_putting( |
| 120 | + (; v=missing, l=missing, f_latent=fill(missing, data.N), y=data.y), |
| 121 | + data.N, # Number of observations |
| 122 | + data.n, # Observed attempts |
| 123 | + data.d, # Observed distances |
| 124 | + data.jitter, # Numerical stability term |
| 125 | +) |
| 126 | + |
| 127 | +# Optionally, set the evaluation mode. Using generated functions can be faster. |
| 128 | +# model = JuliaBUGS.set_evaluation_mode(model, JuliaBUGS.UseGeneratedLogDensityFunction()) |
| 129 | + |
| 130 | +# --- MCMC Setup with Custom LogDensityProblems Wrapper --- |
| 131 | + |
| 132 | +# We need a wrapper around the JuliaBUGS model to interface with LogDensityProblems |
| 133 | +# and utilize automatic differentiation (AD) via Mooncake.jl for gradient computation, |
| 134 | +# which is required by AdvancedHMC. |
| 135 | + |
| 136 | +struct BUGSMooncakeModel{T,P} |
| 137 | + model::T # The JuliaBUGS model |
| 138 | + prep::P # Pre-allocated workspace for gradient computation using Mooncake |
| 139 | +end |
| 140 | + |
| 141 | +# Define the function to compute the log density using the JuliaBUGS model's internal function |
| 142 | +f(x) = model.log_density_computation_function(model.evaluation_env, x) |
| 143 | + |
| 144 | +# Prepare the differentiation backend (Mooncake) |
| 145 | +backend = AutoMooncake(; config=nothing) |
| 146 | +x_init = rand(LogDensityProblems.dimension(model)) # Initial point for testing/preparation |
| 147 | +prep = prepare_gradient(f, backend, x_init) |
| 148 | + |
| 149 | +# Create the wrapped model instance |
| 150 | +bugsmooncake = BUGSMooncakeModel(model, prep) |
| 151 | + |
| 152 | +# --- LogDensityProblems Interface Implementation for the Wrapper --- |
| 153 | + |
| 154 | +# Define logdensity function for the wrapper |
| 155 | +function LogDensityProblems.logdensity(model::BUGSMooncakeModel, x::AbstractVector) |
| 156 | + return f(x) # Calls the underlying JuliaBUGS log density function |
| 157 | +end |
| 158 | + |
| 159 | +# Define logdensity_and_gradient function using the prepared DifferentiationInterface setup |
| 160 | +function LogDensityProblems.logdensity_and_gradient( |
| 161 | + model::BUGSMooncakeModel, x::AbstractVector |
| 162 | +) |
| 163 | + # Computes both the log density and its gradient using Mooncake AD |
| 164 | + return DifferentiationInterface.value_and_gradient( |
| 165 | + f, model.prep, AutoMooncake(; config=nothing), x |
| 166 | + ) |
| 167 | +end |
| 168 | + |
| 169 | +# Define dimension function |
| 170 | +function LogDensityProblems.dimension(model::BUGSMooncakeModel) |
| 171 | + return LogDensityProblems.dimension(model.model) # Delegates to the original model |
| 172 | +end |
| 173 | + |
| 174 | +# Define a custom bundle_samples function to convert the AdvancedHMC.Transition to a Chains object |
| 175 | +function AbstractMCMC.bundle_samples( |
| 176 | + ts::Vector{<:AdvancedHMC.Transition}, |
| 177 | + logdensitymodel::AbstractMCMC.LogDensityModel{<:BUGSMooncakeModel}, |
| 178 | + sampler::AdvancedHMC.AbstractHMCSampler, |
| 179 | + state, |
| 180 | + chain_type::Type{Chains}; |
| 181 | + discard_initial=0, |
| 182 | + thinning=1, |
| 183 | + kwargs..., |
| 184 | +) |
| 185 | + stats_names = collect(keys(merge((; lp=ts[1].z.ℓπ.value), AdvancedHMC.stat(ts[1])))) |
| 186 | + stats_values = [ |
| 187 | + vcat([ts[i].z.ℓπ.value..., collect(values(AdvancedHMC.stat(ts[i])))...]) for |
| 188 | + i in eachindex(ts) |
| 189 | + ] |
| 190 | + |
| 191 | + return JuliaBUGS.gen_chains( |
| 192 | + logdensitymodel.logdensity.model, |
| 193 | + [t.z.θ for t in ts], |
| 194 | + stats_names, |
| 195 | + stats_values; |
| 196 | + discard_initial=discard_initial, |
| 197 | + thinning=thinning, |
| 198 | + kwargs..., |
| 199 | + ) |
| 200 | +end |
| 201 | + |
| 202 | +# Specify capabilities (indicates gradient availability) |
| 203 | +function LogDensityProblems.capabilities(::Type{<:BUGSMooncakeModel}) |
| 204 | + return LogDensityProblems.LogDensityOrder{1}() # Can compute up to the gradient |
| 205 | +end |
| 206 | + |
| 207 | +# --- MCMC Sampling --- |
| 208 | + |
| 209 | +# Sample from the posterior distribution using AdvancedHMC's NUTS sampler |
| 210 | +samples_and_stats = AbstractMCMC.sample( |
| 211 | + AbstractMCMC.LogDensityModel(bugsmooncake), # Wrap the model for AbstractMCMC |
| 212 | + AdvancedHMC.NUTS(0.65), # No-U-Turn Sampler |
| 213 | + 1000; # Total number of samples |
| 214 | + chain_type=Chains, # Store results as MCMCChains object |
| 215 | + n_adapts=500, # Number of adaptation steps for NUTS |
| 216 | + discard_initial=500, # Number of initial samples (warmup) to discard; |
| 217 | +) |
0 commit comments