Skip to content

Commit c7ca39a

Browse files
committed
Specialize on functions
1 parent c137cb7 commit c7ca39a

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

src/linesearch.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ end
3131
end
3232

3333
# LineSearches.jl doesn't have a supertype so default to that
34-
init_linesearch_cache(_, ls, f, u, p, fu, iip) = LineSearchesJLCache(ls, f, u, p, fu, iip)
34+
function init_linesearch_cache(_, ls, f::F, u, p, fu, iip) where {F <: Function}
35+
return LineSearchesJLCache(ls, f, u, p, fu, iip)
36+
end
3537

3638
# Wrapper over LineSearches.jl algorithms
3739
@concrete mutable struct LineSearchesJLCache
@@ -43,7 +45,8 @@ init_linesearch_cache(_, ls, f, u, p, fu, iip) = LineSearchesJLCache(ls, f, u, p
4345
ls
4446
end
4547

46-
function LineSearchesJLCache(ls::LineSearch, f, u::Number, p, _, ::Val{false})
48+
function LineSearchesJLCache(ls::LineSearch, f::F, u::Number, p, _,
49+
::Val{false}) where {F <: Function}
4750
eval_f(u, du, α) = eval_f(u - α * du)
4851
eval_f(u) = f(u, p)
4952

@@ -84,7 +87,8 @@ function LineSearchesJLCache(ls::LineSearch, f, u::Number, p, _, ::Val{false})
8487
return LineSearchesJLCache(eval_f, ϕ, dϕ, ϕdϕ, convert(eltype(u), ls.α), ls)
8588
end
8689

87-
function LineSearchesJLCache(ls::LineSearch, f, u, p, fu1, IIP::Val{iip}) where {iip}
90+
function LineSearchesJLCache(ls::LineSearch, f::F, u, p, fu1,
91+
IIP::Val{iip}) where {iip, F <: Function}
8892
fu = iip ? deepcopy(fu1) : nothing
8993
u_ = _mutable_zero(u)
9094

@@ -200,8 +204,8 @@ end
200204
α
201205
end
202206

203-
function init_linesearch_cache(alg::LiFukushimaLineSearch, ls::LineSearch, f, _u, p, _fu,
204-
::Val{iip}) where {iip}
207+
function init_linesearch_cache(alg::LiFukushimaLineSearch, ls::LineSearch, f::F, _u, p, _fu,
208+
::Val{iip}) where {iip, F <: Function}
205209
fu = iip ? deepcopy(_fu) : nothing
206210
u = iip ? deepcopy(_u) : nothing
207211
return LiFukushimaLineSearchCache{iip}(f, p, u, fu, alg, ls.α)

0 commit comments

Comments
 (0)