Skip to content

Commit f97cb2b

Browse files
authored
Simplify callback interface
1 parent 3823ea9 commit f97cb2b

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

src/sample.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ Return `N` samples from the MCMC `sampler` for the provided `model`.
4949
5050
A callback function `f` with type signature
5151
```julia
52-
f(rng, model, sampler, N, iteration, transition; kwargs...)
52+
f(rng, model, sampler, transition, iteration)
5353
```
5454
may be provided as keyword argument `callback`. It is called after every sampling step.
5555
"""
@@ -60,7 +60,7 @@ function mcmcsample(
6060
N::Integer;
6161
progress = true,
6262
progressname = "Sampling",
63-
callback = (args...; kwargs...) -> nothing,
63+
callback = (args...) -> nothing,
6464
chain_type::Type=Any,
6565
kwargs...
6666
)
@@ -75,7 +75,7 @@ function mcmcsample(
7575
transition = step!(rng, model, sampler, N; iteration=1, kwargs...)
7676

7777
# Run callback.
78-
callback(rng, model, sampler, N, 1, transition; kwargs...)
78+
callback(rng, model, sampler, transition, 1)
7979

8080
# Save the transition.
8181
transitions = transitions_init(transition, model, sampler, N; kwargs...)
@@ -90,7 +90,7 @@ function mcmcsample(
9090
transition = step!(rng, model, sampler, N, transition; iteration=i, kwargs...)
9191

9292
# Run callback.
93-
callback(rng, model, sampler, N, i, transition; kwargs...)
93+
callback(rng, model, sampler, transition, i)
9494

9595
# Save the transition.
9696
transitions_save!(transitions, i, transition, model, sampler, N; kwargs...)
@@ -119,7 +119,7 @@ and should return `true` when sampling should end, and `false` otherwise.
119119
120120
A callback function `f` with type signature
121121
```julia
122-
f(rng, model, sampler, N, iteration, transition; kwargs...)
122+
f(rng, model, sampler, transition, iteration)
123123
```
124124
may be provided as keyword argument `callback`. It is called after every sampling step.
125125
"""
@@ -131,7 +131,7 @@ function mcmcsample(
131131
chain_type::Type=Any,
132132
progress = true,
133133
progressname = "Convergence sampling",
134-
callback = (args...; kwargs...) -> nothing,
134+
callback = (args...) -> nothing,
135135
kwargs...
136136
)
137137
# Perform any necessary setup.
@@ -142,7 +142,7 @@ function mcmcsample(
142142
transition = step!(rng, model, sampler, 1; iteration=1, kwargs...)
143143

144144
# Run callback.
145-
callback(rng, model, sampler, 1, 1, transition; kwargs...)
145+
callback(rng, model, sampler, transition, 1)
146146

147147
# Save the transition.
148148
transitions = transitions_init(transition, model, sampler; kwargs...)
@@ -155,7 +155,7 @@ function mcmcsample(
155155
transition = step!(rng, model, sampler, 1, transition; iteration=i, kwargs...)
156156

157157
# Run callback.
158-
callback(rng, model, sampler, 1, i, transition; kwargs...)
158+
callback(rng, model, sampler, transition, i)
159159

160160
# Save the transition.
161161
transitions_save!(transitions, i, transition, model, sampler; kwargs...)
@@ -244,4 +244,4 @@ function mcmcpsample(
244244

245245
# Concatenate the chains together.
246246
return reduce(chainscat, chains)
247-
end
247+
end

0 commit comments

Comments
 (0)