Skip to content

Commit 31c9569

Browse files
committed
feat: integrate SciMLJacobianOperators into NonlinearSolve
1 parent 09aa6e8 commit 31c9569

File tree

6 files changed

+164
-308
lines changed

6 files changed

+164
-308
lines changed

docs/src/devdocs/operators.md

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,6 @@
66
NonlinearSolve.AbstractNonlinearSolveOperator
77
```
88

9-
## Jacobian Operators
10-
11-
```@docs
12-
NonlinearSolve.JacobianOperator
13-
NonlinearSolve.VecJacOperator
14-
NonlinearSolve.JacVecOperator
15-
```
16-
17-
### Stateful Jacobian Operators
18-
19-
```@docs
20-
NonlinearSolve.StatefulJacobianOperator
21-
NonlinearSolve.StatefulJacobianNormalFormOperator
22-
```
23-
249
## Low-Rank Jacobian Operators
2510

2611
```@docs

lib/SciMLJacobianOperators/src/SciMLJacobianOperators.jl

Lines changed: 152 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using ConcreteStructs: @concrete
55
using ConstructionBase: ConstructionBase
66
using DifferentiationInterface: DifferentiationInterface
77
using FastClosures: @closure
8+
using LinearAlgebra: LinearAlgebra
89
using SciMLBase: SciMLBase, AbstractNonlinearProblem, AbstractNonlinearFunction
910
using SciMLOperators: AbstractSciMLOperator
1011
using Setfield: @set!
@@ -23,6 +24,57 @@ struct JVP <: AbstractMode end
2324
flip_mode(::VJP) = JVP()
2425
flip_mode(::JVP) = VJP()
2526

27+
"""
28+
JacobianOperator{iip, T} <: AbstractJacobianOperator{T} <: AbstractSciMLOperator{T}
29+
30+
A Jacobian Operator Provides both JVP and VJP without materializing either (if possible).
31+
32+
### Constructor
33+
34+
```julia
35+
JacobianOperator(prob::AbstractNonlinearProblem, fu, u; jvp_autodiff = nothing,
36+
vjp_autodiff = nothing, skip_vjp::Val = Val(false), skip_jvp::Val = Val(false))
37+
```
38+
39+
By default, the `JacobianOperator` will compute `JVP`. Use `Base.adjoint` or
40+
`Base.transpose` to switch to `VJP`.
41+
42+
### Computing the VJP
43+
44+
Computing the VJP is done according to the following rules:
45+
46+
- If `f` has a `vjp` method, then we use that.
47+
- If `f` has a `jac` method and no `vjp_autodiff` is provided, then we use `jac * v`.
48+
- If `vjp_autodiff` is provided we using DifferentiationInterface.jl to compute the VJP.
49+
50+
### Computing the JVP
51+
52+
Computing the JVP is done according to the following rules:
53+
54+
- If `f` has a `jvp` method, then we use that.
55+
- If `f` has a `jac` method and no `jvp_autodiff` is provided, then we use `v * jac`.
56+
- If `jvp_autodiff` is provided we using DifferentiationInterface.jl to compute the JVP.
57+
58+
### Special Case (Number)
59+
60+
For Number inputs, VJP and JVP are not distinct. Hence, if either `vjp` or `jvp` is
61+
provided, then we use that. If neither is provided, then we use `v * jac` if `jac` is
62+
provided. Finally, we use the respective autodiff methods to compute the derivative
63+
using DifferentiationInterface.jl and multiply by `v`.
64+
65+
### Methods Provided
66+
67+
!!! warning
68+
69+
Currently it is expected that `p` during problem construction is same as `p` during
70+
operator evaluation. This restriction will be lifted in the future.
71+
72+
- `(op::JacobianOperator)(v, u, p)`: Computes `∂f(u, p)/∂u * v` or `∂f(u, p)/∂uᵀ * v`.
73+
- `(op::JacobianOperator)(res, v, u, p)`: Computes `∂f(u, p)/∂u * v` or `∂f(u, p)/∂uᵀ * v`
74+
and stores the result in `res`.
75+
76+
See also [`VecJacOperator`](@ref) and [`JacVecOperator`](@ref).
77+
"""
2678
@concrete struct JacobianOperator{iip, T <: Real} <: AbstractJacobianOperator{T}
2779
mode <: AbstractMode
2880

