diff --git a/ext/OptimizationEnzymeExt.jl b/ext/OptimizationEnzymeExt.jl index 7adae59..ad4ab2b 100644 --- a/ext/OptimizationEnzymeExt.jl +++ b/ext/OptimizationEnzymeExt.jl @@ -17,8 +17,8 @@ using Core: Vararg end end -function inner_grad(θ, bθ, f, p) - Enzyme.autodiff_deferred(Enzyme.Reverse, +function inner_grad(mode::Mode, θ, bθ, f, p) where Mode + Enzyme.autodiff(Mode, Const(firstapply), Active, Const(f), @@ -28,19 +28,9 @@ function inner_grad(θ, bθ, f, p) return nothing end -function inner_grad_primal(θ, bθ, f, p) - Enzyme.autodiff_deferred(Enzyme.ReverseWithPrimal, - Const(firstapply), - Active, - Const(f), - Enzyme.Duplicated(θ, bθ), - Const(p) - )[2] -end - -function hv_f2_alloc(x, f, p) +function hv_f2_alloc(mode::Mode, x, f, p) where Mode dx = Enzyme.make_zero(x) - Enzyme.autodiff_deferred(Enzyme.Reverse, + Enzyme.autodiff(mode, Const(firstapply), Active, Const(f), @@ -57,9 +47,9 @@ function inner_cons(x, fcons::Function, p::Union{SciMLBase.NullParameters, Nothi return res[i] end -function cons_f2(x, dx, fcons, p, num_cons, i) +function cons_f2(mode, x, dx, fcons, p, num_cons, i) Enzyme.autodiff_deferred( - Enzyme.Reverse, Const(inner_cons), Active, Enzyme.Duplicated(x, dx), + mode, Const(inner_cons), Active, Enzyme.Duplicated(x, dx), Const(fcons), Const(p), Const(num_cons), Const(i)) return nothing end @@ -70,9 +60,9 @@ function inner_cons_oop( return fcons(x, p)[i] end -function cons_f2_oop(x, dx, fcons, p, i) +function cons_f2_oop(mode, x, dx, fcons, p, i) Enzyme.autodiff_deferred( - Enzyme.Reverse, Const(inner_cons_oop), Active, Enzyme.Duplicated(x, dx), + mode, Const(inner_cons_oop), Active, Enzyme.Duplicated(x, dx), Const(fcons), Const(p), Const(i)) return nothing end @@ -83,22 +73,37 @@ function lagrangian(x, _f::Function, cons::Function, p, λ, σ = one(eltype(x))) return σ * _f(x, p) + dot(λ, res) end -function lag_grad(x, dx, lagrangian::Function, _f::Function, cons::Function, p, σ, λ) +function lag_grad(mode, x, dx, lagrangian::Function, _f::Function, cons::Function, p, σ, λ) Enzyme.autodiff_deferred( - Enzyme.Reverse, Const(lagrangian), Active, Enzyme.Duplicated(x, dx), + mode, Const(lagrangian), Active, Enzyme.Duplicated(x, dx), Const(_f), Const(cons), Const(p), Const(λ), Const(σ)) return nothing end + +set_runtime_activity2(a::Mode1, ::Enzyme.Mode{ABI, Err, RTA}) where {Mode1, ABI, Err, RTA} = Enzyme.set_runtime_activity(a, RTA) function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, adtype::AutoEnzyme, p, num_cons = 0; g = false, h = false, hv = false, fg = false, fgh = false, cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false, lag_h = false) + + rmode = if adtype.mode isa Nothing + Enzyme.Reverse + else + set_runtime_activity2(Enzyme.Reverse) + end + + fmode = if adtype.mode isa Nothing + Enzyme.Forward + else + set_runtime_activity2(Enzyme.Forward) + end + if g == true && f.grad === nothing function grad(res, θ, p = p) Enzyme.make_zero!(res) - Enzyme.autodiff(Enzyme.Reverse, + Enzyme.autodiff(rmode, Const(firstapply), Active, Const(f.f), @@ -115,7 +120,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, if fg == true && f.fg === nothing function fg!(res, θ, p = p) Enzyme.make_zero!(res) - y = Enzyme.autodiff(Enzyme.ReverseWithPrimal, + y = Enzyme.autodiff(WithPrimal(rmode), Const(firstapply), Active, Const(f.f), @@ -145,8 +150,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, Enzyme.make_zero!(bθ) Enzyme.make_zero!.(vdbθ) - Enzyme.autodiff(Enzyme.Forward, + Enzyme.autodiff(fmode, inner_grad, + Const(rmode), Enzyme.BatchDuplicated(θ, vdθ), Enzyme.BatchDuplicatedNoNeed(bθ, vdbθ), Const(f.f), @@ -168,8 +174,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * one(eltype(θ))))) vdbθ = Tuple(zeros(eltype(θ), length(θ)) for i in eachindex(θ)) - Enzyme.autodiff(Enzyme.Forward, + Enzyme.autodiff(fmode, inner_grad, + Const(rmode), Enzyme.BatchDuplicated(θ, vdθ), Enzyme.BatchDuplicatedNoNeed(G, vdbθ), Const(f.f), @@ -189,7 +196,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, if hv == true && f.hv === nothing function hv!(H, θ, v, p = p) H .= Enzyme.autodiff( - Enzyme.Forward, hv_f2_alloc, Duplicated(θ, v), + fmode, hv_f2_alloc, Const(rmode), Duplicated(θ, v), Const(f.f), Const(p) )[1] end @@ -221,7 +228,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, Enzyme.make_zero!(Jaccache[i]) end Enzyme.make_zero!(y) - Enzyme.autodiff(Enzyme.Forward, f.cons, BatchDuplicated(y, Jaccache), + Enzyme.autodiff(fmode, f.cons, BatchDuplicated(y, Jaccache), BatchDuplicated(θ, seeds), Const(p)) for i in eachindex(θ) if J isa Vector @@ -254,7 +261,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, Enzyme.make_zero!(res) Enzyme.make_zero!(cons_res) - Enzyme.autodiff(Enzyme.Reverse, + Enzyme.autodiff(rmode, f.cons, Const, Duplicated(cons_res, v), @@ -275,7 +282,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, Enzyme.make_zero!(res) Enzyme.make_zero!(cons_res) - Enzyme.autodiff(Enzyme.Forward, + Enzyme.autodiff(fmode, f.cons, Duplicated(cons_res, res), Duplicated(θ, v), @@ -297,8 +304,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, for i in 1:num_cons Enzyme.make_zero!(cons_bθ) Enzyme.make_zero!.(cons_vdbθ) - Enzyme.autodiff(Enzyme.Forward, + Enzyme.autodiff(fmode, cons_f2, + Const(rmode), Enzyme.BatchDuplicated(θ, cons_vdθ), Enzyme.BatchDuplicated(cons_bθ, cons_vdbθ), Const(f.cons), @@ -332,8 +340,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, Enzyme.make_zero!(lag_bθ) Enzyme.make_zero!.(lag_vdbθ) - Enzyme.autodiff(Enzyme.Forward, + Enzyme.autodiff(fmode, lag_grad, + Const(rmode), Enzyme.BatchDuplicated(θ, lag_vdθ), Enzyme.BatchDuplicatedNoNeed(lag_bθ, lag_vdbθ), Const(lagrangian), @@ -357,8 +366,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, Enzyme.make_zero!(lag_bθ) Enzyme.make_zero!.(lag_vdbθ) - Enzyme.autodiff(Enzyme.Forward, + Enzyme.autodiff(fmode, lag_grad, + Const(rmode), Enzyme.BatchDuplicated(θ, lag_vdθ), Enzyme.BatchDuplicatedNoNeed(lag_bθ, lag_vdbθ), Const(lagrangian), @@ -410,11 +420,23 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x g = false, h = false, hv = false, fg = false, fgh = false, cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false, lag_h = false) + rmode = if adtype.mode isa Nothing + Enzyme.Reverse + else + set_runtime_activity2(Enzyme.Reverse) + end + + fmode = if adtype.mode isa Nothing + Enzyme.Forward + else + set_runtime_activity2(Enzyme.Forward) + end + if g == true && f.grad === nothing res = zeros(eltype(x), size(x)) function grad(θ, p = p) Enzyme.make_zero!(res) - Enzyme.autodiff(Enzyme.Reverse, + Enzyme.autodiff(rmode, Const(firstapply), Active, Const(f.f), @@ -433,7 +455,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x res_fg = zeros(eltype(x), size(x)) function fg!(θ, p = p) Enzyme.make_zero!(res_fg) - y = Enzyme.autodiff(Enzyme.ReverseWithPrimal, + y = Enzyme.autodiff(WithPrimal(rmode), Const(firstapply), Active, Const(f.f), @@ -457,8 +479,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x Enzyme.make_zero!(bθ) Enzyme.make_zero!.(vdbθ) - Enzyme.autodiff(Enzyme.Forward, + Enzyme.autodiff(fmode, inner_grad, + Const(rmode), Enzyme.BatchDuplicated(θ, vdθ), Enzyme.BatchDuplicated(bθ, vdbθ), Const(f.f), @@ -485,8 +508,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x Enzyme.make_zero!(H_fgh) Enzyme.make_zero!.(vdbθ_fgh) - Enzyme.autodiff(Enzyme.Forward, + Enzyme.autodiff(fmode, inner_grad, + Const(rmode), Enzyme.BatchDuplicated(θ, vdθ_fgh), Enzyme.BatchDuplicatedNoNeed(G_fgh, vdbθ_fgh), Const(f.f), @@ -507,7 +531,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x if hv == true && f.hv === nothing function hv!(θ, v, p = p) return Enzyme.autodiff( - Enzyme.Forward, hv_f2_alloc, DuplicatedNoNeed, Duplicated(θ, v), + fmode, hv_f2_alloc, DuplicatedNoNeed, Const(rmode), Duplicated(θ, v), Const(_f), Const(f.f), Const(p) )[1] end @@ -533,7 +557,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x for i in eachindex(Jaccache) Enzyme.make_zero!(Jaccache[i]) end - Jaccache, y = Enzyme.autodiff(Enzyme.ForwardWithPrimal, f.cons, Duplicated, + Jaccache, y = Enzyme.autodiff(WithPrimal(fmode), f.cons, Duplicated, BatchDuplicated(θ, seeds), Const(p)) if size(y, 1) == 1 return reduce(vcat, Jaccache) @@ -555,7 +579,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x Enzyme.make_zero!(res_vjp) Enzyme.make_zero!(cons_vjp_res) - Enzyme.autodiff(Enzyme.Reverse, + Enzyme.autodiff(WithPrimal(rmode), f.cons, Const, Duplicated(cons_vjp_res, v), @@ -578,7 +602,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x Enzyme.make_zero!(res_jvp) Enzyme.make_zero!(cons_jvp_res) - Enzyme.autodiff(Enzyme.Forward, + Enzyme.autodiff(fmode, f.cons, Duplicated(cons_jvp_res, res_jvp), Duplicated(θ, v), @@ -601,8 +625,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x return map(1:num_cons) do i Enzyme.make_zero!(cons_bθ) Enzyme.make_zero!.(cons_vdbθ) - Enzyme.autodiff(Enzyme.Forward, + Enzyme.autodiff(fmode, cons_f2_oop, + Const(rmode), Enzyme.BatchDuplicated(θ, cons_vdθ), Enzyme.BatchDuplicated(cons_bθ, cons_vdbθ), Const(f.cons), @@ -631,8 +656,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x Enzyme.make_zero!(lag_bθ) Enzyme.make_zero!.(lag_vdbθ) - Enzyme.autodiff(Enzyme.Forward, + Enzyme.autodiff(fmode, lag_grad, + Const(rmode), Enzyme.BatchDuplicated(θ, lag_vdθ), Enzyme.BatchDuplicatedNoNeed(lag_bθ, lag_vdbθ), Const(lagrangian),