|
563 | 563 | function AHMCAdaptor(::Hamiltonian, ::AHMC.AbstractMetric; kwargs...)
|
564 | 564 | return AHMC.Adaptation.NoAdaptation()
|
565 | 565 | end
|
566 |
| - |
567 |
| -########################## |
568 |
| -# HMC State Constructors # |
569 |
| -########################## |
570 |
| - |
571 |
| -function HMCState( |
572 |
| - rng::AbstractRNG, |
573 |
| - model::Model, |
574 |
| - spl::Sampler{<:Hamiltonian}, |
575 |
| - vi::AbstractVarInfo; |
576 |
| - kwargs..., |
577 |
| -) |
578 |
| - # Link everything if needed. |
579 |
| - waslinked = islinked(vi, spl) |
580 |
| - if !waslinked |
581 |
| - vi = link!!(vi, spl, model) |
582 |
| - end |
583 |
| - |
584 |
| - # Get the initial log pdf and gradient functions. |
585 |
| - ∂logπ∂θ = gen_∂logπ∂θ(vi, spl, model) |
586 |
| - logπ = Turing.LogDensityFunction( |
587 |
| - vi, |
588 |
| - model, |
589 |
| - DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context)), |
590 |
| - ) |
591 |
| - |
592 |
| - # Get the metric type. |
593 |
| - metricT = getmetricT(spl.alg) |
594 |
| - |
595 |
| - # Create a Hamiltonian. |
596 |
| - θ_init = Vector{Float64}(spl.state.vi[spl]) |
597 |
| - metric = metricT(length(θ_init)) |
598 |
| - h = AHMC.Hamiltonian(metric, logπ, ∂logπ∂θ) |
599 |
| - |
600 |
| - # Find good eps if not provided one |
601 |
| - if iszero(spl.alg.ϵ) |
602 |
| - ϵ = AHMC.find_good_stepsize(rng, h, θ_init) |
603 |
| - @info "Found initial step size" ϵ |
604 |
| - else |
605 |
| - ϵ = spl.alg.ϵ |
606 |
| - end |
607 |
| - |
608 |
| - # Generate a kernel. |
609 |
| - kernel = make_ahmc_kernel(spl.alg, ϵ) |
610 |
| - |
611 |
| - # Generate a phasepoint. Replaced during sample_init! |
612 |
| - h, t = AHMC.sample_init(rng, h, θ_init) # this also ensure AHMC has the same dim as θ. |
613 |
| - |
614 |
| - # Unlink everything, if it was indeed linked before. |
615 |
| - if waslinked |
616 |
| - vi = invlink!!(vi, spl, model) |
617 |
| - end |
618 |
| - |
619 |
| - return HMCState(vi, 0, 0, kernel.τ, h, AHMCAdaptor(spl.alg, metric; ϵ=ϵ), t.z) |
620 |
| -end |
0 commit comments