@@ -47,6 +47,14 @@ function AbstractMCMC.sample(
47
47
callback = nothing ,
48
48
kwargs... ,
49
49
)
50
+ if haskey (kwargs, :nadapts )
51
+ throw (
52
+ ArgumentError (
53
+ " keyword argument `nadapts` is unsupported. Please use `n_adapts` to specify the number of adaptation steps." ,
54
+ ),
55
+ )
56
+ end
57
+
50
58
if callback === nothing
51
59
callback = HMCProgressCallback (N, progress = progress, verbose = verbose)
52
60
progress = false # don't use AMCMC's progress-funtionality
@@ -78,6 +86,13 @@ function AbstractMCMC.sample(
78
86
callback = nothing ,
79
87
kwargs... ,
80
88
)
89
+ if haskey (kwargs, :nadapts )
90
+ throw (
91
+ ArgumentError (
92
+ " keyword argument `nadapts` is unsupported. Please use `n_adapts` to specify the number of adaptation steps." ,
93
+ ),
94
+ )
95
+ end
81
96
82
97
if callback === nothing
83
98
callback = HMCProgressCallback (N, progress = progress, verbose = verbose)
@@ -141,8 +156,17 @@ function AbstractMCMC.step(
141
156
model:: AbstractMCMC.LogDensityModel ,
142
157
spl:: AbstractHMCSampler ,
143
158
state:: HMCState ;
159
+ n_adapts:: Int = 0 ,
144
160
kwargs... ,
145
161
)
162
+ if haskey (kwargs, :nadapts )
163
+ throw (
164
+ ArgumentError (
165
+ " keyword argument `nadapts` is unsupported. Please use `n_adapts` to specify the number of adaptation steps." ,
166
+ ),
167
+ )
168
+ end
169
+
146
170
# Compute transition.
147
171
i = state. i + 1
148
172
t_old = state. transition
@@ -158,7 +182,6 @@ function AbstractMCMC.step(
158
182
159
183
# Adapt h and spl.
160
184
tstat = stat (t)
161
- n_adapts = kwargs[:n_adapts ]
162
185
h, κ, isadapted = adapt! (h, κ, adaptor, i, n_adapts, t. z. θ, tstat. acceptance_rate)
163
186
tstat = merge (tstat, (is_adapt = isadapted,))
164
187
@@ -189,8 +212,8 @@ struct HMCProgressCallback{P}
189
212
" If `progress` is not specified and this is `true` some information will be logged upon completion of adaptation."
190
213
verbose:: Bool
191
214
" Number of divergent transitions fo far."
192
- num_divergent_transitions:: Ref {Int}
193
- num_divergent_transitions_during_adaption:: Ref {Int}
215
+ num_divergent_transitions:: Base.RefValue {Int}
216
+ num_divergent_transitions_during_adaption:: Base.RefValue {Int}
194
217
end
195
218
196
219
function HMCProgressCallback (n_samples; progress = true , verbose = false )
@@ -200,7 +223,16 @@ function HMCProgressCallback(n_samples; progress = true, verbose = false)
200
223
HMCProgressCallback (pm, progress, verbose, Ref (0 ), Ref (0 ))
201
224
end
202
225
203
- function (cb:: HMCProgressCallback )(rng, model, spl, t, state, i; nadapts = 0 , kwargs... )
226
+ function (cb:: HMCProgressCallback )(
227
+ rng,
228
+ model,
229
+ spl,
230
+ t,
231
+ state,
232
+ i;
233
+ n_adapts:: Int = 0 ,
234
+ kwargs... ,
235
+ )
204
236
progress = cb. progress
205
237
verbose = cb. verbose
206
238
pm = cb. pm
@@ -243,8 +275,8 @@ function (cb::HMCProgressCallback)(rng, model, spl, t, state, i; nadapts = 0, kw
243
275
),
244
276
)
245
277
# Report finish of adapation
246
- elseif verbose && isadapted && i == nadapts
247
- @info " Finished $nadapts adapation steps" adaptor κ. τ. integrator metric
278
+ elseif verbose && isadapted && i == n_adapts
279
+ @info " Finished $(n_adapts) adapation steps" adaptor κ. τ. integrator metric
248
280
end
249
281
end
250
282
0 commit comments