@@ -8,7 +8,7 @@ mutable struct DAState{T<:AbstractScalarOrVec{<:AbstractFloat}}
8
8
H_bar:: T
9
9
end
10
10
11
- computeμ (ϵ:: AbstractScalarOrVec{<: AbstractFloat} ) = log . (10 * ϵ)
11
+ computeμ (ϵ:: AbstractFloat ) = log (10 * ϵ)
12
12
13
13
function DAState (ϵ:: T ) where {T}
14
14
μ = computeμ (ϵ)
17
17
18
18
function DAState (ϵ:: AbstractVector{T} ) where {T}
19
19
n = length (ϵ)
20
- μ = computeμ ( ϵ)
20
+ μ = map (computeμ, ϵ)
21
21
return DAState (0 , ϵ, μ, zeros (T, n), zeros (T, n))
22
22
end
23
23
24
24
function reset! (das:: DAState{T} ) where {T<: AbstractFloat }
25
25
das. m = 0
26
26
das. μ = computeμ (das. ϵ)
27
27
das. x_bar = zero (T)
28
- return das. H_bar = zero (T)
28
+ das. H_bar = zero (T)
29
+ return das
29
30
end
30
31
31
32
function reset! (das:: DAState{<:AbstractVector{T}} ) where {T<: AbstractFloat }
32
33
das. m = 0
33
- das. μ .= computeμ (das. ϵ)
34
- das. x_bar .= zero (T)
35
- return das. H_bar .= zero (T)
34
+ map! (computeμ, das. μ, das. ϵ)
35
+ fill! (das. x_bar, zero (T))
36
+ fill! (das. H_bar, zero (T))
37
+ return das
38
+ end
39
+
40
+ function finalize! (das:: DAState{<:AbstractFloat} )
41
+ das. ϵ = exp (das. x_bar)
42
+ return das
43
+ end
44
+
45
+ function finalize! (das:: DAState{<:AbstractVector{<:AbstractFloat}} )
46
+ map! (exp, das. ϵ, das. x_bar)
47
+ return das
36
48
end
37
49
38
50
mutable struct MSSState{T<: AbstractScalarOrVec{<:AbstractFloat} }
@@ -51,7 +63,7 @@ getϵ(ss::StepSizeAdaptor) = ss.state.ϵ
51
63
struct FixedStepSize{T<: AbstractScalarOrVec{<:AbstractFloat} } <: StepSizeAdaptor
52
64
ϵ:: T
53
65
end
54
- Base. show (io:: IO , a:: FixedStepSize ) = print (io, " FixedStepSize($( a. ϵ) )" )
66
+ Base. show (io:: IO , a:: FixedStepSize ) = print (io, " FixedStepSize(" , a. ϵ, " )" )
55
67
56
68
getϵ (fss:: FixedStepSize ) = fss. ϵ
57
69
82
94
function Base. show (io:: IO , a:: NesterovDualAveraging )
83
95
return print (
84
96
io,
85
- " NesterovDualAveraging(γ=$(a. γ) , t_0=$(a. t_0) , κ=$(a. κ) , δ=$(a. δ) , state.ϵ=$(getϵ (a)) )" ,
97
+ " NesterovDualAveraging(γ=" ,
98
+ a. γ,
99
+ " , t_0=" ,
100
+ a. t_0,
101
+ " , κ=" ,
102
+ a. κ,
103
+ " , δ=" ,
104
+ a. δ,
105
+ " , state.ϵ=" ,
106
+ getϵ (a),
107
+ " )" ,
86
108
)
87
109
end
88
110
95
117
function NesterovDualAveraging (
96
118
δ:: T , ϵ:: VT
97
119
) where {T<: AbstractFloat ,VT<: AbstractScalarOrVec{T} }
98
- return NesterovDualAveraging (T (0.05 ), T (10.0 ), T (0.75 ), δ, ϵ)
120
+ return NesterovDualAveraging (T (1 // 20 ), T (10 ), T (3 // 4 ), δ, ϵ)
99
121
end
100
122
101
123
# Ref: https://github.com/stan-dev/stan/blob/develop/src/stan/mcmc/stepsize_adaptation.hpp
102
124
# Note: This function is not merged with `adapt!` to empahsize the fact that
103
125
# step size adaptation is not dependent on `θ`.
126
+ # Note 2: `da.state` and `α` support vectorised HMC but should do so together.
104
127
function adapt_stepsize! (
105
- da:: NesterovDualAveraging{T} , α:: AbstractScalarOrVec{<: T}
128
+ da:: NesterovDualAveraging{T} , α:: AbstractScalarOrVec{T}
106
129
) where {T<: AbstractFloat }
107
130
@debug " Adapting step size..." α
108
131
109
- # Clip average MH acceptance probability
110
- if α isa AbstractVector
111
- α[α .> 1 ] .= one (T)
112
- else
113
- α = α > 1 ? one (T) : α
114
- end
115
-
116
132
(; state, γ, t_0, κ, δ) = da
117
133
(; μ, m, x_bar, H_bar) = state
118
134
119
135
m = m + 1
120
136
121
137
η_H = one (T) / (m + t_0)
122
- H_bar = (one (T) - η_H) * H_bar .+ η_H * (δ .- α )
138
+ H_bar = (one (T) - η_H) . * H_bar .+ η_H . * (δ .- min .( one (T), α) )
123
139
124
- x = μ .- H_bar * sqrt (m) / γ # x ≡ logϵ
140
+ x = μ .- H_bar .* ( sqrt (m) / γ) # x ≡ logϵ
125
141
η_x = m^ (- κ)
126
- x_bar = (one (T) - η_x) * x_bar .+ η_x * x
142
+ x_bar = (one (T) - η_x) . * x_bar .+ η_x . * x
127
143
128
144
ϵ = exp .(x)
129
145
@debug " Adapting step size..." new_ϵ = ϵ old_ϵ = da. state. ϵ
@@ -151,9 +167,12 @@ function adapt!(
151
167
return nothing
152
168
end
153
169
154
- reset! (da:: NesterovDualAveraging ) = reset! (da. state)
170
+ function reset! (da:: NesterovDualAveraging )
171
+ reset! (da. state)
172
+ return da
173
+ end
155
174
156
175
function finalize! (da:: NesterovDualAveraging )
157
- da . state . ϵ = exp . (da. state. x_bar )
158
- return nothing
176
+ finalize! (da. state)
177
+ return da
159
178
end
0 commit comments