Skip to content

Commit 59c988e

Browse files
committed
Short circuit linesearch for now
1 parent a24f20e commit 59c988e

File tree

2 files changed

+19
-9
lines changed

2 files changed

+19
-9
lines changed

src/linesearch.jl

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ differentiation for fast Vector Jacobian Products.
88
99
### Arguments
1010
11-
- `method`: the line search algorithm to use. Defaults to `Static()`, which means that the
11+
- `method`: the line search algorithm to use. Defaults to `nothing`, which means that the
1212
step size is fixed to the value of `alpha`.
1313
- `autodiff`: the automatic differentiation backend to use for the line search. Defaults to
1414
`AutoFiniteDiff()`, which means that finite differencing is used to compute the VJP.
@@ -22,19 +22,31 @@ differentiation for fast Vector Jacobian Products.
2222
α
2323
end
2424

25-
function LineSearch(; method = Static(), autodiff = AutoFiniteDiff(), alpha = true)
25+
function LineSearch(; method = nothing, autodiff = AutoFiniteDiff(), alpha = true)
2626
return LineSearch(method, autodiff, alpha)
2727
end
2828

2929
@inline function init_linesearch_cache(ls::LineSearch, args...)
3030
return init_linesearch_cache(ls.method, ls, args...)
3131
end
3232

33+
@concrete struct NoLineSearchCache
34+
α
35+
end
36+
37+
function init_linesearch_cache(::Nothing, ls, f::F, u, p, fu, iip) where {F}
38+
return NoLineSearchCache(convert(eltype(u), ls.α))
39+
end
40+
41+
perform_linesearch!(cache::NoLineSearchCache, u, du) = cache.α
42+
3343
# LineSearches.jl doesn't have a supertype so default to that
34-
function init_linesearch_cache(_, ls, f::F, u, p, fu, iip) where {F <: Function}
44+
function init_linesearch_cache(_, ls, f::F, u, p, fu, iip) where {F}
3545
return LineSearchesJLCache(ls, f, u, p, fu, iip)
3646
end
3747

48+
# FIXME: The closures lead to too many unnecessary runtime dispatches which leads to the
49+
# massive increase in precompilation times.
3850
# Wrapper over LineSearches.jl algorithms
3951
@concrete mutable struct LineSearchesJLCache
4052
f
@@ -45,8 +57,7 @@ end
4557
ls
4658
end
4759

48-
function LineSearchesJLCache(ls::LineSearch, f::F, u::Number, p, _,
49-
::Val{false}) where {F <: Function}
60+
function LineSearchesJLCache(ls::LineSearch, f::F, u::Number, p, _, ::Val{false}) where {F}
5061
eval_f(u, du, α) = eval_f(u - α * du)
5162
eval_f(u) = f(u, p)
5263

@@ -87,8 +98,7 @@ function LineSearchesJLCache(ls::LineSearch, f::F, u::Number, p, _,
8798
return LineSearchesJLCache(eval_f, ϕ, dϕ, ϕdϕ, convert(eltype(u), ls.α), ls)
8899
end
89100

90-
function LineSearchesJLCache(ls::LineSearch, f::F, u, p, fu1,
91-
IIP::Val{iip}) where {iip, F <: Function}
101+
function LineSearchesJLCache(ls::LineSearch, f::F, u, p, fu1, IIP::Val{iip}) where {iip, F}
92102
fu = iip ? deepcopy(fu1) : nothing
93103
u_ = _mutable_zero(u)
94104

@@ -202,7 +212,7 @@ end
202212
end
203213

204214
function init_linesearch_cache(alg::LiFukushimaLineSearch, ls::LineSearch, f::F, _u, p, _fu,
205-
::Val{iip}) where {iip, F <: Function}
215+
::Val{iip}) where {iip, F}
206216
fu = iip ? deepcopy(_fu) : nothing
207217
u = iip ? deepcopy(_u) : nothing
208218
return LiFukushimaLineSearchCache{iip}(f, p, u, fu, alg, ls.α)

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ function __get_concrete_algorithm(alg, prob)
216216
use_sparse_ad ? AutoSparseFiniteDiff() : AutoFiniteDiff()
217217
else
218218
tag = NonlinearSolveTag()
219-
use_sparse_ad ? AutoSparseForwardDiff(; tag) : AutoForwardDiff(; tag)
219+
(use_sparse_ad ? AutoSparseForwardDiff : AutoForwardDiff)(; tag)
220220
end
221221
return set_ad(alg, ad)
222222
end

0 commit comments

Comments
 (0)