-
Notifications
You must be signed in to change notification settings - Fork 48
Reduce allocations in stepsize.jl #390
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e83ae54
b597b98
d8ee6f0
018dcc4
bb73bce
8b4bb3c
1dc6ccc
4412e7b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -8,7 +8,7 @@ mutable struct DAState{T<:AbstractScalarOrVec{<:AbstractFloat}} | |||||
| H_bar::T | ||||||
| end | ||||||
|
|
||||||
| computeμ(ϵ::AbstractScalarOrVec{<:AbstractFloat}) = log.(10 * ϵ) | ||||||
| computeμ(ϵ::AbstractFloat) = log(10 * ϵ) | ||||||
yebai marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
|
||||||
| function DAState(ϵ::T) where {T} | ||||||
| μ = computeμ(ϵ) | ||||||
|
|
@@ -17,22 +17,34 @@ end | |||||
|
|
||||||
| function DAState(ϵ::AbstractVector{T}) where {T} | ||||||
| n = length(ϵ) | ||||||
| μ = computeμ(ϵ) | ||||||
| μ = map(computeμ, ϵ) | ||||||
yebai marked this conversation as resolved.
Show resolved
Hide resolved
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| return DAState(0, ϵ, μ, zeros(T, n), zeros(T, n)) | ||||||
| end | ||||||
|
|
||||||
| function reset!(das::DAState{T}) where {T<:AbstractFloat} | ||||||
| das.m = 0 | ||||||
| das.μ = computeμ(das.ϵ) | ||||||
| das.x_bar = zero(T) | ||||||
| return das.H_bar = zero(T) | ||||||
| das.H_bar = zero(T) | ||||||
| return das | ||||||
| end | ||||||
|
|
||||||
| function reset!(das::DAState{<:AbstractVector{T}}) where {T<:AbstractFloat} | ||||||
| das.m = 0 | ||||||
| das.μ .= computeμ(das.ϵ) | ||||||
| das.x_bar .= zero(T) | ||||||
| return das.H_bar .= zero(T) | ||||||
| map!(computeμ, das.μ, das.ϵ) | ||||||
yebai marked this conversation as resolved.
Show resolved
Hide resolved
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's keep this as-is for now. We could refactor the vectorised HMC interface, but better to do it seprately in a concerted effort:
Suggested change
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This suggestion would go against the main intention of the PR, reducing unnecessary allocations: With
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair point! I recently discovered the AcceleratedKernels package, which provides a unified interface for parallelisation on CPUs, clusters, and GPUs. We could consider switching to EDIT: I opened an issue for this suggestion. #412 |
||||||
| fill!(das.x_bar, zero(T)) | ||||||
| fill!(das.H_bar, zero(T)) | ||||||
| return das | ||||||
| end | ||||||
|
|
||||||
| function finalize!(das::DAState{<:AbstractFloat}) | ||||||
| das.ϵ = exp(das.x_bar) | ||||||
| return das | ||||||
| end | ||||||
|
|
||||||
| function finalize!(das::DAState{<:AbstractVector{<:AbstractFloat}}) | ||||||
| map!(exp, das.ϵ, das.x_bar) | ||||||
| return das | ||||||
| end | ||||||
|
|
||||||
| mutable struct MSSState{T<:AbstractScalarOrVec{<:AbstractFloat}} | ||||||
|
|
@@ -51,7 +63,7 @@ getϵ(ss::StepSizeAdaptor) = ss.state.ϵ | |||||
| struct FixedStepSize{T<:AbstractScalarOrVec{<:AbstractFloat}} <: StepSizeAdaptor | ||||||
| ϵ::T | ||||||
| end | ||||||
| Base.show(io::IO, a::FixedStepSize) = print(io, "FixedStepSize($(a.ϵ))") | ||||||
| Base.show(io::IO, a::FixedStepSize) = print(io, "FixedStepSize(", a.ϵ, ")") | ||||||
|
|
||||||
| getϵ(fss::FixedStepSize) = fss.ϵ | ||||||
|
|
||||||
|
|
@@ -82,7 +94,17 @@ end | |||||
| function Base.show(io::IO, a::NesterovDualAveraging) | ||||||
| return print( | ||||||
| io, | ||||||
| "NesterovDualAveraging(γ=$(a.γ), t_0=$(a.t_0), κ=$(a.κ), δ=$(a.δ), state.ϵ=$(getϵ(a)))", | ||||||
| "NesterovDualAveraging(γ=", | ||||||
| a.γ, | ||||||
| ", t_0=", | ||||||
| a.t_0, | ||||||
| ", κ=", | ||||||
| a.κ, | ||||||
| ", δ=", | ||||||
| a.δ, | ||||||
| ", state.ϵ=", | ||||||
| getϵ(a), | ||||||
| ")", | ||||||
| ) | ||||||
| end | ||||||
|
|
||||||
|
|
@@ -95,35 +117,29 @@ end | |||||
| function NesterovDualAveraging( | ||||||
| δ::T, ϵ::VT | ||||||
| ) where {T<:AbstractFloat,VT<:AbstractScalarOrVec{T}} | ||||||
| return NesterovDualAveraging(T(0.05), T(10.0), T(0.75), δ, ϵ) | ||||||
| return NesterovDualAveraging(T(1//20), T(10), T(3//4), δ, ϵ) | ||||||
| end | ||||||
|
|
||||||
| # Ref: https://github.com/stan-dev/stan/blob/develop/src/stan/mcmc/stepsize_adaptation.hpp | ||||||
| # Note: This function is not merged with `adapt!` to empahsize the fact that | ||||||
| # step size adaptation is not dependent on `θ`. | ||||||
| # Note 2: `da.state` and `α` support vectorised HMC but should do so together. | ||||||
| function adapt_stepsize!( | ||||||
yebai marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| da::NesterovDualAveraging{T}, α::AbstractScalarOrVec{<:T} | ||||||
| da::NesterovDualAveraging{T}, α::AbstractScalarOrVec{T} | ||||||
| ) where {T<:AbstractFloat} | ||||||
| @debug "Adapting step size..." α | ||||||
|
|
||||||
| # Clip average MH acceptance probability | ||||||
| if α isa AbstractVector | ||||||
| α[α .> 1] .= one(T) | ||||||
| else | ||||||
| α = α > 1 ? one(T) : α | ||||||
| end | ||||||
|
|
||||||
| (; state, γ, t_0, κ, δ) = da | ||||||
| (; μ, m, x_bar, H_bar) = state | ||||||
|
|
||||||
| m = m + 1 | ||||||
|
|
||||||
| η_H = one(T) / (m + t_0) | ||||||
| H_bar = (one(T) - η_H) * H_bar .+ η_H * (δ .- α) | ||||||
| H_bar = (one(T) - η_H) .* H_bar .+ η_H .* (δ .- min.(one(T), α)) | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. HG: I'll have to review these more carefully later this week. EDIT: This looks good. I am surprised the previous code didn't break any tests, as it didn't properly support vectorised adaption. |
||||||
|
|
||||||
| x = μ .- H_bar * sqrt(m) / γ # x ≡ logϵ | ||||||
| x = μ .- H_bar .* (sqrt(m) / γ) # x ≡ logϵ | ||||||
| η_x = m^(-κ) | ||||||
| x_bar = (one(T) - η_x) * x_bar .+ η_x * x | ||||||
| x_bar = (one(T) - η_x) .* x_bar .+ η_x .* x | ||||||
|
|
||||||
| ϵ = exp.(x) | ||||||
| @debug "Adapting step size..." new_ϵ = ϵ old_ϵ = da.state.ϵ | ||||||
|
|
@@ -151,9 +167,12 @@ function adapt!( | |||||
| return nothing | ||||||
| end | ||||||
|
|
||||||
| reset!(da::NesterovDualAveraging) = reset!(da.state) | ||||||
| function reset!(da::NesterovDualAveraging) | ||||||
| reset!(da.state) | ||||||
| return da | ||||||
| end | ||||||
|
|
||||||
| function finalize!(da::NesterovDualAveraging) | ||||||
| da.state.ϵ = exp.(da.state.x_bar) | ||||||
| return nothing | ||||||
| finalize!(da.state) | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice improvement! |
||||||
| return da | ||||||
| end | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Caution is required here: these support the vectorised version of
HMC. Do you know howmapwould differ from broadcasting here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The results of the calculations won't be affected by this change, but using the non-broadcasted formulation for scalars and
mapfor vectors of floats will remove the broadcasting overhead and reduce stress on the compiler, i.e., generally it reduces compilation time. Sometimes it also helps type inference (but this case is too simple for this effect I assume).In my experience, broadcasting is useful if one's actually broadcasting values of different size and dimensions but otherwise often a suboptimal choice.