Skip to content

Commit 52316db

Browse files
Add some inter-op examples (#249)
partially address #245 --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent d034d65 commit 52316db

File tree

6 files changed

+491
-0
lines changed

6 files changed

+491
-0
lines changed

examples/.JuliaFormatter.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
style = "blue"
2+
always_use_return = false

examples/Project.toml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
[deps]
2+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
3+
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
4+
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
5+
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
6+
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
7+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
8+
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
9+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
10+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
11+
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
12+
JuliaBUGS = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf"
13+
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
14+
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
15+
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
16+
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
17+
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
18+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
19+
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
20+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"

examples/README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# JuliaBUGS Examples
2+
3+
We adapted some examples to show how to use JuliaBUGS in this repo.
4+
5+
## Sources
6+
7+
* SIR: https://github.com/TuringLang/Turing-Workshop/tree/main/2023-MRC-BSU-and-UKHSA/Part-2-More-Julia-and-some-Bayesian-inference
8+
* GP: https://turinglang.org/docs/tutorials/gaussian-processes-introduction/
9+
* BNN: https://turinglang.org/docs/tutorials/bayesian-neural-networks/

examples/bnn.jl

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
using JuliaBUGS
2+
3+
using AbstractMCMC
4+
using ADTypes
5+
using AdvancedHMC
6+
using DifferentiationInterface
7+
using FillArrays
8+
using Functors
9+
using LinearAlgebra
10+
using LogDensityProblems
11+
using LogDensityProblemsAD
12+
using Lux
13+
using MCMCChains
14+
using Mooncake
15+
using Random
16+
17+
## data simulation
18+
19+
# Number of points to generate
20+
N = 80
21+
M = round(Int, N / 4)
22+
rng = Random.default_rng()
23+
Random.seed!(rng, 1234)
24+
25+
# Generate artificial data
26+
x1s = rand(rng, Float32, M) * 4.5f0;
27+
x2s = rand(rng, Float32, M) * 4.5f0;
28+
xt1s = Array([[x1s[i] + 0.5f0; x2s[i] + 0.5f0] for i in 1:M])
29+
x1s = rand(rng, Float32, M) * 4.5f0;
30+
x2s = rand(rng, Float32, M) * 4.5f0;
31+
append!(xt1s, Array([[x1s[i] - 5.0f0; x2s[i] - 5.0f0] for i in 1:M]))
32+
33+
x1s = rand(rng, Float32, M) * 4.5f0;
34+
x2s = rand(rng, Float32, M) * 4.5f0;
35+
xt0s = Array([[x1s[i] + 0.5f0; x2s[i] - 5.0f0] for i in 1:M])
36+
x1s = rand(rng, Float32, M) * 4.5f0;
37+
x2s = rand(rng, Float32, M) * 4.5f0;
38+
append!(xt0s, Array([[x1s[i] - 5.0f0; x2s[i] + 0.5f0] for i in 1:M]))
39+
40+
# Store all the data for later
41+
xs = [xt1s; xt0s]
42+
xs_hcat = Float64.(reduce(hcat, xs))
43+
ts = [ones(2 * M); zeros(2 * M)]
44+
45+
alpha = 0.09
46+
sigma = sqrt(1.0 / alpha)
47+
48+
##
49+
50+
# Construct a neural network using Lux
51+
nn_initial = Chain(Dense(2 => 3, tanh), Dense(3 => 2, tanh), Dense(2 => 1, σ))
52+
53+
# Initialize the model weights and state
54+
ps, st = Lux.setup(rng, nn_initial)
55+
56+
Lux.parameterlength(nn_initial) # number of parameters in NN
57+
58+
function vector_to_parameters(ps_new::AbstractVector, ps::NamedTuple)
59+
@assert length(ps_new) == Lux.parameterlength(ps)
60+
i = 1
61+
function get_ps(x)
62+
z = reshape(view(ps_new, i:(i + length(x) - 1)), size(x))
63+
i += length(x)
64+
return z
65+
end
66+
return fmap(get_ps, ps)
67+
end
68+
69+
const nn = StatefulLuxLayer{true}(nn_initial, nothing, st)
70+
71+
model_def = @bugs begin
72+
parameters[1:nparameters] ~ parameter_distribution(nparameters, sigma)
73+
predictions[1:N] = make_prediction(parameters[1:nparameters], xs[:, :])
74+
for i in 1:N
75+
ts[i] ~ Bernoulli(predictions[i])
76+
end
77+
end
78+
79+
JuliaBUGS.@register_primitive function parameter_distribution(nparameters, sigma)
80+
return MvNormal(zeros(nparameters), Diagonal(abs2.(sigma .* ones(nparameters))))
81+
end
82+
83+
JuliaBUGS.@register_primitive function make_prediction(parameters, xs; ps=ps, nn=nn)
84+
return Lux.apply(nn, f32(xs), f32(vector_to_parameters(parameters, ps)))
85+
end
86+
87+
@eval JuliaBUGS begin
88+
ps = Main.ps
89+
nn = Main.nn
90+
Lux = Main.Lux
91+
f32 = Main.f32
92+
vector_to_parameters = Main.vector_to_parameters
93+
end
94+
95+
data = (nparameters=Lux.parameterlength(nn), xs=xs_hcat, ts=ts, N=length(ts), sigma=sigma)
96+
97+
model = compile(model_def, data)
98+
99+
ad_model = ADgradient(AutoMooncake(; config=Mooncake.Config()), model)
100+
101+
# sampling is slow, so sample 10 of them to verify that this can work
102+
samples_and_stats = AbstractMCMC.sample(
103+
ad_model,
104+
NUTS(0.65),
105+
10;
106+
chain_type=Chains,
107+
# n_adapts=1000,
108+
# discard_initial=1000
109+
)

examples/gp.jl

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
# Example demonstrating the use of Gaussian Processes (GPs) within JuliaBUGS
2+
# for modeling golf putting accuracy based on distance.
3+
# This example uses AbstractGPs.jl for the GP implementation and AdvancedHMC.jl
4+
# for sampling from the posterior distribution.
5+
6+
using JuliaBUGS
7+
using JuliaBUGS: @model
8+
9+
# Required packages for GP modeling and MCMC
10+
using AbstractGPs, Distributions, LogExpFunctions
11+
using LogDensityProblems, LogDensityProblemsAD
12+
using AbstractMCMC, AdvancedHMC, MCMCChains
13+
14+
# Differentiation backend
15+
using DifferentiationInterface
16+
using Mooncake: Mooncake
17+
18+
# --- Data Definition ---
19+
20+
# Golf putting data from Gelman et al. (BDA3, Chapter 5)
21+
golf_data = (
22+
distance=[2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], # Distance in feet
23+
n=[ # Number of putts attempted
24+
1443,
25+
694,
26+
455,
27+
353,
28+
272,
29+
256,
30+
240,
31+
217,
32+
200,
33+
237,
34+
202,
35+
192,
36+
174,
37+
167,
38+
201,
39+
195,
40+
191,
41+
147,
42+
152,
43+
],
44+
y=[ # Number of successful putts
45+
1346,
46+
577,
47+
337,
48+
208,
49+
149,
50+
136,
51+
111,
52+
69,
53+
67,
54+
75,
55+
52,
56+
46,
57+
54,
58+
28,
59+
27,
60+
31,
61+
33,
62+
20,
63+
24,
64+
],
65+
)
66+
67+
# Prepare data in the format expected by the BUGS model
68+
data = (
69+
d=golf_data.distance,
70+
n=golf_data.n,
71+
y=golf_data.y,
72+
jitter=1e-6, # Small value added to GP kernel diagonal for numerical stability
73+
N=length(golf_data.distance),
74+
)
75+
76+
# --- BUGS Model Definition ---
77+
78+
@model function gp_golf_putting((; v, l, f_latent, y), N, n, d, jitter)
79+
# Priors for GP hyperparameters
80+
v ~ Distributions.Gamma(2, 1) # Variance
81+
l ~ Distributions.Gamma(4, 1) # Lengthscale
82+
83+
# Latent GP function values
84+
# f_latent represents the underlying putting success probability (on logit scale)
85+
# modeled by a GP.
86+
f_latent[1:N] ~ gp_predict(v, l, d[1:N], jitter)
87+
88+
# Likelihood: Binomial distribution for observed successes
89+
# The success probability for each distance is the logistic transformation of the latent GP value.
90+
y[1:N] ~ y_distribution(n[1:N], f_latent[1:N])
91+
end
92+
93+
# --- Custom Primitive Definitions for BUGS ---
94+
95+
# Register the GP kernel type with JuliaBUGS
96+
# This allows using AbstractGPs types directly in the model definition.
97+
JuliaBUGS.@register_primitive GP with_lengthscale SEKernel
98+
99+
# Define a function callable within the BUGS model to compute GP predictions.
100+
# BUGS requires functions to operate on basic numerical types, so this wraps the GP call.
101+
JuliaBUGS.@register_primitive function gp_predict(v, l, d, jitter)
102+
# Create a GP with a Squared Exponential kernel using the provided hyperparameters
103+
kernel = v * with_lengthscale(SEKernel(), l)
104+
gp = GP(kernel)
105+
# Return the distribution representing the GP evaluated at distances `d` with jitter
106+
return gp(d, jitter)
107+
end
108+
109+
# Define a function for the observation model (likelihood).
110+
# This creates a product distribution of Binomials, one for each distance.
111+
JuliaBUGS.@register_primitive function y_distribution(n, f_latent)
112+
return product_distribution(Binomial.(n, logistic.(f_latent)))
113+
end
114+
115+
# --- Model Instantiation ---
116+
117+
# Create the JuliaBUGS model instance
118+
# Provide initial values (missing for parameters to be inferred) and observed data
119+
model = gp_golf_putting(
120+
(; v=missing, l=missing, f_latent=fill(missing, data.N), y=data.y),
121+
data.N, # Number of observations
122+
data.n, # Observed attempts
123+
data.d, # Observed distances
124+
data.jitter, # Numerical stability term
125+
)
126+
127+
# Optionally, set the evaluation mode. Using generated functions can be faster.
128+
# model = JuliaBUGS.set_evaluation_mode(model, JuliaBUGS.UseGeneratedLogDensityFunction())
129+
130+
# --- MCMC Setup with Custom LogDensityProblems Wrapper ---
131+
132+
# We need a wrapper around the JuliaBUGS model to interface with LogDensityProblems
133+
# and utilize automatic differentiation (AD) via Mooncake.jl for gradient computation,
134+
# which is required by AdvancedHMC.
135+
136+
struct BUGSMooncakeModel{T,P}
137+
model::T # The JuliaBUGS model
138+
prep::P # Pre-allocated workspace for gradient computation using Mooncake
139+
end
140+
141+
# Define the function to compute the log density using the JuliaBUGS model's internal function
142+
f(x) = model.log_density_computation_function(model.evaluation_env, x)
143+
144+
# Prepare the differentiation backend (Mooncake)
145+
backend = AutoMooncake(; config=nothing)
146+
x_init = rand(LogDensityProblems.dimension(model)) # Initial point for testing/preparation
147+
prep = prepare_gradient(f, backend, x_init)
148+
149+
# Create the wrapped model instance
150+
bugsmooncake = BUGSMooncakeModel(model, prep)
151+
152+
# --- LogDensityProblems Interface Implementation for the Wrapper ---
153+
154+
# Define logdensity function for the wrapper
155+
function LogDensityProblems.logdensity(model::BUGSMooncakeModel, x::AbstractVector)
156+
return f(x) # Calls the underlying JuliaBUGS log density function
157+
end
158+
159+
# Define logdensity_and_gradient function using the prepared DifferentiationInterface setup
160+
function LogDensityProblems.logdensity_and_gradient(
161+
model::BUGSMooncakeModel, x::AbstractVector
162+
)
163+
# Computes both the log density and its gradient using Mooncake AD
164+
return DifferentiationInterface.value_and_gradient(
165+
f, model.prep, AutoMooncake(; config=nothing), x
166+
)
167+
end
168+
169+
# Define dimension function
170+
function LogDensityProblems.dimension(model::BUGSMooncakeModel)
171+
return LogDensityProblems.dimension(model.model) # Delegates to the original model
172+
end
173+
174+
# Define a custom bundle_samples function to convert the AdvancedHMC.Transition to a Chains object
175+
function AbstractMCMC.bundle_samples(
176+
ts::Vector{<:AdvancedHMC.Transition},
177+
logdensitymodel::AbstractMCMC.LogDensityModel{<:BUGSMooncakeModel},
178+
sampler::AdvancedHMC.AbstractHMCSampler,
179+
state,
180+
chain_type::Type{Chains};
181+
discard_initial=0,
182+
thinning=1,
183+
kwargs...,
184+
)
185+
stats_names = collect(keys(merge((; lp=ts[1].z.ℓπ.value), AdvancedHMC.stat(ts[1]))))
186+
stats_values = [
187+
vcat([ts[i].z.ℓπ.value..., collect(values(AdvancedHMC.stat(ts[i])))...]) for
188+
i in eachindex(ts)
189+
]
190+
191+
return JuliaBUGS.gen_chains(
192+
logdensitymodel.logdensity.model,
193+
[t.z.θ for t in ts],
194+
stats_names,
195+
stats_values;
196+
discard_initial=discard_initial,
197+
thinning=thinning,
198+
kwargs...,
199+
)
200+
end
201+
202+
# Specify capabilities (indicates gradient availability)
203+
function LogDensityProblems.capabilities(::Type{<:BUGSMooncakeModel})
204+
return LogDensityProblems.LogDensityOrder{1}() # Can compute up to the gradient
205+
end
206+
207+
# --- MCMC Sampling ---
208+
209+
# Sample from the posterior distribution using AdvancedHMC's NUTS sampler
210+
samples_and_stats = AbstractMCMC.sample(
211+
AbstractMCMC.LogDensityModel(bugsmooncake), # Wrap the model for AbstractMCMC
212+
AdvancedHMC.NUTS(0.65), # No-U-Turn Sampler
213+
1000; # Total number of samples
214+
chain_type=Chains, # Store results as MCMCChains object
215+
n_adapts=500, # Number of adaptation steps for NUTS
216+
discard_initial=500, # Number of initial samples (warmup) to discard;
217+
)

0 commit comments

Comments
 (0)