Skip to content
This repository was archived by the owner on Aug 25, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 67 additions & 41 deletions ext/OptimizationEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ using Core: Vararg
end
end

function inner_grad(θ, bθ, f, p)
Enzyme.autodiff_deferred(Enzyme.Reverse,
function inner_grad(mode::Mode, θ, bθ, f, p) where Mode
Enzyme.autodiff(Mode,
Const(firstapply),
Active,
Const(f),
Expand All @@ -28,19 +28,9 @@ function inner_grad(θ, bθ, f, p)
return nothing
end

function inner_grad_primal(θ, bθ, f, p)
Enzyme.autodiff_deferred(Enzyme.ReverseWithPrimal,
Const(firstapply),
Active,
Const(f),
Enzyme.Duplicated(θ, bθ),
Const(p)
)[2]
end

function hv_f2_alloc(x, f, p)
function hv_f2_alloc(mode::Mode, x, f, p) where Mode
dx = Enzyme.make_zero(x)
Enzyme.autodiff_deferred(Enzyme.Reverse,
Enzyme.autodiff(mode,
Const(firstapply),
Active,
Const(f),
Expand All @@ -57,9 +47,9 @@ function inner_cons(x, fcons::Function, p::Union{SciMLBase.NullParameters, Nothi
return res[i]
end

function cons_f2(x, dx, fcons, p, num_cons, i)
function cons_f2(mode, x, dx, fcons, p, num_cons, i)
Enzyme.autodiff_deferred(
Enzyme.Reverse, Const(inner_cons), Active, Enzyme.Duplicated(x, dx),
mode, Const(inner_cons), Active, Enzyme.Duplicated(x, dx),
Const(fcons), Const(p), Const(num_cons), Const(i))
return nothing
end
Expand All @@ -70,9 +60,9 @@ function inner_cons_oop(
return fcons(x, p)[i]
end

function cons_f2_oop(x, dx, fcons, p, i)
function cons_f2_oop(mode, x, dx, fcons, p, i)
Enzyme.autodiff_deferred(
Enzyme.Reverse, Const(inner_cons_oop), Active, Enzyme.Duplicated(x, dx),
mode, Const(inner_cons_oop), Active, Enzyme.Duplicated(x, dx),
Const(fcons), Const(p), Const(i))
return nothing
end
Expand All @@ -83,22 +73,37 @@ function lagrangian(x, _f::Function, cons::Function, p, λ, σ = one(eltype(x)))
return σ * _f(x, p) + dot(λ, res)
end

function lag_grad(x, dx, lagrangian::Function, _f::Function, cons::Function, p, σ, λ)
function lag_grad(mode, x, dx, lagrangian::Function, _f::Function, cons::Function, p, σ, λ)
Enzyme.autodiff_deferred(
Enzyme.Reverse, Const(lagrangian), Active, Enzyme.Duplicated(x, dx),
mode, Const(lagrangian), Active, Enzyme.Duplicated(x, dx),
Const(_f), Const(cons), Const(p), Const(λ), Const(σ))
return nothing
end


set_runtime_activity2(a::Mode1, ::Enzyme.Mode{ABI, Err, RTA}) where {Mode1, ABI, Err, RTA} = Enzyme.set_runtime_activity(a, RTA)
function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
adtype::AutoEnzyme, p, num_cons = 0;
g = false, h = false, hv = false, fg = false, fgh = false,
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
lag_h = false)

rmode = if adtype.mode isa Nothing
Enzyme.Reverse
else
set_runtime_activity2(Enzyme.Reverse)
end

fmode = if adtype.mode isa Nothing
Enzyme.Forward
else
set_runtime_activity2(Enzyme.Forward)
end

if g == true && f.grad === nothing
function grad(res, θ, p = p)
Enzyme.make_zero!(res)
Enzyme.autodiff(Enzyme.Reverse,
Enzyme.autodiff(rmode,
Const(firstapply),
Active,
Const(f.f),
Expand All @@ -115,7 +120,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
if fg == true && f.fg === nothing
function fg!(res, θ, p = p)
Enzyme.make_zero!(res)
y = Enzyme.autodiff(Enzyme.ReverseWithPrimal,
y = Enzyme.autodiff(WithPrimal(rmode),
Const(firstapply),
Active,
Const(f.f),
Expand Down Expand Up @@ -145,8 +150,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
Enzyme.make_zero!(bθ)
Enzyme.make_zero!.(vdbθ)

Enzyme.autodiff(Enzyme.Forward,
Enzyme.autodiff(fmode,
inner_grad,
Const(rmode),
Enzyme.BatchDuplicated(θ, vdθ),
Enzyme.BatchDuplicatedNoNeed(bθ, vdbθ),
Const(f.f),
Expand All @@ -168,8 +174,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * one(eltype(θ)))))
vdbθ = Tuple(zeros(eltype(θ), length(θ)) for i in eachindex(θ))

Enzyme.autodiff(Enzyme.Forward,
Enzyme.autodiff(fmode,
inner_grad,
Const(rmode),
Enzyme.BatchDuplicated(θ, vdθ),
Enzyme.BatchDuplicatedNoNeed(G, vdbθ),
Const(f.f),
Expand All @@ -189,7 +196,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
if hv == true && f.hv === nothing
function hv!(H, θ, v, p = p)
H .= Enzyme.autodiff(
Enzyme.Forward, hv_f2_alloc, Duplicated(θ, v),
fmode, hv_f2_alloc, Const(rmode), Duplicated(θ, v),
Const(f.f), Const(p)
)[1]
end
Expand Down Expand Up @@ -221,7 +228,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
Enzyme.make_zero!(Jaccache[i])
end
Enzyme.make_zero!(y)
Enzyme.autodiff(Enzyme.Forward, f.cons, BatchDuplicated(y, Jaccache),
Enzyme.autodiff(fmode, f.cons, BatchDuplicated(y, Jaccache),
BatchDuplicated(θ, seeds), Const(p))
for i in eachindex(θ)
if J isa Vector
Expand Down Expand Up @@ -254,7 +261,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
Enzyme.make_zero!(res)
Enzyme.make_zero!(cons_res)

Enzyme.autodiff(Enzyme.Reverse,
Enzyme.autodiff(rmode,
f.cons,
Const,
Duplicated(cons_res, v),
Expand All @@ -275,7 +282,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
Enzyme.make_zero!(res)
Enzyme.make_zero!(cons_res)

Enzyme.autodiff(Enzyme.Forward,
Enzyme.autodiff(fmode,
f.cons,
Duplicated(cons_res, res),
Duplicated(θ, v),
Expand All @@ -297,8 +304,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
for i in 1:num_cons
Enzyme.make_zero!(cons_bθ)
Enzyme.make_zero!.(cons_vdbθ)
Enzyme.autodiff(Enzyme.Forward,
Enzyme.autodiff(fmode,
cons_f2,
Const(rmode),
Enzyme.BatchDuplicated(θ, cons_vdθ),
Enzyme.BatchDuplicated(cons_bθ, cons_vdbθ),
Const(f.cons),
Expand Down Expand Up @@ -332,8 +340,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
Enzyme.make_zero!(lag_bθ)
Enzyme.make_zero!.(lag_vdbθ)

Enzyme.autodiff(Enzyme.Forward,
Enzyme.autodiff(fmode,
lag_grad,
Const(rmode),
Enzyme.BatchDuplicated(θ, lag_vdθ),
Enzyme.BatchDuplicatedNoNeed(lag_bθ, lag_vdbθ),
Const(lagrangian),
Expand All @@ -357,8 +366,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
Enzyme.make_zero!(lag_bθ)
Enzyme.make_zero!.(lag_vdbθ)

Enzyme.autodiff(Enzyme.Forward,
Enzyme.autodiff(fmode,
lag_grad,
Const(rmode),
Enzyme.BatchDuplicated(θ, lag_vdθ),
Enzyme.BatchDuplicatedNoNeed(lag_bθ, lag_vdbθ),
Const(lagrangian),
Expand Down Expand Up @@ -410,11 +420,23 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
g = false, h = false, hv = false, fg = false, fgh = false,
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
lag_h = false)
rmode = if adtype.mode isa Nothing
Enzyme.Reverse
else
set_runtime_activity2(Enzyme.Reverse)
end

fmode = if adtype.mode isa Nothing
Enzyme.Forward
else
set_runtime_activity2(Enzyme.Forward)
end

if g == true && f.grad === nothing
res = zeros(eltype(x), size(x))
function grad(θ, p = p)
Enzyme.make_zero!(res)
Enzyme.autodiff(Enzyme.Reverse,
Enzyme.autodiff(rmode,
Const(firstapply),
Active,
Const(f.f),
Expand All @@ -433,7 +455,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
res_fg = zeros(eltype(x), size(x))
function fg!(θ, p = p)
Enzyme.make_zero!(res_fg)
y = Enzyme.autodiff(Enzyme.ReverseWithPrimal,
y = Enzyme.autodiff(WithPrimal(rmode),
Const(firstapply),
Active,
Const(f.f),
Expand All @@ -457,8 +479,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
Enzyme.make_zero!(bθ)
Enzyme.make_zero!.(vdbθ)

Enzyme.autodiff(Enzyme.Forward,
Enzyme.autodiff(fmode,
inner_grad,
Const(rmode),
Enzyme.BatchDuplicated(θ, vdθ),
Enzyme.BatchDuplicated(bθ, vdbθ),
Const(f.f),
Expand All @@ -485,8 +508,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
Enzyme.make_zero!(H_fgh)
Enzyme.make_zero!.(vdbθ_fgh)

Enzyme.autodiff(Enzyme.Forward,
Enzyme.autodiff(fmode,
inner_grad,
Const(rmode),
Enzyme.BatchDuplicated(θ, vdθ_fgh),
Enzyme.BatchDuplicatedNoNeed(G_fgh, vdbθ_fgh),
Const(f.f),
Expand All @@ -507,7 +531,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
if hv == true && f.hv === nothing
function hv!(θ, v, p = p)
return Enzyme.autodiff(
Enzyme.Forward, hv_f2_alloc, DuplicatedNoNeed, Duplicated(θ, v),
fmode, hv_f2_alloc, DuplicatedNoNeed, Const(rmode), Duplicated(θ, v),
Const(_f), Const(f.f), Const(p)
)[1]
end
Expand All @@ -533,7 +557,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
for i in eachindex(Jaccache)
Enzyme.make_zero!(Jaccache[i])
end
Jaccache, y = Enzyme.autodiff(Enzyme.ForwardWithPrimal, f.cons, Duplicated,
Jaccache, y = Enzyme.autodiff(WithPrimal(fmode), f.cons, Duplicated,
BatchDuplicated(θ, seeds), Const(p))
if size(y, 1) == 1
return reduce(vcat, Jaccache)
Expand All @@ -555,7 +579,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
Enzyme.make_zero!(res_vjp)
Enzyme.make_zero!(cons_vjp_res)

Enzyme.autodiff(Enzyme.Reverse,
Enzyme.autodiff(WithPrimal(rmode),
f.cons,
Const,
Duplicated(cons_vjp_res, v),
Expand All @@ -578,7 +602,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
Enzyme.make_zero!(res_jvp)
Enzyme.make_zero!(cons_jvp_res)

Enzyme.autodiff(Enzyme.Forward,
Enzyme.autodiff(fmode,
f.cons,
Duplicated(cons_jvp_res, res_jvp),
Duplicated(θ, v),
Expand All @@ -601,8 +625,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
return map(1:num_cons) do i
Enzyme.make_zero!(cons_bθ)
Enzyme.make_zero!.(cons_vdbθ)
Enzyme.autodiff(Enzyme.Forward,
Enzyme.autodiff(fmode,
cons_f2_oop,
Const(rmode),
Enzyme.BatchDuplicated(θ, cons_vdθ),
Enzyme.BatchDuplicated(cons_bθ, cons_vdbθ),
Const(f.cons),
Expand Down Expand Up @@ -631,8 +656,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
Enzyme.make_zero!(lag_bθ)
Enzyme.make_zero!.(lag_vdbθ)

Enzyme.autodiff(Enzyme.Forward,
Enzyme.autodiff(fmode,
lag_grad,
Const(rmode),
Enzyme.BatchDuplicated(θ, lag_vdθ),
Enzyme.BatchDuplicatedNoNeed(lag_bθ, lag_vdbθ),
Const(lagrangian),
Expand Down
Loading