Skip to content
This repository was archived by the owner on Aug 25, 2025. It is now read-only.

Commit 917b21b

Browse files
committed
Enzyme: add runtime activity
1 parent 6043b0c commit 917b21b

File tree

1 file changed

+67
-41
lines changed

1 file changed

+67
-41
lines changed

ext/OptimizationEnzymeExt.jl

Lines changed: 67 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ using Core: Vararg
1717
end
1818
end
1919

20-
function inner_grad(θ, bθ, f, p)
21-
Enzyme.autodiff_deferred(Enzyme.Reverse,
20+
function inner_grad(mode::Mode, θ, bθ, f, p) where Mode
21+
Enzyme.autodiff(Mode,
2222
Const(firstapply),
2323
Active,
2424
Const(f),
@@ -28,19 +28,9 @@ function inner_grad(θ, bθ, f, p)
2828
return nothing
2929
end
3030

31-
function inner_grad_primal(θ, bθ, f, p)
32-
Enzyme.autodiff_deferred(Enzyme.ReverseWithPrimal,
33-
Const(firstapply),
34-
Active,
35-
Const(f),
36-
Enzyme.Duplicated(θ, bθ),
37-
Const(p)
38-
)[2]
39-
end
40-
41-
function hv_f2_alloc(x, f, p)
31+
function hv_f2_alloc(mode::Mode, x, f, p) where Mode
4232
dx = Enzyme.make_zero(x)
43-
Enzyme.autodiff_deferred(Enzyme.Reverse,
33+
Enzyme.autodiff(mode,
4434
Const(firstapply),
4535
Active,
4636
Const(f),
@@ -57,9 +47,9 @@ function inner_cons(x, fcons::Function, p::Union{SciMLBase.NullParameters, Nothi
5747
return res[i]
5848
end
5949

60-
function cons_f2(x, dx, fcons, p, num_cons, i)
50+
function cons_f2(mode, x, dx, fcons, p, num_cons, i)
6151
Enzyme.autodiff_deferred(
62-
Enzyme.Reverse, Const(inner_cons), Active, Enzyme.Duplicated(x, dx),
52+
mode, Const(inner_cons), Active, Enzyme.Duplicated(x, dx),
6353
Const(fcons), Const(p), Const(num_cons), Const(i))
6454
return nothing
6555
end
@@ -70,9 +60,9 @@ function inner_cons_oop(
7060
return fcons(x, p)[i]
7161
end
7262

73-
function cons_f2_oop(x, dx, fcons, p, i)
63+
function cons_f2_oop(mode, x, dx, fcons, p, i)
7464
Enzyme.autodiff_deferred(
75-
Enzyme.Reverse, Const(inner_cons_oop), Active, Enzyme.Duplicated(x, dx),
65+
mode, Const(inner_cons_oop), Active, Enzyme.Duplicated(x, dx),
7666
Const(fcons), Const(p), Const(i))
7767
return nothing
7868
end
@@ -83,22 +73,37 @@ function lagrangian(x, _f::Function, cons::Function, p, λ, σ = one(eltype(x)))
8373
return σ * _f(x, p) + dot(λ, res)
8474
end
8575

86-
function lag_grad(x, dx, lagrangian::Function, _f::Function, cons::Function, p, σ, λ)
76+
function lag_grad(mode, x, dx, lagrangian::Function, _f::Function, cons::Function, p, σ, λ)
8777
Enzyme.autodiff_deferred(
88-
Enzyme.Reverse, Const(lagrangian), Active, Enzyme.Duplicated(x, dx),
78+
mode, Const(lagrangian), Active, Enzyme.Duplicated(x, dx),
8979
Const(_f), Const(cons), Const(p), Const(λ), Const(σ))
9080
return nothing
9181
end
9282

83+
84+
set_runtime_activity2(a::Mode1, ::Enzyme.Mode{ABI, Err, RTA}) where {Mode1, ABI, Err, RTA} = Enzyme.set_runtime_activity(a, RTA)
9385
function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
9486
adtype::AutoEnzyme, p, num_cons = 0;
9587
g = false, h = false, hv = false, fg = false, fgh = false,
9688
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
9789
lag_h = false)
90+
91+
rmode = if adtype.mode isa Nothing
92+
Enzyme.Reverse
93+
else
94+
set_runtime_activity2(Enzyme.Reverse)
95+
end
96+
97+
fmode = if adtype.mode isa Nothing
98+
Enzyme.Forward
99+
else
100+
set_runtime_activity2(Enzyme.Forward)
101+
end
102+
98103
if g == true && f.grad === nothing
99104
function grad(res, θ, p = p)
100105
Enzyme.make_zero!(res)
101-
Enzyme.autodiff(Enzyme.Reverse,
106+
Enzyme.autodiff(rmode,
102107
Const(firstapply),
103108
Active,
104109
Const(f.f),
@@ -115,7 +120,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
115120
if fg == true && f.fg === nothing
116121
function fg!(res, θ, p = p)
117122
Enzyme.make_zero!(res)
118-
y = Enzyme.autodiff(Enzyme.ReverseWithPrimal,
123+
y = Enzyme.autodiff(WithPrimal(rmode),
119124
Const(firstapply),
120125
Active,
121126
Const(f.f),
@@ -145,8 +150,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
145150
Enzyme.make_zero!(bθ)
146151
Enzyme.make_zero!.(vdbθ)
147152

148-
Enzyme.autodiff(Enzyme.Forward,
153+
Enzyme.autodiff(fmode,
149154
inner_grad,
155+
Const(rmode),
150156
Enzyme.BatchDuplicated(θ, vdθ),
151157
Enzyme.BatchDuplicatedNoNeed(bθ, vdbθ),
152158
Const(f.f),
@@ -168,8 +174,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
168174
vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * one(eltype(θ)))))
169175
vdbθ = Tuple(zeros(eltype(θ), length(θ)) for i in eachindex(θ))
170176

171-
Enzyme.autodiff(Enzyme.Forward,
177+
Enzyme.autodiff(fmode,
172178
inner_grad,
179+
Const(rmode)
173180
Enzyme.BatchDuplicated(θ, vdθ),
174181
Enzyme.BatchDuplicatedNoNeed(G, vdbθ),
175182
Const(f.f),
@@ -189,7 +196,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
189196
if hv == true && f.hv === nothing
190197
function hv!(H, θ, v, p = p)
191198
H .= Enzyme.autodiff(
192-
Enzyme.Forward, hv_f2_alloc, Duplicated(θ, v),
199+
fmode, hv_f2_alloc, Const(rmode), Duplicated(θ, v),
193200
Const(f.f), Const(p)
194201
)[1]
195202
end
@@ -221,7 +228,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
221228
Enzyme.make_zero!(Jaccache[i])
222229
end
223230
Enzyme.make_zero!(y)
224-
Enzyme.autodiff(Enzyme.Forward, f.cons, BatchDuplicated(y, Jaccache),
231+
Enzyme.autodiff(fmode, f.cons, BatchDuplicated(y, Jaccache),
225232
BatchDuplicated(θ, seeds), Const(p))
226233
for i in eachindex(θ)
227234
if J isa Vector
@@ -254,7 +261,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
254261
Enzyme.make_zero!(res)
255262
Enzyme.make_zero!(cons_res)
256263

257-
Enzyme.autodiff(Enzyme.Reverse,
264+
Enzyme.autodiff(rmode,
258265
f.cons,
259266
Const,
260267
Duplicated(cons_res, v),
@@ -275,7 +282,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
275282
Enzyme.make_zero!(res)
276283
Enzyme.make_zero!(cons_res)
277284

278-
Enzyme.autodiff(Enzyme.Forward,
285+
Enzyme.autodiff(fmode,
279286
f.cons,
280287
Duplicated(cons_res, res),
281288
Duplicated(θ, v),
@@ -297,8 +304,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
297304
for i in 1:num_cons
298305
Enzyme.make_zero!(cons_bθ)
299306
Enzyme.make_zero!.(cons_vdbθ)
300-
Enzyme.autodiff(Enzyme.Forward,
307+
Enzyme.autodiff(fmode,
301308
cons_f2,
309+
Const(rmode),
302310
Enzyme.BatchDuplicated(θ, cons_vdθ),
303311
Enzyme.BatchDuplicated(cons_bθ, cons_vdbθ),
304312
Const(f.cons),
@@ -332,8 +340,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
332340
Enzyme.make_zero!(lag_bθ)
333341
Enzyme.make_zero!.(lag_vdbθ)
334342

335-
Enzyme.autodiff(Enzyme.Forward,
343+
Enzyme.autodiff(fmode,
336344
lag_grad,
345+
Const(rmode),
337346
Enzyme.BatchDuplicated(θ, lag_vdθ),
338347
Enzyme.BatchDuplicatedNoNeed(lag_bθ, lag_vdbθ),
339348
Const(lagrangian),
@@ -357,8 +366,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
357366
Enzyme.make_zero!(lag_bθ)
358367
Enzyme.make_zero!.(lag_vdbθ)
359368

360-
Enzyme.autodiff(Enzyme.Forward,
369+
Enzyme.autodiff(fmode,
361370
lag_grad,
371+
Const(rmode),
362372
Enzyme.BatchDuplicated(θ, lag_vdθ),
363373
Enzyme.BatchDuplicatedNoNeed(lag_bθ, lag_vdbθ),
364374
Const(lagrangian),
@@ -410,11 +420,23 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
410420
g = false, h = false, hv = false, fg = false, fgh = false,
411421
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
412422
lag_h = false)
423+
rmode = if adtype.mode isa Nothing
424+
Enzyme.Reverse
425+
else
426+
set_runtime_activity2(Enzyme.Reverse)
427+
end
428+
429+
fmode = if adtype.mode isa Nothing
430+
Enzyme.Forward
431+
else
432+
set_runtime_activity2(Enzyme.Forward)
433+
end
434+
413435
if g == true && f.grad === nothing
414436
res = zeros(eltype(x), size(x))
415437
function grad(θ, p = p)
416438
Enzyme.make_zero!(res)
417-
Enzyme.autodiff(Enzyme.Reverse,
439+
Enzyme.autodiff(rmode,
418440
Const(firstapply),
419441
Active,
420442
Const(f.f),
@@ -433,7 +455,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
433455
res_fg = zeros(eltype(x), size(x))
434456
function fg!(θ, p = p)
435457
Enzyme.make_zero!(res_fg)
436-
y = Enzyme.autodiff(Enzyme.ReverseWithPrimal,
458+
y = Enzyme.autodiff(WithPrimal(rmode),
437459
Const(firstapply),
438460
Active,
439461
Const(f.f),
@@ -457,8 +479,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
457479
Enzyme.make_zero!(bθ)
458480
Enzyme.make_zero!.(vdbθ)
459481

460-
Enzyme.autodiff(Enzyme.Forward,
482+
Enzyme.autodiff(fmode,
461483
inner_grad,
484+
Const(rmode),
462485
Enzyme.BatchDuplicated(θ, vdθ),
463486
Enzyme.BatchDuplicated(bθ, vdbθ),
464487
Const(f.f),
@@ -485,8 +508,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
485508
Enzyme.make_zero!(H_fgh)
486509
Enzyme.make_zero!.(vdbθ_fgh)
487510

488-
Enzyme.autodiff(Enzyme.Forward,
511+
Enzyme.autodiff(fmode,
489512
inner_grad,
513+
Const(rmode),
490514
Enzyme.BatchDuplicated(θ, vdθ_fgh),
491515
Enzyme.BatchDuplicatedNoNeed(G_fgh, vdbθ_fgh),
492516
Const(f.f),
@@ -507,7 +531,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
507531
if hv == true && f.hv === nothing
508532
function hv!(θ, v, p = p)
509533
return Enzyme.autodiff(
510-
Enzyme.Forward, hv_f2_alloc, DuplicatedNoNeed, Duplicated(θ, v),
534+
fmode, hv_f2_alloc, DuplicatedNoNeed, Const(rmode), Duplicated(θ, v),
511535
Const(_f), Const(f.f), Const(p)
512536
)[1]
513537
end
@@ -533,7 +557,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
533557
for i in eachindex(Jaccache)
534558
Enzyme.make_zero!(Jaccache[i])
535559
end
536-
Jaccache, y = Enzyme.autodiff(Enzyme.ForwardWithPrimal, f.cons, Duplicated,
560+
Jaccache, y = Enzyme.autodiff(WithPrimal(fmode), f.cons, Duplicated,
537561
BatchDuplicated(θ, seeds), Const(p))
538562
if size(y, 1) == 1
539563
return reduce(vcat, Jaccache)
@@ -555,7 +579,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
555579
Enzyme.make_zero!(res_vjp)
556580
Enzyme.make_zero!(cons_vjp_res)
557581

558-
Enzyme.autodiff(Enzyme.Reverse,
582+
Enzyme.autodiff(WithPrimal(rmode),
559583
f.cons,
560584
Const,
561585
Duplicated(cons_vjp_res, v),
@@ -578,7 +602,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
578602
Enzyme.make_zero!(res_jvp)
579603
Enzyme.make_zero!(cons_jvp_res)
580604

581-
Enzyme.autodiff(Enzyme.Forward,
605+
Enzyme.autodiff(fmode,
582606
f.cons,
583607
Duplicated(cons_jvp_res, res_jvp),
584608
Duplicated(θ, v),
@@ -601,8 +625,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
601625
return map(1:num_cons) do i
602626
Enzyme.make_zero!(cons_bθ)
603627
Enzyme.make_zero!.(cons_vdbθ)
604-
Enzyme.autodiff(Enzyme.Forward,
628+
Enzyme.autodiff(fmode,
605629
cons_f2_oop,
630+
Const(rmode),
606631
Enzyme.BatchDuplicated(θ, cons_vdθ),
607632
Enzyme.BatchDuplicated(cons_bθ, cons_vdbθ),
608633
Const(f.cons),
@@ -631,8 +656,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
631656
Enzyme.make_zero!(lag_bθ)
632657
Enzyme.make_zero!.(lag_vdbθ)
633658

634-
Enzyme.autodiff(Enzyme.Forward,
659+
Enzyme.autodiff(fmode,
635660
lag_grad,
661+
Const(rmode),
636662
Enzyme.BatchDuplicated(θ, lag_vdθ),
637663
Enzyme.BatchDuplicatedNoNeed(lag_bθ, lag_vdbθ),
638664
Const(lagrangian),

0 commit comments

Comments
 (0)