@@ -17,8 +17,8 @@ using Core: Vararg
1717 end
1818end
1919
20- function inner_grad (θ, bθ, f, p)
21- Enzyme. autodiff_deferred (Enzyme . Reverse ,
20+ function inner_grad (mode :: Mode , θ, bθ, f, p) where {Mode}
21+ Enzyme. autodiff (mode ,
2222 Const (firstapply),
2323 Active,
2424 Const (f),
@@ -28,19 +28,9 @@ function inner_grad(θ, bθ, f, p)
2828 return nothing
2929end
3030
31- function inner_grad_primal (θ, bθ, f, p)
32- Enzyme. autodiff_deferred (Enzyme. ReverseWithPrimal,
33- Const (firstapply),
34- Active,
35- Const (f),
36- Enzyme. Duplicated (θ, bθ),
37- Const (p)
38- )[2 ]
39- end
40-
41- function hv_f2_alloc (x, f, p)
31+ function hv_f2_alloc (mode:: Mode , x, f, p) where {Mode}
4232 dx = Enzyme. make_zero (x)
43- Enzyme. autodiff_deferred (Enzyme . Reverse ,
33+ Enzyme. autodiff (mode ,
4434 Const (firstapply),
4535 Active,
4636 Const (f),
@@ -57,9 +47,9 @@ function inner_cons(x, fcons::Function, p::Union{SciMLBase.NullParameters, Nothi
5747 return res[i]
5848end
5949
60- function cons_f2 (x, dx, fcons, p, num_cons, i)
50+ function cons_f2 (mode, x, dx, fcons, p, num_cons, i)
6151 Enzyme. autodiff_deferred (
62- Enzyme . Reverse , Const (inner_cons), Active, Enzyme. Duplicated (x, dx),
52+ mode , Const (inner_cons), Active, Enzyme. Duplicated (x, dx),
6353 Const (fcons), Const (p), Const (num_cons), Const (i))
6454 return nothing
6555end
@@ -70,9 +60,9 @@ function inner_cons_oop(
7060 return fcons (x, p)[i]
7161end
7262
73- function cons_f2_oop (x, dx, fcons, p, i)
63+ function cons_f2_oop (mode, x, dx, fcons, p, i)
7464 Enzyme. autodiff_deferred (
75- Enzyme . Reverse , Const (inner_cons_oop), Active, Enzyme. Duplicated (x, dx),
65+ mode , Const (inner_cons_oop), Active, Enzyme. Duplicated (x, dx),
7666 Const (fcons), Const (p), Const (i))
7767 return nothing
7868end
@@ -83,22 +73,38 @@ function lagrangian(x, _f::Function, cons::Function, p, λ, σ = one(eltype(x)))
8373 return σ * _f (x, p) + dot (λ, res)
8474end
8575
86- function lag_grad (x, dx, lagrangian:: Function , _f:: Function , cons:: Function , p, σ, λ)
76+ function lag_grad (mode, x, dx, lagrangian:: Function , _f:: Function , cons:: Function , p, σ, λ)
8777 Enzyme. autodiff_deferred (
88- Enzyme . Reverse , Const (lagrangian), Active, Enzyme. Duplicated (x, dx),
78+ mode , Const (lagrangian), Active, Enzyme. Duplicated (x, dx),
8979 Const (_f), Const (cons), Const (p), Const (λ), Const (σ))
9080 return nothing
9181end
9282
83+ function set_runtime_activity2 (
84+ a:: Mode1 , :: Enzyme.Mode{ABI, Err, RTA} ) where {Mode1, ABI, Err, RTA}
85+ Enzyme. set_runtime_activity (a, RTA)
86+ end
9387function OptimizationBase. instantiate_function (f:: OptimizationFunction{true} , x,
9488 adtype:: AutoEnzyme , p, num_cons = 0 ;
9589 g = false , h = false , hv = false , fg = false , fgh = false ,
9690 cons_j = false , cons_vjp = false , cons_jvp = false , cons_h = false ,
9791 lag_h = false )
92+ rmode = if adtype. mode isa Nothing
93+ Enzyme. Reverse
94+ else
95+ set_runtime_activity2 (Enzyme. Reverse, adtype. mode)
96+ end
97+
98+ fmode = if adtype. mode isa Nothing
99+ Enzyme. Forward
100+ else
101+ set_runtime_activity2 (Enzyme. Forward, adtype. mode)
102+ end
103+
98104 if g == true && f. grad === nothing
99105 function grad (res, θ, p = p)
100106 Enzyme. make_zero! (res)
101- Enzyme. autodiff (Enzyme . Reverse ,
107+ Enzyme. autodiff (rmode ,
102108 Const (firstapply),
103109 Active,
104110 Const (f. f),
@@ -115,7 +121,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
115121 if fg == true && f. fg === nothing
116122 function fg! (res, θ, p = p)
117123 Enzyme. make_zero! (res)
118- y = Enzyme. autodiff (Enzyme . ReverseWithPrimal ,
124+ y = Enzyme. autodiff (WithPrimal (rmode) ,
119125 Const (firstapply),
120126 Active,
121127 Const (f. f),
@@ -145,8 +151,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
145151 Enzyme. make_zero! (bθ)
146152 Enzyme. make_zero! .(vdbθ)
147153
148- Enzyme. autodiff (Enzyme . Forward ,
154+ Enzyme. autodiff (fmode ,
149155 inner_grad,
156+ Const (rmode),
150157 Enzyme. BatchDuplicated (θ, vdθ),
151158 Enzyme. BatchDuplicatedNoNeed (bθ, vdbθ),
152159 Const (f. f),
@@ -168,8 +175,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
168175 vdθ = Tuple ((Array (r) for r in eachrow (I (length (θ)) * one (eltype (θ)))))
169176 vdbθ = Tuple (zeros (eltype (θ), length (θ)) for i in eachindex (θ))
170177
171- Enzyme. autodiff (Enzyme . Forward ,
178+ Enzyme. autodiff (fmode ,
172179 inner_grad,
180+ Const (rmode),
173181 Enzyme. BatchDuplicated (θ, vdθ),
174182 Enzyme. BatchDuplicatedNoNeed (G, vdbθ),
175183 Const (f. f),
@@ -189,7 +197,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
189197 if hv == true && f. hv === nothing
190198 function hv! (H, θ, v, p = p)
191199 H .= Enzyme. autodiff (
192- Enzyme . Forward , hv_f2_alloc, Duplicated (θ, v),
200+ fmode , hv_f2_alloc, Const (rmode) , Duplicated (θ, v),
193201 Const (f. f), Const (p)
194202 )[1 ]
195203 end
@@ -221,7 +229,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
221229 Enzyme. make_zero! (Jaccache[i])
222230 end
223231 Enzyme. make_zero! (y)
224- Enzyme. autodiff (Enzyme . Forward , f. cons, BatchDuplicated (y, Jaccache),
232+ Enzyme. autodiff (fmode , f. cons, BatchDuplicated (y, Jaccache),
225233 BatchDuplicated (θ, seeds), Const (p))
226234 for i in eachindex (θ)
227235 if J isa Vector
@@ -254,7 +262,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
254262 Enzyme. make_zero! (res)
255263 Enzyme. make_zero! (cons_res)
256264
257- Enzyme. autodiff (Enzyme . Reverse ,
265+ Enzyme. autodiff (rmode ,
258266 f. cons,
259267 Const,
260268 Duplicated (cons_res, v),
@@ -275,7 +283,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
275283 Enzyme. make_zero! (res)
276284 Enzyme. make_zero! (cons_res)
277285
278- Enzyme. autodiff (Enzyme . Forward ,
286+ Enzyme. autodiff (fmode ,
279287 f. cons,
280288 Duplicated (cons_res, res),
281289 Duplicated (θ, v),
@@ -297,8 +305,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
297305 for i in 1 : num_cons
298306 Enzyme. make_zero! (cons_bθ)
299307 Enzyme. make_zero! .(cons_vdbθ)
300- Enzyme. autodiff (Enzyme . Forward ,
308+ Enzyme. autodiff (fmode ,
301309 cons_f2,
310+ Const (rmode),
302311 Enzyme. BatchDuplicated (θ, cons_vdθ),
303312 Enzyme. BatchDuplicated (cons_bθ, cons_vdbθ),
304313 Const (f. cons),
@@ -332,8 +341,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
332341 Enzyme. make_zero! (lag_bθ)
333342 Enzyme. make_zero! .(lag_vdbθ)
334343
335- Enzyme. autodiff (Enzyme . Forward ,
344+ Enzyme. autodiff (fmode ,
336345 lag_grad,
346+ Const (rmode),
337347 Enzyme. BatchDuplicated (θ, lag_vdθ),
338348 Enzyme. BatchDuplicatedNoNeed (lag_bθ, lag_vdbθ),
339349 Const (lagrangian),
@@ -357,8 +367,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
357367 Enzyme. make_zero! (lag_bθ)
358368 Enzyme. make_zero! .(lag_vdbθ)
359369
360- Enzyme. autodiff (Enzyme . Forward ,
370+ Enzyme. autodiff (fmode ,
361371 lag_grad,
372+ Const (rmode),
362373 Enzyme. BatchDuplicated (θ, lag_vdθ),
363374 Enzyme. BatchDuplicatedNoNeed (lag_bθ, lag_vdbθ),
364375 Const (lagrangian),
@@ -410,11 +421,23 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
410421 g = false , h = false , hv = false , fg = false , fgh = false ,
411422 cons_j = false , cons_vjp = false , cons_jvp = false , cons_h = false ,
412423 lag_h = false )
424+ rmode = if adtype. mode isa Nothing
425+ Enzyme. Reverse
426+ else
427+ set_runtime_activity2 (Enzyme. Reverse, adtype. mode)
428+ end
429+
430+ fmode = if adtype. mode isa Nothing
431+ Enzyme. Forward
432+ else
433+ set_runtime_activity2 (Enzyme. Forward, adtype. mode)
434+ end
435+
413436 if g == true && f. grad === nothing
414437 res = zeros (eltype (x), size (x))
415438 function grad (θ, p = p)
416439 Enzyme. make_zero! (res)
417- Enzyme. autodiff (Enzyme . Reverse ,
440+ Enzyme. autodiff (rmode ,
418441 Const (firstapply),
419442 Active,
420443 Const (f. f),
@@ -433,7 +456,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
433456 res_fg = zeros (eltype (x), size (x))
434457 function fg! (θ, p = p)
435458 Enzyme. make_zero! (res_fg)
436- y = Enzyme. autodiff (Enzyme . ReverseWithPrimal ,
459+ y = Enzyme. autodiff (WithPrimal (rmode) ,
437460 Const (firstapply),
438461 Active,
439462 Const (f. f),
@@ -457,8 +480,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
457480 Enzyme. make_zero! (bθ)
458481 Enzyme. make_zero! .(vdbθ)
459482
460- Enzyme. autodiff (Enzyme . Forward ,
483+ Enzyme. autodiff (fmode ,
461484 inner_grad,
485+ Const (rmode),
462486 Enzyme. BatchDuplicated (θ, vdθ),
463487 Enzyme. BatchDuplicated (bθ, vdbθ),
464488 Const (f. f),
@@ -485,8 +509,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
485509 Enzyme. make_zero! (H_fgh)
486510 Enzyme. make_zero! .(vdbθ_fgh)
487511
488- Enzyme. autodiff (Enzyme . Forward ,
512+ Enzyme. autodiff (fmode ,
489513 inner_grad,
514+ Const (rmode),
490515 Enzyme. BatchDuplicated (θ, vdθ_fgh),
491516 Enzyme. BatchDuplicatedNoNeed (G_fgh, vdbθ_fgh),
492517 Const (f. f),
@@ -507,7 +532,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
507532 if hv == true && f. hv === nothing
508533 function hv! (θ, v, p = p)
509534 return Enzyme. autodiff (
510- Enzyme . Forward , hv_f2_alloc, DuplicatedNoNeed, Duplicated (θ, v),
535+ fmode , hv_f2_alloc, DuplicatedNoNeed, Const (rmode) , Duplicated (θ, v),
511536 Const (_f), Const (f. f), Const (p)
512537 )[1 ]
513538 end
@@ -533,7 +558,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
533558 for i in eachindex (Jaccache)
534559 Enzyme. make_zero! (Jaccache[i])
535560 end
536- Jaccache, y = Enzyme. autodiff (Enzyme . ForwardWithPrimal , f. cons, Duplicated,
561+ Jaccache, y = Enzyme. autodiff (WithPrimal (fmode) , f. cons, Duplicated,
537562 BatchDuplicated (θ, seeds), Const (p))
538563 if size (y, 1 ) == 1
539564 return reduce (vcat, Jaccache)
@@ -555,7 +580,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
555580 Enzyme. make_zero! (res_vjp)
556581 Enzyme. make_zero! (cons_vjp_res)
557582
558- Enzyme. autodiff (Enzyme . Reverse ,
583+ Enzyme. autodiff (WithPrimal (rmode) ,
559584 f. cons,
560585 Const,
561586 Duplicated (cons_vjp_res, v),
@@ -578,7 +603,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
578603 Enzyme. make_zero! (res_jvp)
579604 Enzyme. make_zero! (cons_jvp_res)
580605
581- Enzyme. autodiff (Enzyme . Forward ,
606+ Enzyme. autodiff (fmode ,
582607 f. cons,
583608 Duplicated (cons_jvp_res, res_jvp),
584609 Duplicated (θ, v),
@@ -601,8 +626,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
601626 return map (1 : num_cons) do i
602627 Enzyme. make_zero! (cons_bθ)
603628 Enzyme. make_zero! .(cons_vdbθ)
604- Enzyme. autodiff (Enzyme . Forward ,
629+ Enzyme. autodiff (fmode ,
605630 cons_f2_oop,
631+ Const (rmode),
606632 Enzyme. BatchDuplicated (θ, cons_vdθ),
607633 Enzyme. BatchDuplicated (cons_bθ, cons_vdbθ),
608634 Const (f. cons),
@@ -631,8 +657,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
631657 Enzyme. make_zero! (lag_bθ)
632658 Enzyme. make_zero! .(lag_vdbθ)
633659
634- Enzyme. autodiff (Enzyme . Forward ,
660+ Enzyme. autodiff (fmode ,
635661 lag_grad,
662+ Const (rmode),
636663 Enzyme. BatchDuplicated (θ, lag_vdθ),
637664 Enzyme. BatchDuplicatedNoNeed (lag_bθ, lag_vdbθ),
638665 Const (lagrangian),
0 commit comments