@@ -17,8 +17,8 @@ using Core: Vararg
1717 end
1818end
1919
20- function inner_grad (θ, bθ, f, p)
21- Enzyme. autodiff_deferred (Enzyme . Reverse ,
20+ function inner_grad (mode :: Mode , θ, bθ, f, p) where Mode
21+ Enzyme. autodiff (Mode ,
2222 Const (firstapply),
2323 Active,
2424 Const (f),
@@ -28,19 +28,9 @@ function inner_grad(θ, bθ, f, p)
2828 return nothing
2929end
3030
31- function inner_grad_primal (θ, bθ, f, p)
32- Enzyme. autodiff_deferred (Enzyme. ReverseWithPrimal,
33- Const (firstapply),
34- Active,
35- Const (f),
36- Enzyme. Duplicated (θ, bθ),
37- Const (p)
38- )[2 ]
39- end
40-
41- function hv_f2_alloc (x, f, p)
31+ function hv_f2_alloc (mode:: Mode , x, f, p) where Mode
4232 dx = Enzyme. make_zero (x)
43- Enzyme. autodiff_deferred (Enzyme . Reverse ,
33+ Enzyme. autodiff (mode ,
4434 Const (firstapply),
4535 Active,
4636 Const (f),
@@ -57,9 +47,9 @@ function inner_cons(x, fcons::Function, p::Union{SciMLBase.NullParameters, Nothi
5747 return res[i]
5848end
5949
60- function cons_f2 (x, dx, fcons, p, num_cons, i)
50+ function cons_f2 (mode, x, dx, fcons, p, num_cons, i)
6151 Enzyme. autodiff_deferred (
62- Enzyme . Reverse , Const (inner_cons), Active, Enzyme. Duplicated (x, dx),
52+ mode , Const (inner_cons), Active, Enzyme. Duplicated (x, dx),
6353 Const (fcons), Const (p), Const (num_cons), Const (i))
6454 return nothing
6555end
@@ -70,9 +60,9 @@ function inner_cons_oop(
7060 return fcons (x, p)[i]
7161end
7262
73- function cons_f2_oop (x, dx, fcons, p, i)
63+ function cons_f2_oop (mode, x, dx, fcons, p, i)
7464 Enzyme. autodiff_deferred (
75- Enzyme . Reverse , Const (inner_cons_oop), Active, Enzyme. Duplicated (x, dx),
65+ mode , Const (inner_cons_oop), Active, Enzyme. Duplicated (x, dx),
7666 Const (fcons), Const (p), Const (i))
7767 return nothing
7868end
@@ -83,22 +73,37 @@ function lagrangian(x, _f::Function, cons::Function, p, λ, σ = one(eltype(x)))
8373 return σ * _f (x, p) + dot (λ, res)
8474end
8575
86- function lag_grad (x, dx, lagrangian:: Function , _f:: Function , cons:: Function , p, σ, λ)
76+ function lag_grad (mode, x, dx, lagrangian:: Function , _f:: Function , cons:: Function , p, σ, λ)
8777 Enzyme. autodiff_deferred (
88- Enzyme . Reverse , Const (lagrangian), Active, Enzyme. Duplicated (x, dx),
78+ mode , Const (lagrangian), Active, Enzyme. Duplicated (x, dx),
8979 Const (_f), Const (cons), Const (p), Const (λ), Const (σ))
9080 return nothing
9181end
9282
83+
84+ set_runtime_activity2 (a:: Mode1 , :: Enzyme.Mode{ABI, Err, RTA} ) where {Mode1, ABI, Err, RTA} = Enzyme. set_runtime_activity (a, RTA)
9385function OptimizationBase. instantiate_function (f:: OptimizationFunction{true} , x,
9486 adtype:: AutoEnzyme , p, num_cons = 0 ;
9587 g = false , h = false , hv = false , fg = false , fgh = false ,
9688 cons_j = false , cons_vjp = false , cons_jvp = false , cons_h = false ,
9789 lag_h = false )
90+
91+ rmode = if adtype. mode isa Nothing
92+ Enzyme. Reverse
93+ else
94+ set_runtime_activity2 (Enzyme. Reverse)
95+ end
96+
97+ fmode = if adtype. mode isa Nothing
98+ Enzyme. Forward
99+ else
100+ set_runtime_activity2 (Enzyme. Forward)
101+ end
102+
98103 if g == true && f. grad === nothing
99104 function grad (res, θ, p = p)
100105 Enzyme. make_zero! (res)
101- Enzyme. autodiff (Enzyme . Reverse ,
106+ Enzyme. autodiff (rmode ,
102107 Const (firstapply),
103108 Active,
104109 Const (f. f),
@@ -115,7 +120,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
115120 if fg == true && f. fg === nothing
116121 function fg! (res, θ, p = p)
117122 Enzyme. make_zero! (res)
118- y = Enzyme. autodiff (Enzyme . ReverseWithPrimal ,
123+ y = Enzyme. autodiff (WithPrimal (rmode) ,
119124 Const (firstapply),
120125 Active,
121126 Const (f. f),
@@ -145,8 +150,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
145150 Enzyme. make_zero! (bθ)
146151 Enzyme. make_zero! .(vdbθ)
147152
148- Enzyme. autodiff (Enzyme . Forward ,
153+ Enzyme. autodiff (fmode ,
149154 inner_grad,
155+ Const (rmode),
150156 Enzyme. BatchDuplicated (θ, vdθ),
151157 Enzyme. BatchDuplicatedNoNeed (bθ, vdbθ),
152158 Const (f. f),
@@ -168,8 +174,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
168174 vdθ = Tuple ((Array (r) for r in eachrow (I (length (θ)) * one (eltype (θ)))))
169175 vdbθ = Tuple (zeros (eltype (θ), length (θ)) for i in eachindex (θ))
170176
171- Enzyme. autodiff (Enzyme . Forward ,
177+ Enzyme. autodiff (fmode ,
172178 inner_grad,
179+ Const (rmode)
173180 Enzyme. BatchDuplicated (θ, vdθ),
174181 Enzyme. BatchDuplicatedNoNeed (G, vdbθ),
175182 Const (f. f),
@@ -189,7 +196,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
189196 if hv == true && f. hv === nothing
190197 function hv! (H, θ, v, p = p)
191198 H .= Enzyme. autodiff (
192- Enzyme . Forward , hv_f2_alloc, Duplicated (θ, v),
199+ fmode , hv_f2_alloc, Const (rmode) , Duplicated (θ, v),
193200 Const (f. f), Const (p)
194201 )[1 ]
195202 end
@@ -221,7 +228,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
221228 Enzyme. make_zero! (Jaccache[i])
222229 end
223230 Enzyme. make_zero! (y)
224- Enzyme. autodiff (Enzyme . Forward , f. cons, BatchDuplicated (y, Jaccache),
231+ Enzyme. autodiff (fmode , f. cons, BatchDuplicated (y, Jaccache),
225232 BatchDuplicated (θ, seeds), Const (p))
226233 for i in eachindex (θ)
227234 if J isa Vector
@@ -254,7 +261,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
254261 Enzyme. make_zero! (res)
255262 Enzyme. make_zero! (cons_res)
256263
257- Enzyme. autodiff (Enzyme . Reverse ,
264+ Enzyme. autodiff (rmode ,
258265 f. cons,
259266 Const,
260267 Duplicated (cons_res, v),
@@ -275,7 +282,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
275282 Enzyme. make_zero! (res)
276283 Enzyme. make_zero! (cons_res)
277284
278- Enzyme. autodiff (Enzyme . Forward ,
285+ Enzyme. autodiff (fmode ,
279286 f. cons,
280287 Duplicated (cons_res, res),
281288 Duplicated (θ, v),
@@ -297,8 +304,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
297304 for i in 1 : num_cons
298305 Enzyme. make_zero! (cons_bθ)
299306 Enzyme. make_zero! .(cons_vdbθ)
300- Enzyme. autodiff (Enzyme . Forward ,
307+ Enzyme. autodiff (fmode ,
301308 cons_f2,
309+ Const (rmode),
302310 Enzyme. BatchDuplicated (θ, cons_vdθ),
303311 Enzyme. BatchDuplicated (cons_bθ, cons_vdbθ),
304312 Const (f. cons),
@@ -332,8 +340,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
332340 Enzyme. make_zero! (lag_bθ)
333341 Enzyme. make_zero! .(lag_vdbθ)
334342
335- Enzyme. autodiff (Enzyme . Forward ,
343+ Enzyme. autodiff (fmode ,
336344 lag_grad,
345+ Const (rmode),
337346 Enzyme. BatchDuplicated (θ, lag_vdθ),
338347 Enzyme. BatchDuplicatedNoNeed (lag_bθ, lag_vdbθ),
339348 Const (lagrangian),
@@ -357,8 +366,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
357366 Enzyme. make_zero! (lag_bθ)
358367 Enzyme. make_zero! .(lag_vdbθ)
359368
360- Enzyme. autodiff (Enzyme . Forward ,
369+ Enzyme. autodiff (fmode ,
361370 lag_grad,
371+ Const (rmode),
362372 Enzyme. BatchDuplicated (θ, lag_vdθ),
363373 Enzyme. BatchDuplicatedNoNeed (lag_bθ, lag_vdbθ),
364374 Const (lagrangian),
@@ -410,11 +420,23 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
410420 g = false , h = false , hv = false , fg = false , fgh = false ,
411421 cons_j = false , cons_vjp = false , cons_jvp = false , cons_h = false ,
412422 lag_h = false )
423+ rmode = if adtype. mode isa Nothing
424+ Enzyme. Reverse
425+ else
426+ set_runtime_activity2 (Enzyme. Reverse)
427+ end
428+
429+ fmode = if adtype. mode isa Nothing
430+ Enzyme. Forward
431+ else
432+ set_runtime_activity2 (Enzyme. Forward)
433+ end
434+
413435 if g == true && f. grad === nothing
414436 res = zeros (eltype (x), size (x))
415437 function grad (θ, p = p)
416438 Enzyme. make_zero! (res)
417- Enzyme. autodiff (Enzyme . Reverse ,
439+ Enzyme. autodiff (rmode ,
418440 Const (firstapply),
419441 Active,
420442 Const (f. f),
@@ -433,7 +455,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
433455 res_fg = zeros (eltype (x), size (x))
434456 function fg! (θ, p = p)
435457 Enzyme. make_zero! (res_fg)
436- y = Enzyme. autodiff (Enzyme . ReverseWithPrimal ,
458+ y = Enzyme. autodiff (WithPrimal (rmode) ,
437459 Const (firstapply),
438460 Active,
439461 Const (f. f),
@@ -457,8 +479,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
457479 Enzyme. make_zero! (bθ)
458480 Enzyme. make_zero! .(vdbθ)
459481
460- Enzyme. autodiff (Enzyme . Forward ,
482+ Enzyme. autodiff (fmode ,
461483 inner_grad,
484+ Const (rmode),
462485 Enzyme. BatchDuplicated (θ, vdθ),
463486 Enzyme. BatchDuplicated (bθ, vdbθ),
464487 Const (f. f),
@@ -485,8 +508,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
485508 Enzyme. make_zero! (H_fgh)
486509 Enzyme. make_zero! .(vdbθ_fgh)
487510
488- Enzyme. autodiff (Enzyme . Forward ,
511+ Enzyme. autodiff (fmode ,
489512 inner_grad,
513+ Const (rmode),
490514 Enzyme. BatchDuplicated (θ, vdθ_fgh),
491515 Enzyme. BatchDuplicatedNoNeed (G_fgh, vdbθ_fgh),
492516 Const (f. f),
@@ -507,7 +531,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
507531 if hv == true && f. hv === nothing
508532 function hv! (θ, v, p = p)
509533 return Enzyme. autodiff (
510- Enzyme . Forward , hv_f2_alloc, DuplicatedNoNeed, Duplicated (θ, v),
534+ fmode , hv_f2_alloc, DuplicatedNoNeed, Const (rmode) , Duplicated (θ, v),
511535 Const (_f), Const (f. f), Const (p)
512536 )[1 ]
513537 end
@@ -533,7 +557,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
533557 for i in eachindex (Jaccache)
534558 Enzyme. make_zero! (Jaccache[i])
535559 end
536- Jaccache, y = Enzyme. autodiff (Enzyme . ForwardWithPrimal , f. cons, Duplicated,
560+ Jaccache, y = Enzyme. autodiff (WithPrimal (fmode) , f. cons, Duplicated,
537561 BatchDuplicated (θ, seeds), Const (p))
538562 if size (y, 1 ) == 1
539563 return reduce (vcat, Jaccache)
@@ -555,7 +579,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
555579 Enzyme. make_zero! (res_vjp)
556580 Enzyme. make_zero! (cons_vjp_res)
557581
558- Enzyme. autodiff (Enzyme . Reverse ,
582+ Enzyme. autodiff (WithPrimal (rmode) ,
559583 f. cons,
560584 Const,
561585 Duplicated (cons_vjp_res, v),
@@ -578,7 +602,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
578602 Enzyme. make_zero! (res_jvp)
579603 Enzyme. make_zero! (cons_jvp_res)
580604
581- Enzyme. autodiff (Enzyme . Forward ,
605+ Enzyme. autodiff (fmode ,
582606 f. cons,
583607 Duplicated (cons_jvp_res, res_jvp),
584608 Duplicated (θ, v),
@@ -601,8 +625,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
601625 return map (1 : num_cons) do i
602626 Enzyme. make_zero! (cons_bθ)
603627 Enzyme. make_zero! .(cons_vdbθ)
604- Enzyme. autodiff (Enzyme . Forward ,
628+ Enzyme. autodiff (fmode ,
605629 cons_f2_oop,
630+ Const (rmode),
606631 Enzyme. BatchDuplicated (θ, cons_vdθ),
607632 Enzyme. BatchDuplicated (cons_bθ, cons_vdbθ),
608633 Const (f. cons),
@@ -631,8 +656,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
631656 Enzyme. make_zero! (lag_bθ)
632657 Enzyme. make_zero! .(lag_vdbθ)
633658
634- Enzyme. autodiff (Enzyme . Forward ,
659+ Enzyme. autodiff (fmode ,
635660 lag_grad,
661+ Const (rmode),
636662 Enzyme. BatchDuplicated (θ, lag_vdθ),
637663 Enzyme. BatchDuplicatedNoNeed (lag_bθ, lag_vdbθ),
638664 Const (lagrangian),
0 commit comments