@@ -25,37 +25,34 @@ handling both scalar and vector parameters correctly.
2525function TuringCallbacks. params_and_values(
2626 model:: AbstractMCMC.LogDensityModel{<:BUGSModel} ,
2727 transition:: AdvancedHMC.Transition ;
28- kwargs...
28+ kwargs... ,
2929)
3030 bugs_model = model. logdensity
3131 gd = bugs_model. graph_evaluation_data
3232 param_names = gd. sorted_parameters
3333 param_values = transition. z. θ
34-
35- # Build pairs of (name, value) by mapping the flattened vector back to parameters
36- pairs = Tuple{String, Float64}[]
34+
35+ pairs = Tuple{String,Float64}[]
3736 pos = 1
38-
37+
3938 for vn in param_names
4039 len = if bugs_model. transformed
4140 bugs_model. transformed_var_lengths[vn]
4241 else
4342 bugs_model. untransformed_var_lengths[vn]
4443 end
45-
44+
4645 if len == 1
47- # Scalar parameter
4846 push!(pairs, (string(vn), param_values[pos]))
4947 pos += 1
5048 else
51- # Vector/array parameter - log each element individually
52- for i in 1 : len
49+ for i = 1 : len
5350 push!(pairs, (string(vn) * " [$i ]" , param_values[pos]))
5451 pos += 1
5552 end
5653 end
5754 end
58-
55+
5956 return pairs
6057end
6158
@@ -70,17 +67,17 @@ step size, tree depth, etc.
7067function TuringCallbacks. extras(
7168 model:: AbstractMCMC.LogDensityModel{<:BUGSModel} ,
7269 transition:: AdvancedHMC.Transition ;
73- kwargs...
70+ kwargs... ,
7471)
7572 # Extract HMC statistics from transition
7673 stats = AdvancedHMC. stat(transition)
7774 names = collect(keys(stats))
7875 vals = collect(values(stats))
79-
76+
8077 # Add log probability at the front
8178 pushfirst!(names, :lp)
8279 pushfirst!(vals, transition. z. ℓπ. value)
83-
80+
8481 return zip(string.(names), vals)
8582end
8683
@@ -92,12 +89,12 @@ Extract hyperparameters from a NUTS sampler used with JuliaBUGS models.
9289function TuringCallbacks. hyperparams(
9390 model:: AbstractMCMC.LogDensityModel{<:BUGSModel} ,
9491 sampler:: AdvancedHMC.NUTS ;
95- kwargs...
92+ kwargs... ,
9693)
9794 return [
9895 " target_acceptance" => sampler. δ,
9996 " max_depth" => sampler. max_depth,
100- " Δ_max" => sampler. Δ_max
97+ " Δ_max" => sampler. Δ_max,
10198 ]
10299end
103100
@@ -108,14 +105,14 @@ Return metric names to track for NUTS hyperparameters with JuliaBUGS models.
108105"""
109106function TuringCallbacks. hyperparam_metrics(
110107 model:: AbstractMCMC.LogDensityModel{<:BUGSModel} ,
111- sampler:: AdvancedHMC.NUTS
108+ sampler:: AdvancedHMC.NUTS ,
112109)
113110 return [
114111 " extras/acceptance_rate/stat/Mean" ,
115112 " extras/max_hamiltonian_energy_error/stat/Mean" ,
116113 " extras/lp/stat/Mean" ,
117114 " extras/n_steps/stat/Mean" ,
118- " extras/tree_depth/stat/Mean"
115+ " extras/tree_depth/stat/Mean" ,
119116 ]
120117end
121118
0 commit comments