Skip to content

Commit cf8dd0c

Browse files
Merge pull request #286 from avik-pal/ap/jvp
Allow specifying custom jvp
2 parents 1c416ba + ad257fd commit cf8dd0c

File tree

4 files changed

+58
-4
lines changed

4 files changed

+58
-4
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NonlinearSolve"
22
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
33
authors = ["SciML"]
4-
version = "2.8.0"
4+
version = "2.8.1"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -19,6 +19,7 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1919
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
2020
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2121
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
22+
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
2223
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
2324
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2425
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
@@ -48,14 +49,15 @@ FastLevenbergMarquardt = "0.1"
4849
FiniteDiff = "2"
4950
ForwardDiff = "0.10.3"
5051
LeastSquaresOptim = "0.8"
51-
LinearAlgebra = "1.9"
5252
LineSearches = "7"
53+
LinearAlgebra = "1.9"
5354
LinearSolve = "2.12"
5455
NonlinearProblemLibrary = "0.1"
5556
PrecompileTools = "1"
5657
RecursiveArrayTools = "2"
5758
Reexport = "0.2, 1"
5859
SciMLBase = "2.4"
60+
SciMLOperators = "0.3"
5961
SimpleNonlinearSolve = "0.1.23"
6062
SparseArrays = "1.9"
6163
SparseDiffTools = "2.11"

src/NonlinearSolve.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_work
2323
import RecursiveArrayTools: ArrayPartition,
2424
AbstractVectorOfArray, recursivecopy!, recursivefill!
2525
import SciMLBase: AbstractNonlinearAlgorithm, NLStats, _unwrap_val, has_jac, isinplace
26+
import SciMLOperators: FunctionOperator
2627
import StaticArraysCore: StaticArray, SVector, SArray, MArray
2728
import UnPack: @unpack
2829

src/jacobian.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,21 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val
7777
# FIXME: To properly support needsJᵀJ without Jacobian, we need to implement
7878
# a reverse diff operation with the seed being `Jx`, this is not yet implemented
7979
J = if !(linsolve_needs_jac || alg_wants_jac || needsJᵀJ)
80-
# We don't need to construct the Jacobian
81-
JacVec(uf, u; autodiff = __get_nonsparse_ad(alg.ad))
80+
if f.jvp === nothing
81+
# We don't need to construct the Jacobian
82+
JacVec(uf, u; autodiff = __get_nonsparse_ad(alg.ad))
83+
else
84+
if iip
85+
jvp = (_, u, v) -> (du = similar(fu); f.jvp(du, v, u, p); du)
86+
jvp! = (du, _, u, v) -> f.jvp(du, v, u, p)
87+
else
88+
jvp = (_, u, v) -> f.jvp(v, u, p)
89+
jvp! = (du, _, u, v) -> (du .= f.jvp(v, u, p))
90+
end
91+
op = SparseDiffTools.FwdModeAutoDiffVecProd(f, u, (), jvp, jvp!)
92+
FunctionOperator(op, u, fu; isinplace = Val(true), outofplace = Val(false),
93+
p, islinear = true)
94+
end
8295
else
8396
if has_analytic_jac
8497
f.jac_prototype === nothing ? undefmatrix(u) : f.jac_prototype

test/basictests.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -973,3 +973,41 @@ end
973973
termination_condition).u .≈ sqrt(2.0))
974974
end
975975
end
976+
977+
# Miscelleneous Tests
978+
@testset "Custom JVP" begin
979+
function F(u::Vector{Float64}, p::Vector{Float64})
980+
Δ = Tridiagonal(-ones(99), 2 * ones(100), -ones(99))
981+
return u + 0.1 * u .* Δ * u - p
982+
end
983+
984+
function F!(du::Vector{Float64}, u::Vector{Float64}, p::Vector{Float64})
985+
Δ = Tridiagonal(-ones(99), 2 * ones(100), -ones(99))
986+
du .= u + 0.1 * u .* Δ * u - p
987+
return nothing
988+
end
989+
990+
function JVP(v::Vector{Float64}, u::Vector{Float64}, p::Vector{Float64})
991+
Δ = Tridiagonal(-ones(99), 2 * ones(100), -ones(99))
992+
return v + 0.1 * (u .* Δ * v + v .* Δ * u)
993+
end
994+
995+
function JVP!(du::Vector{Float64}, v::Vector{Float64}, u::Vector{Float64},
996+
p::Vector{Float64})
997+
Δ = Tridiagonal(-ones(99), 2 * ones(100), -ones(99))
998+
du .= v + 0.1 * (u .* Δ * v + v .* Δ * u)
999+
return nothing
1000+
end
1001+
1002+
u0 = rand(100)
1003+
1004+
prob = NonlinearProblem(NonlinearFunction{false}(F; jvp = JVP), u0, u0)
1005+
sol = solve(prob, NewtonRaphson(; linsolve = KrylovJL_GMRES()))
1006+
1007+
@test norm(F(sol.u, u0)) 1e-8
1008+
1009+
prob = NonlinearProblem(NonlinearFunction{true}(F!; jvp = JVP!), u0, u0)
1010+
sol = solve(prob, NewtonRaphson(; linsolve = KrylovJL_GMRES()))
1011+
1012+
@test norm(F(sol.u, u0)) 1e-8
1013+
end

0 commit comments

Comments
 (0)