Skip to content

Commit bcfcc16

Browse files
committed
Avoid Runtime Checks for Zygote Being loaded
1 parent 6c52956 commit bcfcc16

10 files changed

+33
-18
lines changed

.JuliaFormatter.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
style = "sciml"
22
format_markdown = true
33
annotate_untyped_fields_with_any = false
4+
format_docstrings = true

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,13 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
3030
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
3131
FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
3232
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
33+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3334

3435
[extensions]
3536
NonlinearSolveBandedMatricesExt = "BandedMatrices"
3637
NonlinearSolveFastLevenbergMarquardtExt = "FastLevenbergMarquardt"
3738
NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim"
39+
NonlinearSolveZygoteExt = "Zygote"
3840

3941
[compat]
4042
ADTypes = "0.2"

ext/NonlinearSolveZygoteExt.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
module NonlinearSolveZygoteExt
2+
3+
import NonlinearSolve, Zygote
4+
5+
NonlinearSolve.is_extension_loaded(::Val{:Zygote}) = true
6+
7+
end

src/NonlinearSolve.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ import DiffEqBase: AbstractNonlinearTerminationMode,
3838
const AbstractSparseADType = Union{ADTypes.AbstractSparseFiniteDifferences,
3939
ADTypes.AbstractSparseForwardMode, ADTypes.AbstractSparseReverseMode}
4040

41+
# Type-Inference Friendly Check for Extension Loading
42+
is_extension_loaded(::Val) = false
43+
4144
abstract type AbstractNonlinearSolveLineSearchAlgorithm end
4245

4346
abstract type AbstractNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end

src/extension_algs.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@ for solving `NonlinearLeastSquaresProblem`.
88
99
## Arguments:
1010
11-
- `alg`: Algorithm to use. Can be `:lm` or `:dogleg`.
12-
- `linsolve`: Linear solver to use. Can be `:qr`, `:cholesky` or `:lsmr`. If
13-
`nothing`, then `LeastSquaresOptim.jl` will choose the best linear solver based
14-
on the Jacobian structure.
15-
- `autodiff`: Automatic differentiation / Finite Differences. Can be `:central` or `:forward`.
11+
- `alg`: Algorithm to use. Can be `:lm` or `:dogleg`.
12+
- `linsolve`: Linear solver to use. Can be `:qr`, `:cholesky` or `:lsmr`. If `nothing`,
13+
then `LeastSquaresOptim.jl` will choose the best linear solver based on the Jacobian
14+
structure.
15+
- `autodiff`: Automatic differentiation / Finite Differences. Can be `:central` or
16+
`:forward`.
1617
1718
!!! note
19+
1820
This algorithm is only available if `LeastSquaresOptim.jl` is installed.
1921
"""
2022
struct LeastSquaresOptimJL{alg, linsolve} <: AbstractNonlinearSolveAlgorithm
@@ -36,21 +38,24 @@ end
3638
"""
3739
FastLevenbergMarquardtJL(linsolve = :cholesky)
3840
39-
Wrapper over [FastLevenbergMarquardt.jl](https://github.com/kamesy/FastLevenbergMarquardt.jl) for solving
40-
`NonlinearLeastSquaresProblem`.
41+
Wrapper over [FastLevenbergMarquardt.jl](https://github.com/kamesy/FastLevenbergMarquardt.jl)
42+
for solving `NonlinearLeastSquaresProblem`.
4143
4244
!!! warning
45+
4346
This is not really the fastest solver. It is called that since the original package
4447
is called "Fast". `LevenbergMarquardt()` is almost always a better choice.
4548
4649
!!! warning
50+
4751
This algorithm requires the jacobian function to be provided!
4852
4953
## Arguments:
5054
51-
- `linsolve`: Linear solver to use. Can be `:qr` or `:cholesky`.
55+
- `linsolve`: Linear solver to use. Can be `:qr` or `:cholesky`.
5256
5357
!!! note
58+
5459
This algorithm is only available if `FastLevenbergMarquardt.jl` is installed.
5560
"""
5661
@concrete struct FastLevenbergMarquardtJL{linsolve} <: AbstractNonlinearSolveAlgorithm

src/jacobian.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -166,12 +166,8 @@ function __concrete_vjp_autodiff(vjp_autodiff, jvp_autodiff, uf)
166166
# Short circuit if we see that FiniteDiff was used for J computation
167167
jvp_autodiff isa AutoFiniteDiff && return jvp_autodiff
168168
# Check if Zygote is loaded then use Zygote else use FiniteDiff
169-
if haskey(Base.loaded_modules,
170-
Base.PkgId(Base.UUID("e88e6eb3-aa80-5325-afca-941959d7151f"), "Zygote"))
171-
return AutoZygote()
172-
else
173-
return AutoFiniteDiff()
174-
end
169+
is_extension_loaded(Val{:Zygote}()) && return AutoZygote()
170+
return AutoFiniteDiff()
175171
end
176172
else
177173
ad = __get_nonsparse_ad(vjp_autodiff)

src/linesearch.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,7 @@ function LineSearchesJLCache(ls::LineSearch, f::F, u, p, fu1, IIP::Val{iip}) whe
114114
g₀ = _mutable_zero(u)
115115

116116
autodiff = if ls.autodiff === nothing
117-
if !iip && haskey(Base.loaded_modules,
118-
Base.PkgId(Base.UUID("e88e6eb3-aa80-5325-afca-941959d7151f"), "Zygote"))
117+
if !iip && is_extension_loaded(Val{:Zygote}())
119118
AutoZygote()
120119
else
121120
AutoFiniteDiff()

src/pseudotransient.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ the time-stepping and algorithm, please see the paper:
1212
SIAM Journal on Scientific Computing,25, 553-569.](https://doi.org/10.1137/S106482750241044X)
1313
1414
### Keyword Arguments
15+
1516
- `autodiff`: determines the backend used for the Jacobian. Note that this argument is
1617
ignored if an analytical Jacobian is passed, as that will be used instead. Defaults to
1718
`nothing` which means that a default is selected according to the problem specification!

src/utils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ Construct the AD type from the arguments. This is mostly needed for compatibilit
2323
code.
2424
2525
!!! warning
26+
2627
`chunk_size`, `standardtag`, `diff_type`, and `autodiff::Union{Val, Bool}` are
2728
deprecated and will be removed in v3. Update your code to directly specify
2829
`autodiff=<ADTypes>`.

test/basictests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,10 +1015,10 @@ end
10151015
prob = NonlinearProblem(NonlinearFunction{false}(F; jvp = JVP), u0, u0)
10161016
sol = solve(prob, NewtonRaphson(; linsolve = KrylovJL_GMRES()); abstol = 1e-13)
10171017

1018-
@test norm(F(sol.u, u0)) 1e-8
1018+
@test norm(F(sol.u, u0)) 1e-6
10191019

10201020
prob = NonlinearProblem(NonlinearFunction{true}(F!; jvp = JVP!), u0, u0)
10211021
sol = solve(prob, NewtonRaphson(; linsolve = KrylovJL_GMRES()); abstol = 1e-13)
10221022

1023-
@test norm(F(sol.u, u0)) 1e-8
1023+
@test norm(F(sol.u, u0)) 1e-6
10241024
end

0 commit comments

Comments
 (0)