@@ -65,8 +117,8 @@ function JacobianOperator(prob::AbstractNonlinearProblem, fu, u; jvp_autodiff =
65117
vjp_op = prepare_vjp(skip_vjp, prob, f, u, fu; autodiff = vjp_autodiff)
66118
jvp_op = prepare_jvp(skip_jvp, prob, f, u, fu; autodiff = jvp_autodiff)
67119

68-
output_cache = iip ? similar(fu, T) : nothing
69-
input_cache = iip ? similar(u, T) : nothing
120+
output_cache = similar(fu, T)
121+
input_cache = similar(u, T)
70122

71123
return JacobianOperator{iip, T}(
72124
JVP(), jvp_op, vjp_op, (length(fu), length(u)), output_cache, input_cache)
@@ -112,14 +164,106 @@ function (op::JacobianOperator)(Jv, v, u, p)
112164
end
113165
end
114166

167+
"""
168+
VecJacOperator(args...; autodiff = nothing, kwargs...)
169+
170+
Constructs a [`JacobianOperator`](@ref) which only provides the VJP using the
171+
`vjp_autodiff = autodiff`.
172+
"""
115173
function VecJacOperator(args...; autodiff = nothing, kwargs...)
116174
return JacobianOperator(args...; kwargs..., skip_jvp = True, vjp_autodiff = autodiff)'
117175
end
118176

177+
"""
178+
JacVecOperator(args...; autodiff = nothing, kwargs...)
179+
180+
Constructs a [`JacobianOperator`](@ref) which only provides the JVP using the
181+
`jvp_autodiff = autodiff`.
182+
"""
119183
function JacVecOperator(args...; autodiff = nothing, kwargs...)
120184
return JacobianOperator(args...; kwargs..., skip_vjp = True, jvp_autodiff = autodiff)
121185
end
122186

187+
"""
188+
StatefulJacobianOperator(jac_op::JacobianOperator, u, p)
189+
190+
Wrapper over a [`JacobianOperator`](@ref) which stores the input `u` and `p` and defines
191+
`mul!` and `*` for computing VJPs and JVPs.
192+
"""
193+
@concrete struct StatefulJacobianOperator{M <: AbstractMode, T} <:
194+
AbstractJacobianOperator{T}
195+
mode::M
196+
jac_op <: JacobianOperator
197+
u
198+
p
199+
200+
function StatefulJacobianOperator(jac_op::JacobianOperator, u, p)
201+
return new{
202+
typeof(jac_op.mode), eltype(jac_op), typeof(jac_op), typeof(u), typeof(p)}(
203+
jac_op.mode, jac_op, u, p)
204+
end
205+
end
206+
207+
Base.size(J::StatefulJacobianOperator) = size(J.jac_op)
208+
Base.size(J::StatefulJacobianOperator, d::Integer) = size(J.jac_op, d)
209+
210+
for op in (:adjoint, :transpose)
211+
@eval function Base.$(op)(operator::StatefulJacobianOperator)
212+
return StatefulJacobianOperator($(op)(operator.jac_op), operator.u, operator.p)
213+
end
214+
end
215+
216+
Base.:*(J::StatefulJacobianOperator, v::AbstractArray) = J.jac_op(v, J.u, J.p)
217+
218+
function LinearAlgebra.mul!(
219+
Jv::AbstractArray, J::StatefulJacobianOperator, v::AbstractArray)
220+
J.jac_op(Jv, v, J.u, J.p)
221+
return Jv
222+
end
223+
224+
"""
225+
StatefulJacobianNormalFormOperator(vjp_operator, jvp_operator, cache)
226+
227+
This constructs a Normal Form Jacobian Operator, i.e. it constructs the operator
228+
corresponding to `JᵀJ` where `J` is the Jacobian Operator. This is not meant to be directly
229+
constructed, rather it is constructed with `*` on two [`StatefulJacobianOperator`](@ref)s.
230+
"""
231+
@concrete mutable struct StatefulJacobianNormalFormOperator{T} <:
232+
AbstractJacobianOperator{T}
233+
vjp_operator <: StatefulJacobianOperator{VJP}
234+
jvp_operator <: StatefulJacobianOperator{JVP}
235+
cache
236+
end
237+
238+
function Base.size(J::StatefulJacobianNormalFormOperator)
239+
return size(J.vjp_operator, 1), size(J.jvp_operator, 2)
240+
end
241+
242+
function Base.:*(J1::StatefulJacobianOperator{VJP}, J2::StatefulJacobianOperator{JVP})
243+
cache = J2 * J2.jac_op.input_cache
244+
T = promote_type(eltype(J1), eltype(J2))
245+
return StatefulJacobianNormalFormOperator{T}(J1, J2, cache)
246+
end
247+
248+
function LinearAlgebra.mul!(C::StatefulJacobianNormalFormOperator,
249+
A::StatefulJacobianOperator{VJP}, B::StatefulJacobianOperator{JVP})
250+
C.vjp_operator = A
251+
C.jvp_operator = B
252+
return C
253+
end
254+
255+
function Base.:*(JᵀJ::StatefulJacobianNormalFormOperator, x::AbstractArray)
256+
return JᵀJ.vjp_operator * (JᵀJ.jvp_operator * x)
257+
end
258+
259+
function LinearAlgebra.mul!(
260+
JᵀJx::AbstractArray, JᵀJ::StatefulJacobianNormalFormOperator, x::AbstractArray)
261+
mul!(JᵀJ.cache, JᵀJ.jvp_operator, x)
262+
mul!(JᵀJx, JᵀJ.vjp_operator, JᵀJ.cache)
263+
return JᵀJx
264+
end
265+
266+
# Helper Functions
123267
prepare_vjp(::Val{true}, args...; kwargs...) = nothing
124268

125269
function prepare_vjp(::Val{false}, prob::AbstractNonlinearProblem,
@@ -163,8 +307,8 @@ function prepare_vjp(::Val{false}, prob::AbstractNonlinearProblem,
163307
DI.pullback!(fₚ, fu_cache, reshape(vJ, size(u)), autodiff, u, v, di_extras)
164308
end
165309
else
166-
di_extras = DI.prepare_pullback(f, autodiff, u, fu)
167-
return @closure (v, u, p) -> DI.pullback(f, autodiff, u, v, di_extras)
310+
di_extras = DI.prepare_pullback(fₚ, autodiff, u, fu)
311+
return @closure (v, u, p) -> DI.pullback(fₚ, autodiff, u, v, di_extras)
168312
end
169313
end
170314

@@ -206,8 +350,8 @@ function prepare_jvp(::Val{false}, prob::AbstractNonlinearProblem,
206350
fu_cache = copy(fu)
207351
di_extras = DI.prepare_pushforward(fₚ, fu_cache, autodiff, u, u)
208352
return @closure (Jv, v, u, p) -> begin
209-
DI.pushforward!(fₚ, fu_cache, reshape(Jv, size(fu_cache)), autodiff, u, v,
210-
di_extras)
353+
DI.pushforward!(
354+
fₚ, fu_cache, reshape(Jv, size(fu_cache)), autodiff, u, v, di_extras)
211355
return
212356
end
213357
else
@@ -231,5 +375,7 @@ function prepare_scalar_op(::Val{false}, prob::AbstractNonlinearProblem,
231375
end
232376

233377
export JacobianOperator, VecJacOperator, JacVecOperator
378+
export StatefulJacobianOperator
379+
export StatefulJacobianNormalFormOperator
234380

235381
end

src/NonlinearSolve.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ using Preferences: Preferences, @load_preference, @set_preferences!
4040
using RecursiveArrayTools: recursivecopy!, recursivefill!
4141
using SciMLBase: AbstractNonlinearAlgorithm, JacobianWrapper, AbstractNonlinearProblem,
4242
AbstractSciMLOperator, _unwrap_val, has_jac, isinplace, NLStats
43+
using SciMLJacobianOperators: JacobianOperator, VecJacOperator, JacVecOperator,
44+
StatefulJacobianOperator, StatefulJacobianNormalFormOperator
4345
using SparseArrays: AbstractSparseMatrix, SparseMatrixCSC
4446
using SparseDiffTools: SparseDiffTools, AbstractSparsityDetection,
4547
ApproximateJacobianSparsity, JacPrototypeSparsityDetection,
@@ -72,7 +74,6 @@ include("descent/dogleg.jl")
7274
include("descent/damped_newton.jl")
7375
include("descent/geodesic_acceleration.jl")
7476

75-
include("internal/operators.jl")
7677
include("internal/jacobian.jl")
7778
include("internal/forward_diff.jl")
7879
include("internal/linear_solve.jl")

src/globalization/trust_region.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,11 +386,13 @@ function __internal_init(
386386
p1, p2, p3, p4 = __get_parameters(T, alg.method)
387387
ϵ = T(1e-8)
388388

389+
reverse_ad = get_concrete_reverse_ad(alg.reverse_ad, prob; check_reverse_mode = true)
389390
vjp_operator = alg.method isa RUS.__Yuan || alg.method isa RUS.__Bastin ?
390-
VecJacOperator(prob, fu, u; autodiff = alg.reverse_ad) : nothing
391+
VecJacOperator(prob, fu, u; autodiff = reverse_ad) : nothing
391392

393+
forward_ad = get_concrete_forward_ad(alg.forward_ad, prob; check_forward_mode = true)
392394
jvp_operator = alg.method isa RUS.__Bastin ?
393-
JacVecOperator(prob, fu, u; autodiff = alg.forward_ad) : nothing
395+
JacVecOperator(prob, fu, u; autodiff = forward_ad) : nothing
394396

395397
if alg.method isa RUS.__Yuan
396398
Jᵀfu_cache = StatefulJacobianOperator(vjp_operator, u, prob.p) * _vec(fu)

src/internal/jacobian.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ Construct a cache for the Jacobian of `f` w.r.t. `u`.
2525
- `jvp_autodiff`: Automatic Differentiation or Finite Differencing backend for computing
2626
the Jacobian-vector product.
2727
- `linsolve`: Linear Solver Algorithm used to determine if we need a concrete jacobian
28-
or if possible we can just use a [`NonlinearSolve.JacobianOperator`](@ref) instead.
28+
or if possible we can just use a [`SciMLJacobianOperators.JacobianOperator`](@ref)
29+
instead.
2930
"""
3031
@concrete mutable struct JacobianCache{iip} <: AbstractNonlinearSolveJacobianCache{iip}
3132
J
@@ -85,8 +86,7 @@ function JacobianCache(prob, alg, f::F, fu_, u, p; stats, autodiff = nothing,
8586
__similar(fu, promote_type(eltype(fu), eltype(u)), length(fu), length(u)) :
8687
copy(f.jac_prototype)
8788
elseif f.jac_prototype === nothing
88-
zero(init_jacobian(
89-
jac_cache; preserve_immutable = Val(true)))
89+
zero(init_jacobian(jac_cache; preserve_immutable = Val(true)))
9090
else
9191
f.jac_prototype
9292
end
@@ -114,9 +114,9 @@ end
114114

115115
@inline (cache::JacobianCache)(u = cache.u) = cache(cache.J, u, cache.p)
116116
@inline function (cache::JacobianCache)(::Nothing)
117-
J = cache.J
118-
J isa JacobianOperator && return StatefulJacobianOperator(J, cache.u, cache.p)
119-
return J
117+
cache.J isa JacobianOperator &&
118+
return StatefulJacobianOperator(cache.J, cache.u, cache.p)
119+
return cache.J
120120
end
121121

122122
function (cache::JacobianCache)(J::JacobianOperator, u, p = cache.p)

0 commit comments

Comments
 (0)