@@ -13,6 +13,8 @@ const DI = DifferentiationInterface
13
13
const True = Val (true )
14
14
const False = Val (false )
15
15
16
+ abstract type AbstractJacobianOperator{T} <: AbstractSciMLOperator{T} end
17
+
16
18
abstract type AbstractMode end
17
19
18
20
struct VJP <: AbstractMode end
@@ -21,17 +23,20 @@ struct JVP <: AbstractMode end
21
23
flip_mode (:: VJP ) = JVP ()
22
24
flip_mode (:: JVP ) = VJP ()
23
25
24
- @concrete struct JacobianOperator{iip, T <: Real } <: AbstractSciMLOperator {T}
26
+ @concrete struct JacobianOperator{iip, T <: Real } <: AbstractJacobianOperator {T}
25
27
mode <: AbstractMode
26
28
27
29
jvp_op
28
30
vjp_op
29
31
30
32
size
31
- jvp_extras
32
- vjp_extras
33
+
34
+ output_cache
35
+ input_cache
33
36
end
34
37
38
+ SciMLBase. isinplace (:: JacobianOperator{iip} ) where {iip} = iip
39
+
35
40
function ConstructionBase. constructorof (:: Type{<:JacobianOperator{iip, T}} ) where {iip, T}
36
41
return JacobianOperator{iip, T}
37
42
end
@@ -42,6 +47,9 @@ Base.size(J::JacobianOperator, d::Integer) = J.size[d]
42
47
for op in (:adjoint , :transpose )
43
48
@eval function Base. $ (op)(operator:: JacobianOperator )
44
49
@set! operator. mode = flip_mode (operator. mode)
50
+ (; output_cache, input_cache) = operator
51
+ @set! operator. output_cache = input_cache
52
+ @set! operator. input_cache = output_cache
45
53
return operator
46
54
end
47
55
end
@@ -53,16 +61,66 @@ function JacobianOperator(prob::AbstractNonlinearProblem, fu, u; jvp_autodiff =
53
61
f = prob. f
54
62
iip = SciMLBase. isinplace (prob)
55
63
T = promote_type (eltype (u), eltype (fu))
56
- fₚ = SciMLBase. JacobianWrapper {iip} (f, prob. p)
57
64
58
- vjp_op, vjp_extras = prepare_vjp (skip_vjp, prob, f, u, fu; autodiff = vjp_autodiff)
59
- jvp_op, jvp_extras = prepare_jvp (skip_jvp, prob, f, u, fu; autodiff = jvp_autodiff)
65
+ vjp_op = prepare_vjp (skip_vjp, prob, f, u, fu; autodiff = vjp_autodiff)
66
+ jvp_op = prepare_jvp (skip_jvp, prob, f, u, fu; autodiff = jvp_autodiff)
67
+
68
+ output_cache = iip ? similar (fu, T) : nothing
69
+ input_cache = iip ? similar (u, T) : nothing
60
70
61
71
return JacobianOperator {iip, T} (
62
- JVP (), jvp_op, vjp_op, (length (fu), length (u)), jvp_extras, vjp_extras )
72
+ JVP (), jvp_op, vjp_op, (length (fu), length (u)), output_cache, input_cache )
63
73
end
64
74
65
- prepare_vjp (:: Val{true} , args... ; kwargs... ) = nothing , nothing
75
+ function (op:: JacobianOperator )(v, u, p)
76
+ if op. mode isa VJP
77
+ if SciMLBase. isinplace (op)
78
+ res = zero (op. output_cache)
79
+ op. vjp_op (res, v, u, p)
80
+ return res
81
+ end
82
+ return op. vjp_op (v, u, p)
83
+ else
84
+ if SciMLBase. isinplace (op)
85
+ res = zero (op. output_cache)
86
+ op. jvp_op (res, v, u, p)
87
+ return res
88
+ end
89
+ return op. jvp_op (v, u, p)
90
+ end
91
+ end
92
+
93
+ function (op:: JacobianOperator )(:: Number , :: Number , _, __)
94
+ error (" Inplace Jacobian Operator not possible for scalars." )
95
+ end
96
+
97
+ function (op:: JacobianOperator )(Jv, v, u, p)
98
+ if op. mode isa VJP
99
+ if SciMLBase. isinplace (op)
100
+ op. vjp_op (Jv, v, u, p)
101
+ return
102
+ end
103
+ copyto! (Jv, op. vjp_op (v, u, p))
104
+ return
105
+ else
106
+ if SciMLBase. isinplace (op)
107
+ op. jvp_op (Jv, v, u, p)
108
+ return
109
+ end
110
+ copyto! (Jv, op. jvp_op (v, u, p))
111
+ return
112
+ end
113
+ end
114
+
115
+ function VecJacOperator (args... ; autodiff = nothing , kwargs... )
116
+ return JacobianOperator (args... ; kwargs... , skip_jvp = True, vjp_autodiff = autodiff)'
117
+ end
118
+
119
+ function JacVecOperator (args... ; autodiff = nothing , kwargs... )
120
+ return JacobianOperator (args... ; kwargs... , skip_vjp = True, jvp_autodiff = autodiff)
121
+ end
122
+
123
+ prepare_vjp (:: Val{true} , args... ; kwargs... ) = nothing
66
124
67
125
function prepare_vjp (:: Val{false} , prob:: AbstractNonlinearProblem ,
68
126
f:: AbstractNonlinearFunction , u:: Number , fu:: Number ; autodiff = nothing )
71
129
72
130
function prepare_vjp (:: Val{false} , prob:: AbstractNonlinearProblem ,
73
131
f:: AbstractNonlinearFunction , u, fu; autodiff = nothing )
74
- SciMLBase. has_vjp (f) && return f. vjp, nothing
132
+ SciMLBase. has_vjp (f) && return f. vjp
75
133
76
134
if autodiff === nothing && SciMLBase. has_jac (f)
77
135
if SciMLBase. isinplace (f)
78
- vjp_extras = (; jac_cache = similar (u, eltype (fu), length (fu), length (u) ))
79
- vjp_op = @closure (vJ, v, u, p, extras ) -> begin
80
- f. jac (extras . jac_cache, u, p)
81
- mul! (vec (vJ), extras . jac_cache' , vec (v))
136
+ jac_cache = similar (u, eltype (fu), length (fu), length (u))
137
+ return @closure (vJ, v, u, p) -> begin
138
+ f. jac (jac_cache, u, p)
139
+ mul! (vec (vJ), jac_cache' , vec (v))
82
140
return
83
141
end
84
142
return vjp_op, vjp_extras
85
143
else
86
- vjp_op = @closure (v, u, p, _) -> reshape (f. jac (u, p)' * vec (v), size (u))
87
- return vjp_op, nothing
144
+ return @closure (v, u, p) -> reshape (f. jac (u, p)' * vec (v), size (u))
88
145
end
89
146
end
90
147
@@ -102,21 +159,16 @@ function prepare_vjp(::Val{false}, prob::AbstractNonlinearProblem,
102
159
fu_cache = copy (fu)
103
160
v_fake = copy (fu)
104
161
di_extras = DI. prepare_pullback (fₚ, fu_cache, autodiff, u, v_fake)
105
- vjp_op = @closure (vJ, v, u, p, extras) -> begin
106
- DI. pullback! (
107
- fₚ, extras. fu_cache, reshape (vJ, size (u)), autodiff, u, v, extras. di_extras)
162
+ return @closure (vJ, v, u, p) -> begin
163
+ DI. pullback! (fₚ, fu_cache, reshape (vJ, size (u)), autodiff, u, v, di_extras)
108
164
end
109
- return vjp_op, (; di_extras, fu_cache)
110
165
else
111
166
di_extras = DI. prepare_pullback (f, autodiff, u, fu)
112
- vjp_op = @closure (v, u, p, extras) -> begin
113
- return DI. pullback (f, autodiff, u, v, extras. di_extras)
114
- end
115
- return vjp_op, (; di_extras)
167
+ return @closure (v, u, p) -> DI. pullback (f, autodiff, u, v, di_extras)
116
168
end
117
169
end
118
170
119
- prepare_jvp (skip:: Val{true} , args... ; kwargs... ) = nothing , nothing
171
+ prepare_jvp (skip:: Val{true} , args... ; kwargs... ) = nothing
120
172
121
173
function prepare_jvp (:: Val{false} , prob:: AbstractNonlinearProblem ,
122
174
f:: AbstractNonlinearFunction , u:: Number , fu:: Number ; autodiff = nothing )
@@ -125,20 +177,18 @@ end
125
177
126
178
function prepare_jvp (:: Val{false} , prob:: AbstractNonlinearProblem ,
127
179
f:: AbstractNonlinearFunction , u, fu; autodiff = nothing )
128
- SciMLBase. has_vjp (f) && return f. vjp, nothing
180
+ SciMLBase. has_vjp (f) && return f. vjp
129
181
130
182
if autodiff === nothing && SciMLBase. has_jac (f)
131
183
if SciMLBase. isinplace (f)
132
- jvp_extras = (; jac_cache = similar (u, eltype (fu), length (fu), length (u) ))
133
- jvp_op = @closure (Jv, v, u, p, extras ) -> begin
134
- f. jac (extras . jac_cache, u, p)
135
- mul! (vec (Jv), extras . jac_cache, vec (v))
184
+ jac_cache = similar (u, eltype (fu), length (fu), length (u))
185
+ return @closure (Jv, v, u, p) -> begin
186
+ f. jac (jac_cache, u, p)
187
+ mul! (vec (Jv), jac_cache, vec (v))
136
188
return
137
189
end
138
- return jvp_op, jvp_extras
139
190
else
140
- jvp_op = @closure (v, u, p, _) -> reshape (f. jac (u, p) * vec (v), size (u))
141
- return jvp_op, nothing
191
+ return @closure (v, u, p, _) -> reshape (f. jac (u, p) * vec (v), size (u))
142
192
end
143
193
end
144
194
@@ -155,43 +205,29 @@ function prepare_jvp(::Val{false}, prob::AbstractNonlinearProblem,
155
205
if SciMLBase. isinplace (f)
156
206
fu_cache = copy (fu)
157
207
di_extras = DI. prepare_pushforward (fₚ, fu_cache, autodiff, u, u)
158
- jvp_op = @closure (Jv, v, u, p, extras) -> begin
159
- DI. pushforward! (fₚ, extras. fu_cache, reshape (Jv, size (extras. fu_cache)),
160
- autodiff, u, v, extras. di_extras)
208
+ return @closure (Jv, v, u, p) -> begin
209
+ DI. pushforward! (fₚ, fu_cache, reshape (Jv, size (fu_cache)), autodiff, u, v,
210
+ di_extras)
211
+ return
161
212
end
162
- return jvp_op, (; di_extras, fu_cache)
163
213
else
164
- di_extras = DI. prepare_pushforward (f, autodiff, u, u)
165
- jvp_op = @closure (v, u, p, extras) -> begin
166
- return DI. pushforward (f, autodiff, u, v, extras. di_extras)
167
- end
168
- return jvp_op, (; di_extras)
214
+ di_extras = DI. prepare_pushforward (fₚ, autodiff, u, u)
215
+ return @closure (v, u, p) -> DI. pushforward (fₚ, autodiff, u, v, di_extras)
169
216
end
170
217
end
171
218
172
219
function prepare_scalar_op (:: Val{false} , prob:: AbstractNonlinearProblem ,
173
220
f:: AbstractNonlinearFunction , u:: Number , fu:: Number ; autodiff = nothing )
174
- SciMLBase. has_vjp (f) && return f. vjp, nothing
175
- SciMLBase. has_jvp (f) && return f. jvp, nothing
176
- SciMLBase. has_jac (f) && return @closure ((v, u, p, _ )-> f. jac (u, p) * v), nothing
221
+ SciMLBase. has_vjp (f) && return f. vjp
222
+ SciMLBase. has_jvp (f) && return f. jvp
223
+ SciMLBase. has_jac (f) && return @closure ((v, u, p)-> f. jac (u, p) * v)
177
224
178
225
@assert autodiff!= = nothing " `autodiff` must be provided if `f` doesn't have \
179
226
analytic `vjp` or `jvp` or `jac`."
180
227
# TODO : Once DI supports const params we can use `p`
181
228
fₚ = Base. Fix2 (f, prob. p)
182
229
di_extras = DI. prepare_derivative (fₚ, autodiff, u)
183
- op = @closure (v, u, p, extras) -> begin
184
- return DI. derivative (fₚ, autodiff, u, extras. di_extras) * v
185
- end
186
- return op, (; di_extras)
187
- end
188
-
189
- function VecJacOperator (args... ; autodiff = nothing , kwargs... )
190
- return JacobianOperator (args... ; kwargs... , skip_jvp = True, vjp_autodiff = autodiff)'
191
- end
192
-
193
- function JacVecOperator (args... ; autodiff = nothing , kwargs... )
194
- return JacobianOperator (args... ; kwargs... , skip_vjp = True, jvp_autodiff = autodiff)
230
+ return @closure (v, u, p) -> DI. derivative (fₚ, autodiff, u, di_extras) * v
195
231
end
196
232
197
233
export JacobianOperator, VecJacOperator, JacVecOperator
0 commit comments