Skip to content

Commit 3e4f8cd

Browse files
authored
Some type annotations inspired by some JET errors. (#1194)
* Some type annotations inspired by some JET errors. Not all these were related to JET errors, but since they appered because of unclear types in certain signatures I fiigured I'd add a few. * Update gradient_descent.jl * Options->NamedTuple * Update golden_section.jl * Update types.jl * Update types.jl
1 parent 80e498b commit 3e4f8cd

File tree

18 files changed

+43
-52
lines changed

18 files changed

+43
-52
lines changed

src/multivariate/solvers/constrained/ipnewton/ipnewton.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ end
121121

122122
function initial_state(
123123
method::IPNewton,
124-
options,
124+
options::Options,
125125
d::TwiceDifferentiable,
126126
constraints::TwiceDifferentiableConstraints,
127127
initial_x::AbstractArray{T},
@@ -270,7 +270,7 @@ function update_state!(
270270
constraints::TwiceDifferentiableConstraints,
271271
state::IPNewtonState{T},
272272
method::IPNewton,
273-
options,
273+
options::Options,
274274
) where {T}
275275
state.f_x_previous, state.L_previous = state.f_x, state.L
276276
bstate, bstep, bounds = state.bstate, state.bstep, constraints.bounds
@@ -331,7 +331,7 @@ end
331331
function solve_step!(
332332
state::IPNewtonState,
333333
constraints,
334-
options,
334+
options::Options,
335335
show_linesearch::Bool = false,
336336
)
337337
x, s, μ, bounds = state.x, state.s, state.μ, constraints.bounds

src/multivariate/solvers/first_order/accelerated_gradient_descent.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,10 @@ end
8888
function trace!(
8989
tr,
9090
d,
91-
state,
92-
iteration,
91+
state::AcceleratedGradientDescentState,
92+
iteration::Integer,
9393
method::AcceleratedGradientDescent,
94-
options,
94+
options::Options,
9595
curr_time = time(),
9696
)
9797
common_trace!(tr, d, state, iteration, method, options, curr_time)

src/multivariate/solvers/first_order/adam.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ function _get_init_params(method::Adam)
6767
method.α(1), method.β₁, method.β₂
6868
end
6969

70-
function initial_state(method::Adam, options, d, initial_x::AbstractArray{T}) where {T}
70+
function initial_state(method::Adam, options::Options, d, initial_x::AbstractArray{T}) where {T}
7171
initial_x = copy(initial_x)
7272

7373
value_gradient!!(d, initial_x)
@@ -122,6 +122,6 @@ function update_state!(d, state::AdamState{T}, method::Adam) where {T}
122122
false # break on linesearch error
123123
end
124124

125-
function trace!(tr, d, state, iteration, method::Adam, options, curr_time = time())
125+
function trace!(tr, d, state::AdamState, iteration::Integer, method::Adam, options::Options, curr_time = time())
126126
common_trace!(tr, d, state, iteration, method, options, curr_time)
127127
end

src/multivariate/solvers/first_order/adamax.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ function _get_init_params(method::AdaMax)
6767
method.α(1), method.β₁, method.β₂
6868
end
6969

70-
function initial_state(method::AdaMax, options, d, initial_x::AbstractArray{T}) where {T}
70+
function initial_state(method::AdaMax, options::Options, d, initial_x::AbstractArray{T}) where {T}
7171
initial_x = copy(initial_x)
7272

7373
value_gradient!!(d, initial_x)
@@ -115,6 +115,6 @@ function update_state!(d, state::AdaMaxState{T}, method::AdaMax) where {T}
115115
false # break on linesearch error
116116
end
117117

118-
function trace!(tr, d, state, iteration, method::AdaMax, options, curr_time = time())
118+
function trace!(tr, d, state::AdaMaxState, iteration::Integer, method::AdaMax, options::Options, curr_time = time())
119119
common_trace!(tr, d, state, iteration, method, options, curr_time)
120120
end

src/multivariate/solvers/first_order/bfgs.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ function reset!(method, state::BFGSState, obj, x)
8888
end
8989
end
9090

91-
function initial_state(method::BFGS, options, d, initial_x::AbstractArray{T}) where {T}
91+
function initial_state(method::BFGS, options::Options, d, initial_x::AbstractArray{T}) where {T}
9292
n = length(initial_x)
9393
initial_x = copy(initial_x)
9494
retract!(method.manifold, initial_x)
@@ -186,7 +186,7 @@ function update_h!(d, state, method::BFGS)
186186
end
187187
end
188188

189-
function trace!(tr, d, state, iteration, method::BFGS, options, curr_time = time())
189+
function trace!(tr, d, state::BFGSState, iteration::Integer, method::BFGS, options::Options, curr_time = time())
190190
dt = Dict()
191191
dt["time"] = curr_time
192192
if options.extended_trace

src/multivariate/solvers/first_order/cg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ function reset!(cg::ConjugateGradient, cgs::ConjugateGradientState, obj, x)
113113
cgs.s .= .-cgs.pg
114114
cgs.f_x_previous = typeof(cgs.f_x_previous)(NaN)
115115
end
116-
function initial_state(method::ConjugateGradient, options, d, initial_x)
116+
function initial_state(method::ConjugateGradient, options::Options, d, initial_x)
117117
T = eltype(initial_x)
118118
initial_x = copy(initial_x)
119119
retract!(method.manifold, initial_x)

src/multivariate/solvers/first_order/gradient_descent.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,10 @@ end
9797
function trace!(
9898
tr,
9999
d,
100-
state,
101-
iteration,
100+
state::GradientDescentState,
101+
iteration::Integer,
102102
method::GradientDescent,
103-
options,
103+
options::Options,
104104
curr_time = time(),
105105
)
106106
common_trace!(tr, d, state, iteration, method, options, curr_time)

src/multivariate/solvers/first_order/l_bfgs.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,14 +152,14 @@ mutable struct LBFGSState{Tx,Tdx,Tdg,T,G} <: AbstractOptimizerState
152152
s::Tx
153153
@add_linesearch_fields()
154154
end
155-
function reset!(method, state::LBFGSState, obj, x)
155+
function reset!(method::LBFGS, state::LBFGSState, obj, x::AbstractArray)
156156
retract!(method.manifold, x)
157157
value_gradient!(obj, x)
158158
project_tangent!(method.manifold, gradient(obj), x)
159159

160160
state.pseudo_iteration = 0
161161
end
162-
function initial_state(method::LBFGS, options, d, initial_x)
162+
function initial_state(method::LBFGS, options::Options, d, initial_x::AbstractArray)
163163
T = real(eltype(initial_x))
164164
n = length(initial_x)
165165
initial_x = copy(initial_x)
@@ -228,7 +228,7 @@ function update_state!(d, state::LBFGSState, method::LBFGS)
228228
end
229229

230230

231-
function update_h!(d, state, method::LBFGS)
231+
function update_h!(d, state::LBFGSState, method::LBFGS)
232232
n = length(state.x)
233233
# Measure the change in the gradient
234234
state.dg .= gradient(d) .- state.g_previous
@@ -247,6 +247,6 @@ function update_h!(d, state, method::LBFGS)
247247
false
248248
end
249249

250-
function trace!(tr, d, state, iteration, method::LBFGS, options, curr_time = time())
250+
function trace!(tr, d, state::LBFGSState, iteration::Integer, method::LBFGS, options::Options, curr_time = time())
251251
common_trace!(tr, d, state, iteration, method, options, curr_time)
252252
end

src/multivariate/solvers/first_order/momentum_gradient_descent.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ mutable struct MomentumGradientDescentState{Tx,T} <: AbstractOptimizerState
2828
@add_linesearch_fields()
2929
end
3030

31-
function initial_state(method::MomentumGradientDescent, options, d, initial_x)
31+
function initial_state(method::MomentumGradientDescent, options::Options, d, initial_x::AbstractArray)
3232
T = eltype(initial_x)
3333
initial_x = copy(initial_x)
3434
retract!(method.manifold, initial_x)
@@ -70,10 +70,10 @@ end
7070
function trace!(
7171
tr,
7272
d,
73-
state,
74-
iteration,
73+
state::MomentumGradientDescentState,
74+
iteration::Integer,
7575
method::MomentumGradientDescent,
76-
options,
76+
options::Options,
7777
curr_time = time(),
7878
)
7979
common_trace!(tr, d, state, iteration, method, options, curr_time)

src/multivariate/solvers/first_order/ngmres.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ end
223223
const ngmres_oaccel_warned = Ref{Bool}(false)
224224
function initial_state(
225225
method::AbstractNGMRES,
226-
options,
226+
options::Options,
227227
d,
228228
initial_x::AbstractArray{eTx},
229229
) where {eTx}
@@ -452,10 +452,10 @@ end
452452
function trace!(
453453
tr,
454454
d,
455-
state,
456-
iteration,
455+
state::NGMRESState,
456+
iteration::Integer,
457457
method::AbstractNGMRES,
458-
options,
458+
options::Options,
459459
curr_time = time(),
460460
)
461461
dt = Dict()

0 commit comments

Comments
 (0)