@@ -5,6 +5,7 @@ using ConcreteStructs: @concrete
5
5
using ConstructionBase: ConstructionBase
6
6
using DifferentiationInterface: DifferentiationInterface
7
7
using FastClosures: @closure
8
+ using LinearAlgebra: LinearAlgebra
8
9
using SciMLBase: SciMLBase, AbstractNonlinearProblem, AbstractNonlinearFunction
9
10
using SciMLOperators: AbstractSciMLOperator
10
11
using Setfield: @set!
@@ -23,6 +24,57 @@ struct JVP <: AbstractMode end
23
24
flip_mode (:: VJP ) = JVP ()
24
25
flip_mode (:: JVP ) = VJP ()
25
26
27
+ """
28
+ JacobianOperator{iip, T} <: AbstractJacobianOperator{T} <: AbstractSciMLOperator{T}
29
+
30
+ A Jacobian Operator Provides both JVP and VJP without materializing either (if possible).
31
+
32
+ ### Constructor
33
+
34
+ ```julia
35
+ JacobianOperator(prob::AbstractNonlinearProblem, fu, u; jvp_autodiff = nothing,
36
+ vjp_autodiff = nothing, skip_vjp::Val = Val(false), skip_jvp::Val = Val(false))
37
+ ```
38
+
39
+ By default, the `JacobianOperator` will compute `JVP`. Use `Base.adjoint` or
40
+ `Base.transpose` to switch to `VJP`.
41
+
42
+ ### Computing the VJP
43
+
44
+ Computing the VJP is done according to the following rules:
45
+
46
+ - If `f` has a `vjp` method, then we use that.
47
+ - If `f` has a `jac` method and no `vjp_autodiff` is provided, then we use `jac * v`.
48
+ - If `vjp_autodiff` is provided we using DifferentiationInterface.jl to compute the VJP.
49
+
50
+ ### Computing the JVP
51
+
52
+ Computing the JVP is done according to the following rules:
53
+
54
+ - If `f` has a `jvp` method, then we use that.
55
+ - If `f` has a `jac` method and no `jvp_autodiff` is provided, then we use `v * jac`.
56
+ - If `jvp_autodiff` is provided we using DifferentiationInterface.jl to compute the JVP.
57
+
58
+ ### Special Case (Number)
59
+
60
+ For Number inputs, VJP and JVP are not distinct. Hence, if either `vjp` or `jvp` is
61
+ provided, then we use that. If neither is provided, then we use `v * jac` if `jac` is
62
+ provided. Finally, we use the respective autodiff methods to compute the derivative
63
+ using DifferentiationInterface.jl and multiply by `v`.
64
+
65
+ ### Methods Provided
66
+
67
+ !!! warning
68
+
69
+ Currently it is expected that `p` during problem construction is same as `p` during
70
+ operator evaluation. This restriction will be lifted in the future.
71
+
72
+ - `(op::JacobianOperator)(v, u, p)`: Computes `∂f(u, p)/∂u * v` or `∂f(u, p)/∂uᵀ * v`.
73
+ - `(op::JacobianOperator)(res, v, u, p)`: Computes `∂f(u, p)/∂u * v` or `∂f(u, p)/∂uᵀ * v`
74
+ and stores the result in `res`.
75
+
76
+ See also [`VecJacOperator`](@ref) and [`JacVecOperator`](@ref).
77
+ """
26
78
@concrete struct JacobianOperator{iip, T <: Real } <: AbstractJacobianOperator{T}
27
79
mode <: AbstractMode
28
80
@@ -65,8 +117,8 @@ function JacobianOperator(prob::AbstractNonlinearProblem, fu, u; jvp_autodiff =
65
117
vjp_op = prepare_vjp (skip_vjp, prob, f, u, fu; autodiff = vjp_autodiff)
66
118
jvp_op = prepare_jvp (skip_jvp, prob, f, u, fu; autodiff = jvp_autodiff)
67
119
68
- output_cache = iip ? similar (fu, T) : nothing
69
- input_cache = iip ? similar (u, T) : nothing
120
+ output_cache = similar (fu, T)
121
+ input_cache = similar (u, T)
70
122
71
123
return JacobianOperator {iip, T} (
72
124
JVP (), jvp_op, vjp_op, (length (fu), length (u)), output_cache, input_cache)
@@ -112,14 +164,106 @@ function (op::JacobianOperator)(Jv, v, u, p)
112
164
end
113
165
end
114
166
167
+ """
168
+ VecJacOperator(args...; autodiff = nothing, kwargs...)
169
+
170
+ Constructs a [`JacobianOperator`](@ref) which only provides the VJP using the
171
+ `vjp_autodiff = autodiff`.
172
+ """
115
173
function VecJacOperator (args... ; autodiff = nothing , kwargs... )
116
174
return JacobianOperator (args... ; kwargs... , skip_jvp = True, vjp_autodiff = autodiff)'
117
175
end
118
176
177
+ """
178
+ JacVecOperator(args...; autodiff = nothing, kwargs...)
179
+
180
+ Constructs a [`JacobianOperator`](@ref) which only provides the JVP using the
181
+ `jvp_autodiff = autodiff`.
182
+ """
119
183
function JacVecOperator (args... ; autodiff = nothing , kwargs... )
120
184
return JacobianOperator (args... ; kwargs... , skip_vjp = True, jvp_autodiff = autodiff)
121
185
end
122
186
187
+ """
188
+ StatefulJacobianOperator(jac_op::JacobianOperator, u, p)
189
+
190
+ Wrapper over a [`JacobianOperator`](@ref) which stores the input `u` and `p` and defines
191
+ `mul!` and `*` for computing VJPs and JVPs.
192
+ """
193
+ @concrete struct StatefulJacobianOperator{M <: AbstractMode , T} < :
194
+ AbstractJacobianOperator{T}
195
+ mode:: M
196
+ jac_op <: JacobianOperator
197
+ u
198
+ p
199
+
200
+ function StatefulJacobianOperator (jac_op:: JacobianOperator , u, p)
201
+ return new{
202
+ typeof (jac_op. mode), eltype (jac_op), typeof (jac_op), typeof (u), typeof (p)}(
203
+ jac_op. mode, jac_op, u, p)
204
+ end
205
+ end
206
+
207
+ Base. size (J:: StatefulJacobianOperator ) = size (J. jac_op)
208
+ Base. size (J:: StatefulJacobianOperator , d:: Integer ) = size (J. jac_op, d)
209
+
210
+ for op in (:adjoint , :transpose )
211
+ @eval function Base. $ (op)(operator:: StatefulJacobianOperator )
212
+ return StatefulJacobianOperator ($ (op)(operator. jac_op), operator. u, operator. p)
213
+ end
214
+ end
215
+
216
+ Base.:* (J:: StatefulJacobianOperator , v:: AbstractArray ) = J. jac_op (v, J. u, J. p)
217
+
218
+ function LinearAlgebra. mul! (
219
+ Jv:: AbstractArray , J:: StatefulJacobianOperator , v:: AbstractArray )
220
+ J. jac_op (Jv, v, J. u, J. p)
221
+ return Jv
222
+ end
223
+
224
+ """
225
+ StatefulJacobianNormalFormOperator(vjp_operator, jvp_operator, cache)
226
+
227
+ This constructs a Normal Form Jacobian Operator, i.e. it constructs the operator
228
+ corresponding to `JᵀJ` where `J` is the Jacobian Operator. This is not meant to be directly
229
+ constructed, rather it is constructed with `*` on two [`StatefulJacobianOperator`](@ref)s.
230
+ """
231
+ @concrete mutable struct StatefulJacobianNormalFormOperator{T} < :
232
+ AbstractJacobianOperator{T}
233
+ vjp_operator <: StatefulJacobianOperator{VJP}
234
+ jvp_operator <: StatefulJacobianOperator{JVP}
235
+ cache
236
+ end
237
+
238
+ function Base. size (J:: StatefulJacobianNormalFormOperator )
239
+ return size (J. vjp_operator, 1 ), size (J. jvp_operator, 2 )
240
+ end
241
+
242
+ function Base.:* (J1:: StatefulJacobianOperator{VJP} , J2:: StatefulJacobianOperator{JVP} )
243
+ cache = J2 * J2. jac_op. input_cache
244
+ T = promote_type (eltype (J1), eltype (J2))
245
+ return StatefulJacobianNormalFormOperator {T} (J1, J2, cache)
246
+ end
247
+
248
+ function LinearAlgebra. mul! (C:: StatefulJacobianNormalFormOperator ,
249
+ A:: StatefulJacobianOperator{VJP} , B:: StatefulJacobianOperator{JVP} )
250
+ C. vjp_operator = A
251
+ C. jvp_operator = B
252
+ return C
253
+ end
254
+
255
+ function Base.:* (JᵀJ:: StatefulJacobianNormalFormOperator , x:: AbstractArray )
256
+ return JᵀJ. vjp_operator * (JᵀJ. jvp_operator * x)
257
+ end
258
+
259
+ function LinearAlgebra. mul! (
260
+ JᵀJx:: AbstractArray , JᵀJ:: StatefulJacobianNormalFormOperator , x:: AbstractArray )
261
+ mul! (JᵀJ. cache, JᵀJ. jvp_operator, x)
262
+ mul! (JᵀJx, JᵀJ. vjp_operator, JᵀJ. cache)
263
+ return JᵀJx
264
+ end
265
+
266
+ # Helper Functions
123
267
prepare_vjp (:: Val{true} , args... ; kwargs... ) = nothing
124
268
125
269
function prepare_vjp (:: Val{false} , prob:: AbstractNonlinearProblem ,
@@ -163,8 +307,8 @@ function prepare_vjp(::Val{false}, prob::AbstractNonlinearProblem,
163
307
DI. pullback! (fₚ, fu_cache, reshape (vJ, size (u)), autodiff, u, v, di_extras)
164
308
end
165
309
else
166
- di_extras = DI. prepare_pullback (f , autodiff, u, fu)
167
- return @closure (v, u, p) -> DI. pullback (f , autodiff, u, v, di_extras)
310
+ di_extras = DI. prepare_pullback (fₚ , autodiff, u, fu)
311
+ return @closure (v, u, p) -> DI. pullback (fₚ , autodiff, u, v, di_extras)
168
312
end
169
313
end
170
314
@@ -206,8 +350,8 @@ function prepare_jvp(::Val{false}, prob::AbstractNonlinearProblem,
206
350
fu_cache = copy (fu)
207
351
di_extras = DI. prepare_pushforward (fₚ, fu_cache, autodiff, u, u)
208
352
return @closure (Jv, v, u, p) -> begin
209
- DI. pushforward! (fₚ, fu_cache, reshape (Jv, size (fu_cache)), autodiff, u, v,
210
- di_extras)
353
+ DI. pushforward! (
354
+ fₚ, fu_cache, reshape (Jv, size (fu_cache)), autodiff, u, v, di_extras)
211
355
return
212
356
end
213
357
else
@@ -231,5 +375,7 @@ function prepare_scalar_op(::Val{false}, prob::AbstractNonlinearProblem,
231
375
end
232
376
233
377
export JacobianOperator, VecJacOperator, JacVecOperator
378
+ export StatefulJacobianOperator
379
+ export StatefulJacobianNormalFormOperator
234
380
235
381
end
0 commit comments