Skip to content

Commit a96ab41

Browse files
devmotiongithub-actions[bot]yebai
authored
Reduce allocations in stepsize.jl (#390)
* Reduce allocations in stepsize.jl * Fix format Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Fix finalize and add type parameter * Fix format Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/adaptation/stepsize.jl --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Hong Ge <[email protected]>
1 parent f87870d commit a96ab41

File tree

1 file changed

+42
-23
lines changed

1 file changed

+42
-23
lines changed

src/adaptation/stepsize.jl

Lines changed: 42 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ mutable struct DAState{T<:AbstractScalarOrVec{<:AbstractFloat}}
88
H_bar::T
99
end
1010

11-
computeμ::AbstractScalarOrVec{<:AbstractFloat}) = log.(10 * ϵ)
11+
computeμ::AbstractFloat) = log(10 * ϵ)
1212

1313
function DAState::T) where {T}
1414
μ = computeμ(ϵ)
@@ -17,22 +17,34 @@ end
1717

1818
function DAState::AbstractVector{T}) where {T}
1919
n = length(ϵ)
20-
μ = computeμ(ϵ)
20+
μ = map(computeμ, ϵ)
2121
return DAState(0, ϵ, μ, zeros(T, n), zeros(T, n))
2222
end
2323

2424
function reset!(das::DAState{T}) where {T<:AbstractFloat}
2525
das.m = 0
2626
das.μ = computeμ(das.ϵ)
2727
das.x_bar = zero(T)
28-
return das.H_bar = zero(T)
28+
das.H_bar = zero(T)
29+
return das
2930
end
3031

3132
function reset!(das::DAState{<:AbstractVector{T}}) where {T<:AbstractFloat}
3233
das.m = 0
33-
das.μ .= computeμ(das.ϵ)
34-
das.x_bar .= zero(T)
35-
return das.H_bar .= zero(T)
34+
map!(computeμ, das.μ, das.ϵ)
35+
fill!(das.x_bar, zero(T))
36+
fill!(das.H_bar, zero(T))
37+
return das
38+
end
39+
40+
function finalize!(das::DAState{<:AbstractFloat})
41+
das.ϵ = exp(das.x_bar)
42+
return das
43+
end
44+
45+
function finalize!(das::DAState{<:AbstractVector{<:AbstractFloat}})
46+
map!(exp, das.ϵ, das.x_bar)
47+
return das
3648
end
3749

3850
mutable struct MSSState{T<:AbstractScalarOrVec{<:AbstractFloat}}
@@ -51,7 +63,7 @@ getϵ(ss::StepSizeAdaptor) = ss.state.ϵ
5163
struct FixedStepSize{T<:AbstractScalarOrVec{<:AbstractFloat}} <: StepSizeAdaptor
5264
ϵ::T
5365
end
54-
Base.show(io::IO, a::FixedStepSize) = print(io, "FixedStepSize($(a.ϵ))")
66+
Base.show(io::IO, a::FixedStepSize) = print(io, "FixedStepSize(", a.ϵ, ")")
5567

5668
getϵ(fss::FixedStepSize) = fss.ϵ
5769

@@ -82,7 +94,17 @@ end
8294
function Base.show(io::IO, a::NesterovDualAveraging)
8395
return print(
8496
io,
85-
"NesterovDualAveraging(γ=$(a.γ), t_0=$(a.t_0), κ=$(a.κ), δ=$(a.δ), state.ϵ=$(getϵ(a)))",
97+
"NesterovDualAveraging(γ=",
98+
a.γ,
99+
", t_0=",
100+
a.t_0,
101+
", κ=",
102+
a.κ,
103+
", δ=",
104+
a.δ,
105+
", state.ϵ=",
106+
getϵ(a),
107+
")",
86108
)
87109
end
88110

@@ -95,35 +117,29 @@ end
95117
function NesterovDualAveraging(
96118
δ::T, ϵ::VT
97119
) where {T<:AbstractFloat,VT<:AbstractScalarOrVec{T}}
98-
return NesterovDualAveraging(T(0.05), T(10.0), T(0.75), δ, ϵ)
120+
return NesterovDualAveraging(T(1//20), T(10), T(3//4), δ, ϵ)
99121
end
100122

101123
# Ref: https://github.com/stan-dev/stan/blob/develop/src/stan/mcmc/stepsize_adaptation.hpp
102124
# Note: This function is not merged with `adapt!` to empahsize the fact that
103125
# step size adaptation is not dependent on `θ`.
126+
# Note 2: `da.state` and `α` support vectorised HMC but should do so together.
104127
function adapt_stepsize!(
105-
da::NesterovDualAveraging{T}, α::AbstractScalarOrVec{<:T}
128+
da::NesterovDualAveraging{T}, α::AbstractScalarOrVec{T}
106129
) where {T<:AbstractFloat}
107130
@debug "Adapting step size..." α
108131

109-
# Clip average MH acceptance probability
110-
if α isa AbstractVector
111-
α[α .> 1] .= one(T)
112-
else
113-
α = α > 1 ? one(T) : α
114-
end
115-
116132
(; state, γ, t_0, κ, δ) = da
117133
(; μ, m, x_bar, H_bar) = state
118134

119135
m = m + 1
120136

121137
η_H = one(T) / (m + t_0)
122-
H_bar = (one(T) - η_H) * H_bar .+ η_H *.- α)
138+
H_bar = (one(T) - η_H) .* H_bar .+ η_H .*.- min.(one(T), α))
123139

124-
x = μ .- H_bar * sqrt(m) / γ # x ≡ logϵ
140+
x = μ .- H_bar .* (sqrt(m) / γ) # x ≡ logϵ
125141
η_x = m^(-κ)
126-
x_bar = (one(T) - η_x) * x_bar .+ η_x * x
142+
x_bar = (one(T) - η_x) .* x_bar .+ η_x .* x
127143

128144
ϵ = exp.(x)
129145
@debug "Adapting step size..." new_ϵ = ϵ old_ϵ = da.state.ϵ
@@ -151,9 +167,12 @@ function adapt!(
151167
return nothing
152168
end
153169

154-
reset!(da::NesterovDualAveraging) = reset!(da.state)
170+
function reset!(da::NesterovDualAveraging)
171+
reset!(da.state)
172+
return da
173+
end
155174

156175
function finalize!(da::NesterovDualAveraging)
157-
da.state.ϵ = exp.(da.state.x_bar)
158-
return nothing
176+
finalize!(da.state)
177+
return da
159178
end

0 commit comments

Comments
 (0)