Skip to content
This repository was archived by the owner on Aug 25, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.10'
- '1'
os:
- ubuntu-latest
Expand Down
1 change: 0 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ uuid = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
authors = ["Vaibhav Dixit <[email protected]> and contributors"]
version = "2.3.0"


[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
109 changes: 68 additions & 41 deletions ext/OptimizationEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ using Core: Vararg
end
end

function inner_grad(θ, bθ, f, p)
Enzyme.autodiff_deferred(Enzyme.Reverse,
function inner_grad(mode::Mode, θ, bθ, f, p) where {Mode}
Enzyme.autodiff(mode,
Const(firstapply),
Active,
Const(f),
Expand All @@ -28,19 +28,9 @@ function inner_grad(θ, bθ, f, p)
return nothing
end

function inner_grad_primal(θ, bθ, f, p)
Enzyme.autodiff_deferred(Enzyme.ReverseWithPrimal,
Const(firstapply),
Active,
Const(f),
Enzyme.Duplicated(θ, bθ),
Const(p)
)[2]
end

function hv_f2_alloc(x, f, p)
function hv_f2_alloc(mode::Mode, x, f, p) where {Mode}
dx = Enzyme.make_zero(x)
Enzyme.autodiff_deferred(Enzyme.Reverse,
Enzyme.autodiff(mode,
Const(firstapply),
Active,
Const(f),
Expand All @@ -57,9 +47,9 @@ function inner_cons(x, fcons::Function, p::Union{SciMLBase.NullParameters, Nothi
return res[i]
end

function cons_f2(x, dx, fcons, p, num_cons, i)
function cons_f2(mode, x, dx, fcons, p, num_cons, i)
Enzyme.autodiff_deferred(
Enzyme.Reverse, Const(inner_cons), Active, Enzyme.Duplicated(x, dx),
mode, Const(inner_cons), Active, Enzyme.Duplicated(x, dx),
Const(fcons), Const(p), Const(num_cons), Const(i))
return nothing
end
Expand All @@ -70,9 +60,9 @@ function inner_cons_oop(
return fcons(x, p)[i]
end

function cons_f2_oop(x, dx, fcons, p, i)
function cons_f2_oop(mode, x, dx, fcons, p, i)
Enzyme.autodiff_deferred(
Enzyme.Reverse, Const(inner_cons_oop), Active, Enzyme.Duplicated(x, dx),
mode, Const(inner_cons_oop), Active, Enzyme.Duplicated(x, dx),
Const(fcons), Const(p), Const(i))
return nothing
end
Expand All @@ -83,22 +73,38 @@ function lagrangian(x, _f::Function, cons::Function, p, λ, σ = one(eltype(x)))
return σ * _f(x, p) + dot(λ, res)
end

function lag_grad(x, dx, lagrangian::Function, _f::Function, cons::Function, p, σ, λ)
function lag_grad(mode, x, dx, lagrangian::Function, _f::Function, cons::Function, p, σ, λ)
Enzyme.autodiff_deferred(
Enzyme.Reverse, Const(lagrangian), Active, Enzyme.Duplicated(x, dx),
mode, Const(lagrangian), Active, Enzyme.Duplicated(x, dx),
Const(_f), Const(cons), Const(p), Const(λ), Const(σ))
return nothing
end

function set_runtime_activity2(
a::Mode1, ::Enzyme.Mode{ABI, Err, RTA}) where {Mode1, ABI, Err, RTA}
Enzyme.set_runtime_activity(a, RTA)
end
function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
adtype::AutoEnzyme, p, num_cons = 0;
g = false, h = false, hv = false, fg = false, fgh = false,
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
lag_h = false)
rmode = if adtype.mode isa Nothing
Enzyme.Reverse
else
set_runtime_activity2(Enzyme.Reverse, adtype.mode)
end

fmode = if adtype.mode isa Nothing
Enzyme.Forward
else
set_runtime_activity2(Enzyme.Forward, adtype.mode)
end

if g == true && f.grad === nothing
function grad(res, θ, p = p)
Enzyme.make_zero!(res)
Enzyme.autodiff(Enzyme.Reverse,
Enzyme.autodiff(rmode,
Const(firstapply),
Active,
Const(f.f),
Expand All @@ -115,7 +121,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
if fg == true && f.fg === nothing
function fg!(res, θ, p = p)
Enzyme.make_zero!(res)
y = Enzyme.autodiff(Enzyme.ReverseWithPrimal,
y = Enzyme.autodiff(WithPrimal(rmode),
Const(firstapply),
Active,
Const(f.f),
Expand Down Expand Up @@ -145,8 +151,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
Enzyme.make_zero!(bθ)
Enzyme.make_zero!.(vdbθ)

Enzyme.autodiff(Enzyme.Forward,
Enzyme.autodiff(fmode,
inner_grad,
Const(rmode),
Enzyme.BatchDuplicated(θ, vdθ),
Enzyme.BatchDuplicatedNoNeed(bθ, vdbθ),
Const(f.f),
Expand All @@ -168,8 +175,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * one(eltype(θ)))))
vdbθ = Tuple(zeros(eltype(θ), length(θ)) for i in eachindex(θ))

Enzyme.autodiff(Enzyme.Forward,
Enzyme.autodiff(fmode,
inner_grad,
Const(rmode),
Enzyme.BatchDuplicated(θ, vdθ),
Enzyme.BatchDuplicatedNoNeed(G, vdbθ),
Const(f.f),
Expand All @@ -189,7 +197,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
if hv == true && f.hv === nothing
function hv!(H, θ, v, p = p)
H .= Enzyme.autodiff(
Enzyme.Forward, hv_f2_alloc, Duplicated(θ, v),
fmode, hv_f2_alloc, Const(rmode), Duplicated(θ, v),
Const(f.f), Const(p)
)[1]
end
Expand Down Expand Up @@ -221,7 +229,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
Enzyme.make_zero!(Jaccache[i])
end
Enzyme.make_zero!(y)
Enzyme.autodiff(Enzyme.Forward, f.cons, BatchDuplicated(y, Jaccache),
Enzyme.autodiff(fmode, f.cons, BatchDuplicated(y, Jaccache),
BatchDuplicated(θ, seeds), Const(p))
for i in eachindex(θ)
if J isa Vector
Expand Down Expand Up @@ -254,7 +262,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
Enzyme.make_zero!(res)
Enzyme.make_zero!(cons_res)

Enzyme.autodiff(Enzyme.Reverse,
Enzyme.autodiff(rmode,
f.cons,
Const,
Duplicated(cons_res, v),
Expand All @@ -275,7 +283,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
Enzyme.make_zero!(res)
Enzyme.make_zero!(cons_res)

Enzyme.autodiff(Enzyme.Forward,
Enzyme.autodiff(fmode,
f.cons,
Duplicated(cons_res, res),
Duplicated(θ, v),
Expand All @@ -297,8 +305,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
for i in 1:num_cons
Enzyme.make_zero!(cons_bθ)
Enzyme.make_zero!.(cons_vdbθ)
Enzyme.autodiff(Enzyme.Forward,
Enzyme.autodiff(fmode,
cons_f2,
Const(rmode),
Enzyme.BatchDuplicated(θ, cons_vdθ),
Enzyme.BatchDuplicated(cons_bθ, cons_vdbθ),
Const(f.cons),
Expand Down Expand Up @@ -332,8 +341,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
Enzyme.make_zero!(lag_bθ)
Enzyme.make_zero!.(lag_vdbθ)

Enzyme.autodiff(Enzyme.Forward,
Enzyme.autodiff(fmode,
lag_grad,
Const(rmode),
Enzyme.BatchDuplicated(θ, lag_vdθ),
Enzyme.BatchDuplicatedNoNeed(lag_bθ, lag_vdbθ),
Const(lagrangian),
Expand All @@ -357,8 +367,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
Enzyme.make_zero!(lag_bθ)
Enzyme.make_zero!.(lag_vdbθ)

Enzyme.autodiff(Enzyme.Forward,
Enzyme.autodiff(fmode,
lag_grad,
Const(rmode),
Enzyme.BatchDuplicated(θ, lag_vdθ),
Enzyme.BatchDuplicatedNoNeed(lag_bθ, lag_vdbθ),
Const(lagrangian),
Expand Down Expand Up @@ -410,11 +421,23 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
g = false, h = false, hv = false, fg = false, fgh = false,
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
lag_h = false)
rmode = if adtype.mode isa Nothing
Enzyme.Reverse
else
set_runtime_activity2(Enzyme.Reverse, adtype.mode)
end

fmode = if adtype.mode isa Nothing
Enzyme.Forward
else
set_runtime_activity2(Enzyme.Forward, adtype.mode)
end

if g == true && f.grad === nothing
res = zeros(eltype(x), size(x))
function grad(θ, p = p)
Enzyme.make_zero!(res)
Enzyme.autodiff(Enzyme.Reverse,
Enzyme.autodiff(rmode,
Const(firstapply),
Active,
Const(f.f),
Expand All @@ -433,7 +456,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
res_fg = zeros(eltype(x), size(x))
function fg!(θ, p = p)
Enzyme.make_zero!(res_fg)
y = Enzyme.autodiff(Enzyme.ReverseWithPrimal,
y = Enzyme.autodiff(WithPrimal(rmode),
Const(firstapply),
Active,
Const(f.f),
Expand All @@ -457,8 +480,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
Enzyme.make_zero!(bθ)
Enzyme.make_zero!.(vdbθ)

Enzyme.autodiff(Enzyme.Forward,
Enzyme.autodiff(fmode,
inner_grad,
Const(rmode),
Enzyme.BatchDuplicated(θ, vdθ),
Enzyme.BatchDuplicated(bθ, vdbθ),
Const(f.f),
Expand All @@ -485,8 +509,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
Enzyme.make_zero!(H_fgh)
Enzyme.make_zero!.(vdbθ_fgh)

Enzyme.autodiff(Enzyme.Forward,
Enzyme.autodiff(fmode,
inner_grad,
Const(rmode),
Enzyme.BatchDuplicated(θ, vdθ_fgh),
Enzyme.BatchDuplicatedNoNeed(G_fgh, vdbθ_fgh),
Const(f.f),
Expand All @@ -507,7 +532,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
if hv == true && f.hv === nothing
function hv!(θ, v, p = p)
return Enzyme.autodiff(
Enzyme.Forward, hv_f2_alloc, DuplicatedNoNeed, Duplicated(θ, v),
fmode, hv_f2_alloc, DuplicatedNoNeed, Const(rmode), Duplicated(θ, v),
Const(_f), Const(f.f), Const(p)
)[1]
end
Expand All @@ -533,7 +558,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
for i in eachindex(Jaccache)
Enzyme.make_zero!(Jaccache[i])
end
Jaccache, y = Enzyme.autodiff(Enzyme.ForwardWithPrimal, f.cons, Duplicated,
Jaccache, y = Enzyme.autodiff(WithPrimal(fmode), f.cons, Duplicated,
BatchDuplicated(θ, seeds), Const(p))
if size(y, 1) == 1
return reduce(vcat, Jaccache)
Expand All @@ -555,7 +580,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
Enzyme.make_zero!(res_vjp)
Enzyme.make_zero!(cons_vjp_res)

Enzyme.autodiff(Enzyme.Reverse,
Enzyme.autodiff(WithPrimal(rmode),
f.cons,
Const,
Duplicated(cons_vjp_res, v),
Expand All @@ -578,7 +603,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
Enzyme.make_zero!(res_jvp)
Enzyme.make_zero!(cons_jvp_res)

Enzyme.autodiff(Enzyme.Forward,
Enzyme.autodiff(fmode,
f.cons,
Duplicated(cons_jvp_res, res_jvp),
Duplicated(θ, v),
Expand All @@ -601,8 +626,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
return map(1:num_cons) do i
Enzyme.make_zero!(cons_bθ)
Enzyme.make_zero!.(cons_vdbθ)
Enzyme.autodiff(Enzyme.Forward,
Enzyme.autodiff(fmode,
cons_f2_oop,
Const(rmode),
Enzyme.BatchDuplicated(θ, cons_vdθ),
Enzyme.BatchDuplicated(cons_bθ, cons_vdbθ),
Const(f.cons),
Expand Down Expand Up @@ -631,8 +657,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
Enzyme.make_zero!(lag_bθ)
Enzyme.make_zero!.(lag_vdbθ)

Enzyme.autodiff(Enzyme.Forward,
Enzyme.autodiff(fmode,
lag_grad,
Const(rmode),
Enzyme.BatchDuplicated(θ, lag_vdθ),
Enzyme.BatchDuplicatedNoNeed(lag_bθ, lag_vdbθ),
Const(lagrangian),
Expand Down
23 changes: 12 additions & 11 deletions test/adtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1172,17 +1172,18 @@ using MLUtils

optf = OptimizationFunction(loss, AutoEnzyme())
optf = OptimizationBase.instantiate_function(
optf, rand(3), AutoEnzyme(), iterate(data)[1], g = true, fg = true)
optf, rand(3), AutoEnzyme(mode = set_runtime_activity(Reverse)),
iterate(data)[1], g = true, fg = true)
G0 = zeros(3)
@test_broken optf.grad(G0, ones(3), (x, y))
optf.grad(G0, ones(3), (x0, y0))
stochgrads = []
# for (x,y) in data
# G = zeros(3)
# optf.grad(G, ones(3), (x,y))
# push!(stochgrads, copy(G))
# G1 = zeros(3)
# optf.fg(G1, ones(3), (x,y))
# @test GG1 rtol=1e-6
# end
# @test G0sum(stochgrads)/length(stochgrads) rtol=1e-1
for (x, y) in data
G = zeros(3)
optf.grad(G, ones(3), (x, y))
push!(stochgrads, copy(G))
G1 = zeros(3)
optf.fg(G1, ones(3), (x, y))
@test GG1 rtol=1e-6
end
@test G0sum(stochgrads) rtol=1e-1
end
Loading