Skip to content

Commit 4d2f37a

Browse files
committed
Make it a ext
1 parent 5dbd10d commit 4d2f37a

File tree

5 files changed

+34
-19
lines changed

5 files changed

+34
-19
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ PETSc = "ace2c81b-2b5f-4b1e-a30d-d662738edfe0"
4444
SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4"
4545
SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412"
4646
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
47+
TaylorDiff = "b36ab563-344f-407b-a36a-4f200bebf99c"
4748

4849
[extensions]
4950
NonlinearSolveFastLevenbergMarquardtExt = "FastLevenbergMarquardt"
@@ -56,6 +57,7 @@ NonlinearSolvePETScExt = ["PETSc", "MPI"]
5657
NonlinearSolveSIAMFANLEquationsExt = "SIAMFANLEquations"
5758
NonlinearSolveSpeedMappingExt = "SpeedMapping"
5859
NonlinearSolveSundialsExt = "Sundials"
60+
NonlinearSolveTaylorDiffExt = "TaylorDiff"
5961

6062
[compat]
6163
ADTypes = "1.9"
@@ -113,7 +115,6 @@ StaticArrays = "1.9"
113115
StaticArraysCore = "1.4"
114116
Sundials = "4.23.1"
115117
SymbolicIndexingInterface = "0.3.31"
116-
Symbolics = "6"
117118
TaylorDiff = "0.3"
118119
Test = "1.10"
119120
Zygote = "0.6.69"
@@ -148,8 +149,9 @@ SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412"
148149
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
149150
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
150151
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
152+
TaylorDiff = "b36ab563-344f-407b-a36a-4f200bebf99c"
151153
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
152154
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
153155

154156
[targets]
155-
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "CUDA", "Enzyme", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "PETSc", "Pkg", "Random", "ReTestItems", "SIAMFANLEquations", "SparseConnectivityTracer", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Test", "Zygote"]
157+
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "CUDA", "Enzyme", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "PETSc", "Pkg", "Random", "ReTestItems", "SIAMFANLEquations", "SparseConnectivityTracer", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "TaylorDiff", "Test", "Zygote"]

ext/NonlinearSolveTaylorDiffExt.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
module NonlinearSolveTaylorDiffExt
2+
using NonlinearSolve: HalleyDescentCache, NonlinearFunction
3+
import NonlinearSolve: evaluate_hvvp
4+
using TaylorDiff: derivative, derivative!
5+
using FastClosures: @closure
6+
7+
function evaluate_hvvp(
8+
hvvp, cache::HalleyDescentCache, f::NonlinearFunction{iip}, p, u, δu) where {iip}
9+
if iip
10+
binary_f = @closure (y, x) -> f(y, x, p)
11+
derivative!(hvvp, binary_f, cache.fu, u, δu, Val(2))
12+
else
13+
unary_f = Base.Fix2(f, p)
14+
hvvp = derivative(unary_f, u, δu, Val(2))
15+
end
16+
hvvp
17+
end
18+
19+
end

lib/NonlinearSolveBase/src/descent/halley.jl

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@ Improve the NewtonDescent with higher-order terms. First compute the descent dir
55
Then compute the hessian-vector-vector product and solve for the second-order correction term as ``J b = H a a``.
66
Finally, compute the descent direction as ``δu = a * a / (b / 2 - a)``.
77
8+
Note that `import TaylorDiff` is required to use this descent algorithm.
9+
810
See also [`NewtonDescent`](@ref).
911
"""
1012
@kwdef @concrete struct HalleyDescent <: AbstractDescentAlgorithm
1113
linsolve = nothing
1214
precs = DEFAULT_PRECS
1315
end
1416

15-
using TaylorDiff: derivative
16-
1717
function Base.show(io::IO, d::HalleyDescent)
1818
modifiers = String[]
1919
d.linsolve !== nothing && push!(modifiers, "linsolve = $(d.linsolve)")
@@ -30,6 +30,7 @@ supports_line_search(::HalleyDescent) = true
3030
δus
3131
b
3232
fu
33+
hvvp
3334
lincache
3435
timer
3536
end
@@ -43,13 +44,14 @@ function __internal_init(prob::NonlinearProblem, alg::HalleyDescent, J, fu, u; s
4344
@bb δu = similar(u)
4445
@bb b = similar(u)
4546
@bb fu = similar(fu)
47+
@bb hvvp = similar(fu)
4648
δus = N 1 ? nothing : map(2:N) do i
4749
@bb δu_ = similar(u)
4850
end
49-
INV && return HalleyDescentCache{true}(prob.f, prob.p, δu, δus, b, nothing, timer)
50-
lincache = LinearSolverCache(
51+
lincache = INV ? nothing :
52+
LinearSolverCache(
5153
alg, alg.linsolve, J, _vec(fu), _vec(u); stats, abstol, reltol, linsolve_kwargs...)
52-
return HalleyDescentCache{false}(prob.f, prob.p, δu, δus, b, fu, lincache, timer)
54+
return HalleyDescentCache{false}(prob.f, prob.p, δu, δus, b, fu, hvvp, lincache, timer)
5355
end
5456

5557
function __internal_solve!(cache::HalleyDescentCache{INV}, J, fu, u, idx::Val = Val(1);
@@ -73,7 +75,7 @@ function __internal_solve!(cache::HalleyDescentCache{INV}, J, fu, u, idx::Val =
7375
end
7476
b = cache.b
7577
# compute the hessian-vector-vector product
76-
hvvp = evaluate_hvvp(cache, cache.f, cache.p, u, δu)
78+
hvvp = evaluate_hvvp(cache.hvvp, cache, cache.f, cache.p, u, δu)
7779
# second linear solve, reuse factorization if possible
7880
if INV
7981
@bb b = J × vec(hvvp)
@@ -94,13 +96,4 @@ function __internal_solve!(cache::HalleyDescentCache{INV}, J, fu, u, idx::Val =
9496
return DescentResult(; δu)
9597
end
9698

97-
function evaluate_hvvp(
98-
cache::HalleyDescentCache, f::NonlinearFunction{iip}, p, u, δu) where {iip}
99-
if iip
100-
binary_f = @closure (y, x) -> f(y, x, p)
101-
derivative(binary_f, cache.fu, u, δu, Val{3}())
102-
else
103-
unary_f = Base.Fix2(f, p)
104-
derivative(unary_f, u, δu, Val{3}())
105-
end
106-
end
99+
evaluate_hvvp(hvvp, cache, f, p, u, δu) = error("not implemented. please import TaylorDiff")

lib/NonlinearSolveFirstOrder/src/halley.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Halley(; concrete_jac = nothing, linsolve = nothing, linesearch = NoLineSearch(),
2+
Halley(; concrete_jac = nothing, linsolve = nothing, linesearch = nothing,
33
precs = DEFAULT_PRECS, autodiff = nothing)
44
55
An experimental Halley's method implementation. Improves the convergence rate of Newton's method by using second-order derivative information to correct the descent direction.

test/23_test_problems_tests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
@testsetup module RobustnessTesting
22
using NonlinearSolve, LinearAlgebra, LinearSolve, NonlinearProblemLibrary, Test
3+
import TaylorDiff
34

45
problems = NonlinearProblemLibrary.problems
56
dicts = NonlinearProblemLibrary.dicts

0 commit comments

Comments
 (0)