@@ -8,7 +8,7 @@ differentiation for fast Vector Jacobian Products.
8
8
9
9
### Arguments
10
10
11
- - `method`: the line search algorithm to use. Defaults to `Static() `, which means that the
11
+ - `method`: the line search algorithm to use. Defaults to `nothing `, which means that the
12
12
step size is fixed to the value of `alpha`.
13
13
- `autodiff`: the automatic differentiation backend to use for the line search. Defaults to
14
14
`AutoFiniteDiff()`, which means that finite differencing is used to compute the VJP.
@@ -22,19 +22,31 @@ differentiation for fast Vector Jacobian Products.
22
22
α
23
23
end
24
24
25
- function LineSearch (; method = Static () , autodiff = AutoFiniteDiff (), alpha = true )
25
+ function LineSearch (; method = nothing , autodiff = AutoFiniteDiff (), alpha = true )
26
26
return LineSearch (method, autodiff, alpha)
27
27
end
28
28
29
29
@inline function init_linesearch_cache (ls:: LineSearch , args... )
30
30
return init_linesearch_cache (ls. method, ls, args... )
31
31
end
32
32
33
+ @concrete struct NoLineSearchCache
34
+ α
35
+ end
36
+
37
+ function init_linesearch_cache (:: Nothing , ls, f:: F , u, p, fu, iip) where {F}
38
+ return NoLineSearchCache (convert (eltype (u), ls. α))
39
+ end
40
+
41
+ perform_linesearch! (cache:: NoLineSearchCache , u, du) = cache. α
42
+
33
43
# LineSearches.jl doesn't have a supertype so default to that
34
- function init_linesearch_cache (_, ls, f:: F , u, p, fu, iip) where {F <: Function }
44
+ function init_linesearch_cache (_, ls, f:: F , u, p, fu, iip) where {F}
35
45
return LineSearchesJLCache (ls, f, u, p, fu, iip)
36
46
end
37
47
48
+ # FIXME : The closures lead to too many unnecessary runtime dispatches which leads to the
49
+ # massive increase in precompilation times.
38
50
# Wrapper over LineSearches.jl algorithms
39
51
@concrete mutable struct LineSearchesJLCache
40
52
f
45
57
ls
46
58
end
47
59
48
- function LineSearchesJLCache (ls:: LineSearch , f:: F , u:: Number , p, _,
49
- :: Val{false} ) where {F <: Function }
60
+ function LineSearchesJLCache (ls:: LineSearch , f:: F , u:: Number , p, _, :: Val{false} ) where {F}
50
61
eval_f (u, du, α) = eval_f (u - α * du)
51
62
eval_f (u) = f (u, p)
52
63
@@ -87,8 +98,7 @@ function LineSearchesJLCache(ls::LineSearch, f::F, u::Number, p, _,
87
98
return LineSearchesJLCache (eval_f, ϕ, dϕ, ϕdϕ, convert (eltype (u), ls. α), ls)
88
99
end
89
100
90
- function LineSearchesJLCache (ls:: LineSearch , f:: F , u, p, fu1,
91
- IIP:: Val{iip} ) where {iip, F <: Function }
101
+ function LineSearchesJLCache (ls:: LineSearch , f:: F , u, p, fu1, IIP:: Val{iip} ) where {iip, F}
92
102
fu = iip ? deepcopy (fu1) : nothing
93
103
u_ = _mutable_zero (u)
94
104
202
212
end
203
213
204
214
function init_linesearch_cache (alg:: LiFukushimaLineSearch , ls:: LineSearch , f:: F , _u, p, _fu,
205
- :: Val{iip} ) where {iip, F <: Function }
215
+ :: Val{iip} ) where {iip, F}
206
216
fu = iip ? deepcopy (_fu) : nothing
207
217
u = iip ? deepcopy (_u) : nothing
208
218
return LiFukushimaLineSearchCache {iip} (f, p, u, fu, alg, ls. α)
0 commit comments