Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 42 additions & 23 deletions src/adaptation/stepsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ mutable struct DAState{T<:AbstractScalarOrVec{<:AbstractFloat}}
H_bar::T
end

computeμ(ϵ::AbstractScalarOrVec{<:AbstractFloat}) = log.(10 * ϵ)
computeμ(ϵ::AbstractFloat) = log(10 * ϵ)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution is required here: these support the vectorised version of HMC. Do you know how map would differ from broadcasting here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The results of the calculations won't be affected by this change, but using the non-broadcasted formulation for scalars and map for vectors of floats will remove the broadcasting overhead and reduce stress on the compiler, i.e., generally it reduces compilation time. Sometimes it also helps type inference (but this case is too simple for this effect I assume).

In my experience, broadcasting is useful if one's actually broadcasting values of different size and dimensions but otherwise often a suboptimal choice.


function DAState(ϵ::T) where {T}
μ = computeμ(ϵ)
Expand All @@ -17,22 +17,34 @@ end

function DAState(ϵ::AbstractVector{T}) where {T}
n = length(ϵ)
μ = computeμ(ϵ)
μ = map(computeμ, ϵ)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
μ = map(computeμ, ϵ)
μ = computeμ(ϵ)

return DAState(0, ϵ, μ, zeros(T, n), zeros(T, n))
end

function reset!(das::DAState{T}) where {T<:AbstractFloat}
das.m = 0
das.μ = computeμ(das.ϵ)
das.x_bar = zero(T)
return das.H_bar = zero(T)
das.H_bar = zero(T)
return das
end

function reset!(das::DAState{<:AbstractVector{T}}) where {T<:AbstractFloat}
das.m = 0
das.μ .= computeμ(das.ϵ)
das.x_bar .= zero(T)
return das.H_bar .= zero(T)
map!(computeμ, das.μ, das.ϵ)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep this as-is for now. We could refactor the vectorised HMC interface, but better to do it seprately in a concerted effort:

Suggested change
map!(computeμ, das.μ, das.ϵ)
das.μ .= computeμ(das.ϵ)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This suggestion would go against the main intention of the PR, reducing unnecessary allocations: With map! (or das.μ .= computeμ.(das.ϵ), but the broadcasting is more stressful for the compiler) no intermediate array would be created in this line, whereas with the suggestion on the right-hand side a new array is allocated that is then copied to das.μ (as a side remark, for the compiler copyto! should be simpler than broadcasting).

Copy link
Member

@yebai yebai Mar 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point!

I recently discovered the AcceleratedKernels package, which provides a unified interface for parallelisation on CPUs, clusters, and GPUs. We could consider switching to AcceleratedKernels.map! for the vectorised HMC implementation, thus the above suggestion.

EDIT: I opened an issue for this suggestion. #412

fill!(das.x_bar, zero(T))
fill!(das.H_bar, zero(T))
return das
end

function finalize!(das::DAState{<:AbstractFloat})
das.ϵ = exp(das.x_bar)
return das
end

function finalize!(das::DAState{<:AbstractVector{<:AbstractFloat}})
map!(exp, das.ϵ, das.x_bar)
return das
end

mutable struct MSSState{T<:AbstractScalarOrVec{<:AbstractFloat}}
Expand All @@ -51,7 +63,7 @@ getϵ(ss::StepSizeAdaptor) = ss.state.ϵ
struct FixedStepSize{T<:AbstractScalarOrVec{<:AbstractFloat}} <: StepSizeAdaptor
ϵ::T
end
Base.show(io::IO, a::FixedStepSize) = print(io, "FixedStepSize($(a.ϵ))")
Base.show(io::IO, a::FixedStepSize) = print(io, "FixedStepSize(", a.ϵ, ")")

getϵ(fss::FixedStepSize) = fss.ϵ

Expand Down Expand Up @@ -82,7 +94,17 @@ end
function Base.show(io::IO, a::NesterovDualAveraging)
return print(
io,
"NesterovDualAveraging(γ=$(a.γ), t_0=$(a.t_0), κ=$(a.κ), δ=$(a.δ), state.ϵ=$(getϵ(a)))",
"NesterovDualAveraging(γ=",
a.γ,
", t_0=",
a.t_0,
", κ=",
a.κ,
", δ=",
a.δ,
", state.ϵ=",
getϵ(a),
")",
)
end

Expand All @@ -95,35 +117,29 @@ end
function NesterovDualAveraging(
δ::T, ϵ::VT
) where {T<:AbstractFloat,VT<:AbstractScalarOrVec{T}}
return NesterovDualAveraging(T(0.05), T(10.0), T(0.75), δ, ϵ)
return NesterovDualAveraging(T(1//20), T(10), T(3//4), δ, ϵ)
end

# Ref: https://github.com/stan-dev/stan/blob/develop/src/stan/mcmc/stepsize_adaptation.hpp
# Note: This function is not merged with `adapt!` to empahsize the fact that
# step size adaptation is not dependent on `θ`.
# Note 2: `da.state` and `α` support vectorised HMC but should do so together.
function adapt_stepsize!(
da::NesterovDualAveraging{T}, α::AbstractScalarOrVec{<:T}
da::NesterovDualAveraging{T}, α::AbstractScalarOrVec{T}
) where {T<:AbstractFloat}
@debug "Adapting step size..." α

# Clip average MH acceptance probability
if α isa AbstractVector
α[α .> 1] .= one(T)
else
α = α > 1 ? one(T) : α
end

(; state, γ, t_0, κ, δ) = da
(; μ, m, x_bar, H_bar) = state

m = m + 1

η_H = one(T) / (m + t_0)
H_bar = (one(T) - η_H) * H_bar .+ η_H * (δ .- α)
H_bar = (one(T) - η_H) .* H_bar .+ η_H .* (δ .- min.(one(T), α))
Copy link
Member

@yebai yebai Mar 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HG: I'll have to review these more carefully later this week.

EDIT: This looks good. I am surprised the previous code didn't break any tests, as it didn't properly support vectorised adaption.


x = μ .- H_bar * sqrt(m) / γ # x ≡ logϵ
x = μ .- H_bar .* (sqrt(m) / γ) # x ≡ logϵ
η_x = m^(-κ)
x_bar = (one(T) - η_x) * x_bar .+ η_x * x
x_bar = (one(T) - η_x) .* x_bar .+ η_x .* x

ϵ = exp.(x)
@debug "Adapting step size..." new_ϵ = ϵ old_ϵ = da.state.ϵ
Expand Down Expand Up @@ -151,9 +167,12 @@ function adapt!(
return nothing
end

reset!(da::NesterovDualAveraging) = reset!(da.state)
function reset!(da::NesterovDualAveraging)
reset!(da.state)
return da
end

function finalize!(da::NesterovDualAveraging)
da.state.ϵ = exp.(da.state.x_bar)
return nothing
finalize!(da.state)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice improvement!

return da
end
Loading