Skip to content

Commit dc39702

Browse files
format
1 parent 726b68a commit dc39702

File tree

1 file changed

+14
-17
lines changed

1 file changed

+14
-17
lines changed

ext/TuringCallbacksJuliaBUGSExt.jl

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,37 +25,34 @@ handling both scalar and vector parameters correctly.
2525
function TuringCallbacks.params_and_values(
2626
model::AbstractMCMC.LogDensityModel{<:BUGSModel},
2727
transition::AdvancedHMC.Transition;
28-
kwargs...
28+
kwargs...,
2929
)
3030
bugs_model = model.logdensity
3131
gd = bugs_model.graph_evaluation_data
3232
param_names = gd.sorted_parameters
3333
param_values = transition.z.θ
34-
35-
# Build pairs of (name, value) by mapping the flattened vector back to parameters
36-
pairs = Tuple{String, Float64}[]
34+
35+
pairs = Tuple{String,Float64}[]
3736
pos = 1
38-
37+
3938
for vn in param_names
4039
len = if bugs_model.transformed
4140
bugs_model.transformed_var_lengths[vn]
4241
else
4342
bugs_model.untransformed_var_lengths[vn]
4443
end
45-
44+
4645
if len == 1
47-
# Scalar parameter
4846
push!(pairs, (string(vn), param_values[pos]))
4947
pos += 1
5048
else
51-
# Vector/array parameter - log each element individually
52-
for i in 1:len
49+
for i = 1:len
5350
push!(pairs, (string(vn) * "[$i]", param_values[pos]))
5451
pos += 1
5552
end
5653
end
5754
end
58-
55+
5956
return pairs
6057
end
6158

@@ -70,17 +67,17 @@ step size, tree depth, etc.
7067
function TuringCallbacks.extras(
7168
model::AbstractMCMC.LogDensityModel{<:BUGSModel},
7269
transition::AdvancedHMC.Transition;
73-
kwargs...
70+
kwargs...,
7471
)
7572
# Extract HMC statistics from transition
7673
stats = AdvancedHMC.stat(transition)
7774
names = collect(keys(stats))
7875
vals = collect(values(stats))
79-
76+
8077
# Add log probability at the front
8178
pushfirst!(names, :lp)
8279
pushfirst!(vals, transition.z.ℓπ.value)
83-
80+
8481
return zip(string.(names), vals)
8582
end
8683

@@ -92,12 +89,12 @@ Extract hyperparameters from a NUTS sampler used with JuliaBUGS models.
9289
function TuringCallbacks.hyperparams(
9390
model::AbstractMCMC.LogDensityModel{<:BUGSModel},
9491
sampler::AdvancedHMC.NUTS;
95-
kwargs...
92+
kwargs...,
9693
)
9794
return [
9895
"target_acceptance" => sampler.δ,
9996
"max_depth" => sampler.max_depth,
100-
"Δ_max" => sampler.Δ_max
97+
"Δ_max" => sampler.Δ_max,
10198
]
10299
end
103100

@@ -108,14 +105,14 @@ Return metric names to track for NUTS hyperparameters with JuliaBUGS models.
108105
"""
109106
function TuringCallbacks.hyperparam_metrics(
110107
model::AbstractMCMC.LogDensityModel{<:BUGSModel},
111-
sampler::AdvancedHMC.NUTS
108+
sampler::AdvancedHMC.NUTS,
112109
)
113110
return [
114111
"extras/acceptance_rate/stat/Mean",
115112
"extras/max_hamiltonian_energy_error/stat/Mean",
116113
"extras/lp/stat/Mean",
117114
"extras/n_steps/stat/Mean",
118-
"extras/tree_depth/stat/Mean"
115+
"extras/tree_depth/stat/Mean",
119116
]
120117
end
121118

0 commit comments

Comments
 (0)