1
1
module SciMLJacobianOperators
2
2
3
- using ADTypes: ADTypes, AutoSparse
3
+ using ADTypes: ADTypes, AutoSparse, AutoEnzyme
4
4
using ConcreteStructs: @concrete
5
5
using ConstructionBase: ConstructionBase
6
6
using DifferentiationInterface: DifferentiationInterface
7
+ using EnzymeCore: EnzymeCore
7
8
using FastClosures: @closure
8
9
using LinearAlgebra: LinearAlgebra
9
10
using SciMLBase: SciMLBase, AbstractNonlinearProblem, AbstractNonlinearFunction
@@ -115,10 +116,10 @@ function JacobianOperator(prob::AbstractNonlinearProblem, fu, u; jvp_autodiff =
115
116
iip = SciMLBase. isinplace (prob)
116
117
T = promote_type (eltype (u), eltype (fu))
117
118
118
- vjp_autodiff = get_dense_ad (vjp_autodiff)
119
+ vjp_autodiff = set_function_as_const ( get_dense_ad (vjp_autodiff) )
119
120
vjp_op = prepare_vjp (skip_vjp, prob, f, u, fu; autodiff = vjp_autodiff)
120
121
121
- jvp_autodiff = get_dense_ad (jvp_autodiff)
122
+ jvp_autodiff = set_function_as_const ( get_dense_ad (jvp_autodiff) )
122
123
jvp_op = prepare_jvp (skip_jvp, prob, f, u, fu; autodiff = jvp_autodiff)
123
124
124
125
output_cache = fu isa Number ? T (fu) : similar (fu, T)
259
260
function Base.:* (JᵀJ:: StatefulJacobianNormalFormOperator , x:: AbstractArray )
260
261
return JᵀJ. vjp_operator * (JᵀJ. jvp_operator * x)
261
262
end
263
+ function Base.:* (JᵀJ:: StatefulJacobianNormalFormOperator , x:: Number )
264
+ return JᵀJ. vjp_operator * (JᵀJ. jvp_operator * x)
265
+ end
262
266
263
267
function LinearAlgebra. mul! (
264
268
JᵀJx:: AbstractArray , JᵀJ:: StatefulJacobianNormalFormOperator , x:: AbstractArray )
@@ -284,7 +288,7 @@ function prepare_vjp(::Val{false}, prob::AbstractNonlinearProblem,
284
288
jac_cache = similar (u, eltype (fu), length (fu), length (u))
285
289
return @closure (vJ, v, u, p) -> begin
286
290
f. jac (jac_cache, u, p)
287
- mul! (vec (vJ), jac_cache' , vec (v))
291
+ LinearAlgebra . mul! (vec (vJ), jac_cache' , vec (v))
288
292
return
289
293
end
290
294
return vjp_op, vjp_extras
@@ -298,6 +302,8 @@ function prepare_vjp(::Val{false}, prob::AbstractNonlinearProblem,
298
302
# TODO : Once DI supports const params we can use `p`
299
303
fₚ = SciMLBase. JacobianWrapper {SciMLBase.isinplace(f)} (f, prob. p)
300
304
if SciMLBase. isinplace (f)
305
+ @assert DI. check_twoarg (autodiff) " Backend: $(autodiff) doesn't support in-place \
306
+ problems."
301
307
fu_cache = copy (fu)
302
308
v_fake = copy (fu)
303
309
di_extras = DI. prepare_pullback (fₚ, fu_cache, autodiff, u, v_fake)
@@ -326,11 +332,11 @@ function prepare_jvp(::Val{false}, prob::AbstractNonlinearProblem,
326
332
jac_cache = similar (u, eltype (fu), length (fu), length (u))
327
333
return @closure (Jv, v, u, p) -> begin
328
334
f. jac (jac_cache, u, p)
329
- mul! (vec (Jv), jac_cache, vec (v))
335
+ LinearAlgebra . mul! (vec (Jv), jac_cache, vec (v))
330
336
return
331
337
end
332
338
else
333
- return @closure (v, u, p, _ ) -> reshape (f. jac (u, p) * vec (v), size (u))
339
+ return @closure (v, u, p) -> reshape (f. jac (u, p) * vec (v), size (u))
334
340
end
335
341
end
336
342
@@ -339,6 +345,8 @@ function prepare_jvp(::Val{false}, prob::AbstractNonlinearProblem,
339
345
# TODO : Once DI supports const params we can use `p`
340
346
fₚ = SciMLBase. JacobianWrapper {SciMLBase.isinplace(f)} (f, prob. p)
341
347
if SciMLBase. isinplace (f)
348
+ @assert DI. check_twoarg (autodiff) " Backend: $(autodiff) doesn't support in-place \
349
+ problems."
342
350
fu_cache = copy (fu)
343
351
di_extras = DI. prepare_pushforward (fₚ, fu_cache, autodiff, u, u)
344
352
return @closure (Jv, v, u, p) -> begin
@@ -375,6 +383,12 @@ function get_dense_ad(ad::AutoSparse)
375
383
return dense_ad
376
384
end
377
385
386
+ # In our case we know that it is safe to mark the function as const
387
+ set_function_as_const (ad) = ad
388
+ function set_function_as_const (ad:: AutoEnzyme{M, Nothing} ) where {M}
389
+ return AutoEnzyme (; ad. mode, function_annotation = EnzymeCore. Const)
390
+ end
391
+
378
392
export JacobianOperator, VecJacOperator, JacVecOperator
379
393
export StatefulJacobianOperator
380
394
export StatefulJacobianNormalFormOperator
0 commit comments