Skip to content

Commit e4e412c

Browse files
Merge pull request #71 from ArndtLab/mcmc
fix HMC initialization bug
2 parents 3966deb + 3e9f003 commit e4e412c

File tree

2 files changed

+29
-19
lines changed

2 files changed

+29
-19
lines changed

src/mle_optimization.jl

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -197,37 +197,48 @@ function getFitResult(hess, para, lp, optim_result, options::FitOptions, counts,
197197
end
198198

199199
"""
200-
sample_model_epochs!(options::FitOptions, h::Histogram{T,1,E}; nsamples)
201-
202-
Sample `nsamples` from the posterior distribution of the parameters of the model
203-
of piece-wise constant epochs, given the observed histogram `h` and the fit
204-
options `options`. See also [`FitOptions`](@ref) for how to specify the fit
205-
options and `setinit!` to specify the initial parameters. Return a `Chains`
206-
object from the `MCMCDiagnostics` module of `Turing`, which contains the samples
207-
from the posterior distribution.
200+
sample_model_epochs!(options::FitOptions, h::Histogram{T,1,E}, init::AbstractVector{<:Real}; nsamples = 10_000, findmode = false)
201+
202+
Sample `nsamples` from the posterior distribution of the parameters, starting
203+
from initial point `init`.
204+
205+
Requires the observed histogram `h` and the fit options `options`.
206+
Return a `Chains` object from the `MCMCDiagnostics` module of `Turing`,
207+
which contains the samples from the posterior distribution.
208+
If `findmode` is true, the function will first find the mode of the
209+
posterior distribution using optimization, and then use that as the
210+
initial point for sampling. Otherwise, it will use the provided
211+
`init` as the initial point for sampling.
208212
"""
209-
function sample_model_epochs!(options::FitOptions, h::Histogram{T,1,E};
210-
nsamples::Int=10_000
213+
function sample_model_epochs!(options::FitOptions, h::Histogram{T,1,E},
214+
init::AbstractVector{<:Real}; nsamples::Int=10_000, findmode = false
211215
) where {T<:Integer,E<:Tuple{AbstractVector{<:Integer}}}
212-
sample_model_epochs!(options, h.edges[1], h.weights, Val(isnaive(options)); nsamples)
216+
@assert length(init)%2 == 0 "initial parameters should be of length 2*nepochs"
217+
setnepochs!(options, length(init)÷2)
218+
setinit!(options, init)
219+
sample_model_epochs!(options, h.edges[1], h.weights, Val(isnaive(options)); nsamples, findmode)
213220
end
214221

215222
function sample_model_epochs!(
216223
options::FitOptions, edges::AbstractVector{<:Integer}, counts::AbstractVector{<:Integer},
217224
::Val{true};
218-
nsamples::Int=10_000
225+
nsamples::Int=10_000, findmode = false
219226
)
220227
# get a good initial guess
221228
iszero(options.init) && initialize!(options, counts)
222229

223230
model = model_epochs(edges, counts, options.mu, options.locut, options.prior)
224231
logger = ConsoleLogger(stdout, Logging.Error)
225-
mle = with_logger(logger) do
226-
Turing.Optimisation.estimate_mode(
227-
model, MLE(), options.solver; initial_params=options.init, options.opt...
228-
)
232+
if findmode
233+
mle = with_logger(logger) do
234+
Turing.Optimisation.estimate_mode(
235+
model, MLE(), options.solver; initial_params=options.init, options.opt...
236+
)
237+
end
238+
setinit!(options, mle.values)
229239
end
230-
init_ = InitFromParams(mle)
240+
pars_ = Dict(DynamicPPL.VarName{:TN}() => options.init)
241+
init_ = InitFromParams(pars_)
231242
chain = with_logger(logger) do
232243
sample(model, NUTS(), nsamples; initial_params=init_)
233244
end

test/runtests.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,7 @@ end
132132
@test !any(best.opt.at_lboundary)
133133
@test !any(best.opt.at_uboundary[2:end])
134134
fcor = correctestimate!(fop, best, h)
135-
setinit!(fop, get_para(best))
136-
chain = sample_model_epochs!(fop, h; nsamples = 10)
135+
chain = sample_model_epochs!(fop, h, get_para(best); nsamples = 10, findmode = true)
137136

138137
resid = compute_residuals(h, mu, rho, TN)
139138
@test !any(isnan.(resid))

0 commit comments

Comments
 (0)