diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 932b4d9..8793288 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -23,6 +23,7 @@ jobs: fail-fast: false matrix: version: + - '1.10' - '1' os: - ubuntu-latest diff --git a/Project.toml b/Project.toml index 7f3356c..05bcd60 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,6 @@ uuid = "bca83a33-5cc9-4baa-983d-23429ab6bcbb" authors = ["Vaibhav Dixit and contributors"] version = "2.3.0" - [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/ext/OptimizationEnzymeExt.jl b/ext/OptimizationEnzymeExt.jl index 7adae59..a19bfb9 100644 --- a/ext/OptimizationEnzymeExt.jl +++ b/ext/OptimizationEnzymeExt.jl @@ -17,8 +17,8 @@ using Core: Vararg end end -function inner_grad(θ, bθ, f, p) - Enzyme.autodiff_deferred(Enzyme.Reverse, +function inner_grad(mode::Mode, θ, bθ, f, p) where {Mode} + Enzyme.autodiff(mode, Const(firstapply), Active, Const(f), @@ -28,19 +28,9 @@ function inner_grad(θ, bθ, f, p) return nothing end -function inner_grad_primal(θ, bθ, f, p) - Enzyme.autodiff_deferred(Enzyme.ReverseWithPrimal, - Const(firstapply), - Active, - Const(f), - Enzyme.Duplicated(θ, bθ), - Const(p) - )[2] -end - -function hv_f2_alloc(x, f, p) +function hv_f2_alloc(mode::Mode, x, f, p) where {Mode} dx = Enzyme.make_zero(x) - Enzyme.autodiff_deferred(Enzyme.Reverse, + Enzyme.autodiff(mode, Const(firstapply), Active, Const(f), @@ -57,9 +47,9 @@ function inner_cons(x, fcons::Function, p::Union{SciMLBase.NullParameters, Nothi return res[i] end -function cons_f2(x, dx, fcons, p, num_cons, i) +function cons_f2(mode, x, dx, fcons, p, num_cons, i) Enzyme.autodiff_deferred( - Enzyme.Reverse, Const(inner_cons), Active, Enzyme.Duplicated(x, dx), + mode, Const(inner_cons), Active, Enzyme.Duplicated(x, dx), Const(fcons), Const(p), Const(num_cons), Const(i)) return nothing end @@ -70,9 +60,9 @@ function inner_cons_oop( return fcons(x, p)[i] end -function cons_f2_oop(x, dx, fcons, p, i) +function cons_f2_oop(mode, x, dx, fcons, p, i) Enzyme.autodiff_deferred( - Enzyme.Reverse, Const(inner_cons_oop), Active, Enzyme.Duplicated(x, dx), + mode, Const(inner_cons_oop), Active, Enzyme.Duplicated(x, dx), Const(fcons), Const(p), Const(i)) return nothing end @@ -83,22 +73,38 @@ function lagrangian(x, _f::Function, cons::Function, p, λ, σ = one(eltype(x))) return σ * _f(x, p) + dot(λ, res) end -function lag_grad(x, dx, lagrangian::Function, _f::Function, cons::Function, p, σ, λ) +function lag_grad(mode, x, dx, lagrangian::Function, _f::Function, cons::Function, p, σ, λ) Enzyme.autodiff_deferred( - Enzyme.Reverse, Const(lagrangian), Active, Enzyme.Duplicated(x, dx), + mode, Const(lagrangian), Active, Enzyme.Duplicated(x, dx), Const(_f), Const(cons), Const(p), Const(λ), Const(σ)) return nothing end +function set_runtime_activity2( + a::Mode1, ::Enzyme.Mode{ABI, Err, RTA}) where {Mode1, ABI, Err, RTA} + Enzyme.set_runtime_activity(a, RTA) +end function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, adtype::AutoEnzyme, p, num_cons = 0; g = false, h = false, hv = false, fg = false, fgh = false, cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false, lag_h = false) + rmode = if adtype.mode isa Nothing + Enzyme.Reverse + else + set_runtime_activity2(Enzyme.Reverse, adtype.mode) + end + + fmode = if adtype.mode isa Nothing + Enzyme.Forward + else + set_runtime_activity2(Enzyme.Forward, adtype.mode) + end + if g == true && f.grad === nothing function grad(res, θ, p = p) Enzyme.make_zero!(res) - Enzyme.autodiff(Enzyme.Reverse, + Enzyme.autodiff(rmode, Const(firstapply), Active, Const(f.f), @@ -115,7 +121,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, if fg == true && f.fg === nothing function fg!(res, θ, p = p) Enzyme.make_zero!(res) - y = Enzyme.autodiff(Enzyme.ReverseWithPrimal, + y = Enzyme.autodiff(WithPrimal(rmode), Const(firstapply), Active, Const(f.f), @@ -145,8 +151,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, Enzyme.make_zero!(bθ) Enzyme.make_zero!.(vdbθ) - Enzyme.autodiff(Enzyme.Forward, + Enzyme.autodiff(fmode, inner_grad, + Const(rmode), Enzyme.BatchDuplicated(θ, vdθ), Enzyme.BatchDuplicatedNoNeed(bθ, vdbθ), Const(f.f), @@ -168,8 +175,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * one(eltype(θ))))) vdbθ = Tuple(zeros(eltype(θ), length(θ)) for i in eachindex(θ)) - Enzyme.autodiff(Enzyme.Forward, + Enzyme.autodiff(fmode, inner_grad, + Const(rmode), Enzyme.BatchDuplicated(θ, vdθ), Enzyme.BatchDuplicatedNoNeed(G, vdbθ), Const(f.f), @@ -189,7 +197,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, if hv == true && f.hv === nothing function hv!(H, θ, v, p = p) H .= Enzyme.autodiff( - Enzyme.Forward, hv_f2_alloc, Duplicated(θ, v), + fmode, hv_f2_alloc, Const(rmode), Duplicated(θ, v), Const(f.f), Const(p) )[1] end @@ -221,7 +229,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, Enzyme.make_zero!(Jaccache[i]) end Enzyme.make_zero!(y) - Enzyme.autodiff(Enzyme.Forward, f.cons, BatchDuplicated(y, Jaccache), + Enzyme.autodiff(fmode, f.cons, BatchDuplicated(y, Jaccache), BatchDuplicated(θ, seeds), Const(p)) for i in eachindex(θ) if J isa Vector @@ -254,7 +262,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, Enzyme.make_zero!(res) Enzyme.make_zero!(cons_res) - Enzyme.autodiff(Enzyme.Reverse, + Enzyme.autodiff(rmode, f.cons, Const, Duplicated(cons_res, v), @@ -275,7 +283,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, Enzyme.make_zero!(res) Enzyme.make_zero!(cons_res) - Enzyme.autodiff(Enzyme.Forward, + Enzyme.autodiff(fmode, f.cons, Duplicated(cons_res, res), Duplicated(θ, v), @@ -297,8 +305,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, for i in 1:num_cons Enzyme.make_zero!(cons_bθ) Enzyme.make_zero!.(cons_vdbθ) - Enzyme.autodiff(Enzyme.Forward, + Enzyme.autodiff(fmode, cons_f2, + Const(rmode), Enzyme.BatchDuplicated(θ, cons_vdθ), Enzyme.BatchDuplicated(cons_bθ, cons_vdbθ), Const(f.cons), @@ -332,8 +341,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, Enzyme.make_zero!(lag_bθ) Enzyme.make_zero!.(lag_vdbθ) - Enzyme.autodiff(Enzyme.Forward, + Enzyme.autodiff(fmode, lag_grad, + Const(rmode), Enzyme.BatchDuplicated(θ, lag_vdθ), Enzyme.BatchDuplicatedNoNeed(lag_bθ, lag_vdbθ), Const(lagrangian), @@ -357,8 +367,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, Enzyme.make_zero!(lag_bθ) Enzyme.make_zero!.(lag_vdbθ) - Enzyme.autodiff(Enzyme.Forward, + Enzyme.autodiff(fmode, lag_grad, + Const(rmode), Enzyme.BatchDuplicated(θ, lag_vdθ), Enzyme.BatchDuplicatedNoNeed(lag_bθ, lag_vdbθ), Const(lagrangian), @@ -410,11 +421,23 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x g = false, h = false, hv = false, fg = false, fgh = false, cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false, lag_h = false) + rmode = if adtype.mode isa Nothing + Enzyme.Reverse + else + set_runtime_activity2(Enzyme.Reverse, adtype.mode) + end + + fmode = if adtype.mode isa Nothing + Enzyme.Forward + else + set_runtime_activity2(Enzyme.Forward, adtype.mode) + end + if g == true && f.grad === nothing res = zeros(eltype(x), size(x)) function grad(θ, p = p) Enzyme.make_zero!(res) - Enzyme.autodiff(Enzyme.Reverse, + Enzyme.autodiff(rmode, Const(firstapply), Active, Const(f.f), @@ -433,7 +456,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x res_fg = zeros(eltype(x), size(x)) function fg!(θ, p = p) Enzyme.make_zero!(res_fg) - y = Enzyme.autodiff(Enzyme.ReverseWithPrimal, + y = Enzyme.autodiff(WithPrimal(rmode), Const(firstapply), Active, Const(f.f), @@ -457,8 +480,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x Enzyme.make_zero!(bθ) Enzyme.make_zero!.(vdbθ) - Enzyme.autodiff(Enzyme.Forward, + Enzyme.autodiff(fmode, inner_grad, + Const(rmode), Enzyme.BatchDuplicated(θ, vdθ), Enzyme.BatchDuplicated(bθ, vdbθ), Const(f.f), @@ -485,8 +509,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x Enzyme.make_zero!(H_fgh) Enzyme.make_zero!.(vdbθ_fgh) - Enzyme.autodiff(Enzyme.Forward, + Enzyme.autodiff(fmode, inner_grad, + Const(rmode), Enzyme.BatchDuplicated(θ, vdθ_fgh), Enzyme.BatchDuplicatedNoNeed(G_fgh, vdbθ_fgh), Const(f.f), @@ -507,7 +532,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x if hv == true && f.hv === nothing function hv!(θ, v, p = p) return Enzyme.autodiff( - Enzyme.Forward, hv_f2_alloc, DuplicatedNoNeed, Duplicated(θ, v), + fmode, hv_f2_alloc, DuplicatedNoNeed, Const(rmode), Duplicated(θ, v), Const(_f), Const(f.f), Const(p) )[1] end @@ -533,7 +558,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x for i in eachindex(Jaccache) Enzyme.make_zero!(Jaccache[i]) end - Jaccache, y = Enzyme.autodiff(Enzyme.ForwardWithPrimal, f.cons, Duplicated, + Jaccache, y = Enzyme.autodiff(WithPrimal(fmode), f.cons, Duplicated, BatchDuplicated(θ, seeds), Const(p)) if size(y, 1) == 1 return reduce(vcat, Jaccache) @@ -555,7 +580,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x Enzyme.make_zero!(res_vjp) Enzyme.make_zero!(cons_vjp_res) - Enzyme.autodiff(Enzyme.Reverse, + Enzyme.autodiff(WithPrimal(rmode), f.cons, Const, Duplicated(cons_vjp_res, v), @@ -578,7 +603,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x Enzyme.make_zero!(res_jvp) Enzyme.make_zero!(cons_jvp_res) - Enzyme.autodiff(Enzyme.Forward, + Enzyme.autodiff(fmode, f.cons, Duplicated(cons_jvp_res, res_jvp), Duplicated(θ, v), @@ -601,8 +626,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x return map(1:num_cons) do i Enzyme.make_zero!(cons_bθ) Enzyme.make_zero!.(cons_vdbθ) - Enzyme.autodiff(Enzyme.Forward, + Enzyme.autodiff(fmode, cons_f2_oop, + Const(rmode), Enzyme.BatchDuplicated(θ, cons_vdθ), Enzyme.BatchDuplicated(cons_bθ, cons_vdbθ), Const(f.cons), @@ -631,8 +657,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x Enzyme.make_zero!(lag_bθ) Enzyme.make_zero!.(lag_vdbθ) - Enzyme.autodiff(Enzyme.Forward, + Enzyme.autodiff(fmode, lag_grad, + Const(rmode), Enzyme.BatchDuplicated(θ, lag_vdθ), Enzyme.BatchDuplicatedNoNeed(lag_bθ, lag_vdbθ), Const(lagrangian), diff --git a/test/adtests.jl b/test/adtests.jl index fc85ad1..6fe4eea 100644 --- a/test/adtests.jl +++ b/test/adtests.jl @@ -1172,17 +1172,18 @@ using MLUtils optf = OptimizationFunction(loss, AutoEnzyme()) optf = OptimizationBase.instantiate_function( - optf, rand(3), AutoEnzyme(), iterate(data)[1], g = true, fg = true) + optf, rand(3), AutoEnzyme(mode = set_runtime_activity(Reverse)), + iterate(data)[1], g = true, fg = true) G0 = zeros(3) - @test_broken optf.grad(G0, ones(3), (x, y)) + optf.grad(G0, ones(3), (x0, y0)) stochgrads = [] - # for (x,y) in data - # G = zeros(3) - # optf.grad(G, ones(3), (x,y)) - # push!(stochgrads, copy(G)) - # G1 = zeros(3) - # optf.fg(G1, ones(3), (x,y)) - # @test G ≈ G1 rtol=1e-6 - # end - # @test G0 ≈ sum(stochgrads)/length(stochgrads) rtol=1e-1 + for (x, y) in data + G = zeros(3) + optf.grad(G, ones(3), (x, y)) + push!(stochgrads, copy(G)) + G1 = zeros(3) + optf.fg(G1, ones(3), (x, y)) + @test G≈G1 rtol=1e-6 + end + @test G0≈sum(stochgrads) rtol=1e-1 end