@@ -197,37 +197,48 @@ function getFitResult(hess, para, lp, optim_result, options::FitOptions, counts,
197197end
198198
199199"""
200- sample_model_epochs!(options::FitOptions, h::Histogram{T,1,E}; nsamples)
201-
202- Sample `nsamples` from the posterior distribution of the parameters of the model
203- of piece-wise constant epochs, given the observed histogram `h` and the fit
204- options `options`. See also [`FitOptions`](@ref) for how to specify the fit
205- options and `setinit!` to specify the initial parameters. Return a `Chains`
206- object from the `MCMCDiagnostics` module of `Turing`, which contains the samples
207- from the posterior distribution.
200+ sample_model_epochs!(options::FitOptions, h::Histogram{T,1,E}, init::AbstractVector{<:Real}; nsamples = 10_000, findmode = false)
201+
202+ Sample `nsamples` from the posterior distribution of the parameters, starting
203+ from initial point `init`.
204+
205+ Requires the observed histogram `h` and the fit options `options`.
206+ Return a `Chains` object from the `MCMCDiagnostics` module of `Turing`,
207+ which contains the samples from the posterior distribution.
208+ If `findmode` is true, the function will first find the mode of the
209+ posterior distribution using optimization, and then use that as the
210+ initial point for sampling. Otherwise, it will use the provided
211+ `init` as the initial point for sampling.
208212"""
209- function sample_model_epochs! (options:: FitOptions , h:: Histogram{T,1,E} ;
210- nsamples:: Int = 10_000
213+ function sample_model_epochs! (options:: FitOptions , h:: Histogram{T,1,E} ,
214+ init :: AbstractVector{<:Real} ; nsamples:: Int = 10_000 , findmode = false
211215) where {T<: Integer ,E<: Tuple{AbstractVector{<:Integer}} }
212- sample_model_epochs! (options, h. edges[1 ], h. weights, Val (isnaive (options)); nsamples)
216+ @assert length (init)% 2 == 0 " initial parameters should be of length 2*nepochs"
217+ setnepochs! (options, length (init)÷ 2 )
218+ setinit! (options, init)
219+ sample_model_epochs! (options, h. edges[1 ], h. weights, Val (isnaive (options)); nsamples, findmode)
213220end
214221
215222function sample_model_epochs! (
216223 options:: FitOptions , edges:: AbstractVector{<:Integer} , counts:: AbstractVector{<:Integer} ,
217224 :: Val{true} ;
218- nsamples:: Int = 10_000
225+ nsamples:: Int = 10_000 , findmode = false
219226)
220227 # get a good initial guess
221228 iszero (options. init) && initialize! (options, counts)
222229
223230 model = model_epochs (edges, counts, options. mu, options. locut, options. prior)
224231 logger = ConsoleLogger (stdout , Logging. Error)
225- mle = with_logger (logger) do
226- Turing. Optimisation. estimate_mode (
227- model, MLE (), options. solver; initial_params= options. init, options. opt...
228- )
232+ if findmode
233+ mle = with_logger (logger) do
234+ Turing. Optimisation. estimate_mode (
235+ model, MLE (), options. solver; initial_params= options. init, options. opt...
236+ )
237+ end
238+ setinit! (options, mle. values)
229239 end
230- init_ = InitFromParams (mle)
240+ pars_ = Dict (DynamicPPL. VarName {:TN} () => options. init)
241+ init_ = InitFromParams (pars_)
231242 chain = with_logger (logger) do
232243 sample (model, NUTS (), nsamples; initial_params= init_)
233244 end
0 commit comments