Skip to content
This repository was archived by the owner on Aug 25, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ OptimizationReverseDiffExt = "ReverseDiff"
OptimizationZygoteExt = "Zygote"

[compat]
ADTypes = "1.5"
ADTypes = "1.9"
ArrayInterface = "7.6"
DifferentiationInterface = "0.5"
DifferentiationInterface = "0.6.1"
DocStringExtensions = "0.9"
Enzyme = "0.12.12"
Enzyme = "0.13.2"
FastClosures = "0.3"
FiniteDiff = "2.12"
ForwardDiff = "0.10.26"
Expand Down
92 changes: 46 additions & 46 deletions src/OptimizationDIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ function instantiate_function(
adtype, soadtype = generate_adtype(adtype)

if g == true && f.grad === nothing
extras_grad = prepare_gradient(_f, adtype, x)
prep_grad = prepare_gradient(_f, adtype, x)
function grad(res, θ)
gradient!(_f, res, adtype, θ, extras_grad)
gradient!(_f, res, prep_grad, adtype, θ)
end
if p !== SciMLBase.NullParameters() && p !== nothing
function grad(res, θ, p)
Expand All @@ -57,10 +57,10 @@ function instantiate_function(

if fg == true && f.fg === nothing
if g == false
extras_grad = prepare_gradient(_f, adtype, x)
prep_grad = prepare_gradient(_f, adtype, x)
end
function fg!(res, θ)
(y, _) = value_and_gradient!(_f, res, adtype, θ, extras_grad)
(y, _) = value_and_gradient!(_f, res, prep_grad, adtype, θ)
return y
end
if p !== SciMLBase.NullParameters() && p !== nothing
Expand All @@ -79,9 +79,9 @@ function instantiate_function(
hess_sparsity = f.hess_prototype
hess_colors = f.hess_colorvec
if h == true && f.hess === nothing
extras_hess = prepare_hessian(_f, soadtype, x)
prep_hess = prepare_hessian(_f, soadtype, x)
function hess(res, θ)
hessian!(_f, res, soadtype, θ, extras_hess)
hessian!(_f, res, prep_hess, soadtype, θ)
end
if p !== SciMLBase.NullParameters() && p !== nothing
function hess(res, θ, p)
Expand All @@ -98,7 +98,7 @@ function instantiate_function(
if fgh == true && f.fgh === nothing
function fgh!(G, H, θ)
(y, _, _) = value_derivative_and_second_derivative!(
_f, G, H, soadtype, θ, extras_hess)
_f, G, H, prep_hess, soadtype, θ)
return y
end
if p !== SciMLBase.NullParameters() && p !== nothing
Expand All @@ -116,14 +116,14 @@ function instantiate_function(
end

if hv == true && f.hv === nothing
extras_hvp = prepare_hvp(_f, soadtype, x, zeros(eltype(x), size(x)))
prep_hvp = prepare_hvp(_f, soadtype, x, (zeros(eltype(x), size(x)),))
function hv!(H, θ, v)
hvp!(_f, H, soadtype, θ, v, extras_hvp)
only(hvp!(_f, (H,), prep_hvp, soadtype, θ, (v,)))
end
if p !== SciMLBase.NullParameters() && p !== nothing
function hv!(H, θ, v, p)
global _p = p
hvp!(_f, H, soadtype, θ, v)
only(hvp!(_f, (H,), soadtype, θ, (v,)))
end
end
elseif hv == true
Expand Down Expand Up @@ -156,9 +156,9 @@ function instantiate_function(
cons_jac_prototype = f.cons_jac_prototype
cons_jac_colorvec = f.cons_jac_colorvec
if cons !== nothing && cons_j == true && f.cons_j === nothing
extras_jac = prepare_jacobian(cons_oop, adtype, x)
prep_jac = prepare_jacobian(cons_oop, adtype, x)
function cons_j!(J, θ)
jacobian!(cons_oop, J, adtype, θ, extras_jac)
jacobian!(cons_oop, J, prep_jac, adtype, θ)
if size(J, 1) == 1
J = vec(J)
end
Expand All @@ -170,9 +170,9 @@ function instantiate_function(
end

if f.cons_vjp === nothing && cons_vjp == true && cons !== nothing
extras_pullback = prepare_pullback(cons_oop, adtype, x, ones(eltype(x), num_cons))
prep_pullback = prepare_pullback(cons_oop, adtype, x, (ones(eltype(x), num_cons),))
function cons_vjp!(J, θ, v)
pullback!(cons_oop, J, adtype, θ, v, extras_pullback)
only(pullback!(cons_oop, (J,), prep_pullback, adtype, θ, (v,)))
end
elseif cons_vjp == true && cons !== nothing
cons_vjp! = (J, θ, v) -> f.cons_vjp(J, θ, v, p)
Expand All @@ -181,10 +181,10 @@ function instantiate_function(
end

if f.cons_jvp === nothing && cons_jvp == true && cons !== nothing
extras_pushforward = prepare_pushforward(
cons_oop, adtype, x, ones(eltype(x), length(x)))
prep_pushforward = prepare_pushforward(
cons_oop, adtype, x, (ones(eltype(x), length(x)),))
function cons_jvp!(J, θ, v)
pushforward!(cons_oop, J, adtype, θ, v, extras_pushforward)
only(pushforward!(cons_oop, (J,), prep_pushforward, adtype, θ, (v,)))
end
elseif cons_jvp == true && cons !== nothing
cons_jvp! = (J, θ, v) -> f.cons_jvp(J, θ, v, p)
Expand All @@ -196,11 +196,11 @@ function instantiate_function(
conshess_colors = f.cons_hess_colorvec
if cons !== nothing && f.cons_h === nothing && cons_h == true
fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons]
extras_cons_hess = prepare_hessian.(fncs, Ref(soadtype), Ref(x))
prep_cons_hess = prepare_hessian.(fncs, Ref(soadtype), Ref(x))

function cons_h!(H, θ)
for i in 1:num_cons
hessian!(fncs[i], H[i], soadtype, θ, extras_cons_hess[i])
hessian!(fncs[i], H[i], prep_cons_hess[i], soadtype, θ)
end
end
elseif cons_h == true && cons !== nothing
Expand All @@ -212,7 +212,7 @@ function instantiate_function(
lag_hess_prototype = f.lag_hess_prototype

if cons !== nothing && lag_h == true && f.lag_h === nothing
lag_extras = prepare_hessian(
lag_prep = prepare_hessian(
lagrangian, soadtype, vcat(x, [one(eltype(x))], ones(eltype(x), num_cons)))
lag_hess_prototype = zeros(Bool, length(x) + num_cons + 1, length(x) + num_cons + 1)

Expand All @@ -221,13 +221,13 @@ function instantiate_function(
cons_h(H, θ)
H *= λ
else
H .= @view(hessian(lagrangian, soadtype, vcat(θ, [σ], λ), lag_extras)[
H .= @view(hessian(lagrangian, lag_prep, soadtype, vcat(θ, [σ], λ))[
1:length(θ), 1:length(θ)])
end
end

function lag_h!(h::AbstractVector, θ, σ, λ)
H = hessian(lagrangian, soadtype, vcat(θ, [σ], λ), lag_extras)
H = hessian(lagrangian, lag_prep, soadtype, vcat(θ, [σ], λ))
k = 0
for i in 1:length(θ)
for j in 1:i
Expand All @@ -244,14 +244,14 @@ function instantiate_function(
H *= λ
else
global _p = p
H .= @view(hessian(lagrangian, soadtype, vcat(θ, [σ], λ), lag_extras)[
H .= @view(hessian(lagrangian, lag_prep, soadtype, vcat(θ, [σ], λ))[
1:length(θ), 1:length(θ)])
end
end

function lag_h!(h::AbstractVector, θ, σ, λ, p)
global _p = p
H = hessian(lagrangian, soadtype, vcat(θ, [σ], λ), lag_extras)
H = hessian(lagrangian, lag_prep, soadtype, vcat(θ, [σ], λ))
k = 0
for i in 1:length(θ)
for j in 1:i
Expand Down Expand Up @@ -308,9 +308,9 @@ function instantiate_function(
adtype, soadtype = generate_adtype(adtype)

if g == true && f.grad === nothing
extras_grad = prepare_gradient(_f, adtype, x)
prep_grad = prepare_gradient(_f, adtype, x)
function grad(θ)
gradient(_f, adtype, θ, extras_grad)
gradient(_f, prep_grad, adtype, θ)
end
if p !== SciMLBase.NullParameters() && p !== nothing
function grad(θ, p)
Expand All @@ -326,10 +326,10 @@ function instantiate_function(

if fg == true && f.fg === nothing
if g == false
extras_grad = prepare_gradient(_f, adtype, x)
prep_grad = prepare_gradient(_f, adtype, x)
end
function fg!(θ)
(y, res) = value_and_gradient(_f, adtype, θ, extras_grad)
(y, res) = value_and_gradient(_f, prep_grad, adtype, θ)
return y, res
end
if p !== SciMLBase.NullParameters() && p !== nothing
Expand All @@ -348,9 +348,9 @@ function instantiate_function(
hess_sparsity = f.hess_prototype
hess_colors = f.hess_colorvec
if h == true && f.hess === nothing
extras_hess = prepare_hessian(_f, soadtype, x)
prep_hess = prepare_hessian(_f, soadtype, x)
function hess(θ)
hessian(_f, soadtype, θ, extras_hess)
hessian(_f, prep_hess, soadtype, θ)
end
if p !== SciMLBase.NullParameters() && p !== nothing
function hess(θ, p)
Expand All @@ -366,7 +366,7 @@ function instantiate_function(

if fgh == true && f.fgh === nothing
function fgh!(θ)
(y, G, H) = value_derivative_and_second_derivative(_f, adtype, θ, extras_hess)
(y, G, H) = value_derivative_and_second_derivative(_f, prep_hess, adtype, θ)
return y, G, H
end
if p !== SciMLBase.NullParameters() && p !== nothing
Expand All @@ -383,14 +383,14 @@ function instantiate_function(
end

if hv == true && f.hv === nothing
extras_hvp = prepare_hvp(_f, soadtype, x, zeros(eltype(x), size(x)))
prep_hvp = prepare_hvp(_f, soadtype, x, (zeros(eltype(x), size(x)),))
function hv!(θ, v)
hvp(_f, soadtype, θ, v, extras_hvp)
only(hvp(_f, prep_hvp, soadtype, θ, (v)))
end
if p !== SciMLBase.NullParameters() && p !== nothing
function hv!(θ, v, p)
global _p = p
hvp(_f, soadtype, θ, v, extras_hvp)
only(vp(_f, prep_hvp, soadtype, θ, (v,)))
end
end
elseif hv == true
Expand All @@ -417,9 +417,9 @@ function instantiate_function(
cons_jac_prototype = f.cons_jac_prototype
cons_jac_colorvec = f.cons_jac_colorvec
if cons !== nothing && cons_j == true && f.cons_j === nothing
extras_jac = prepare_jacobian(cons, adtype, x)
prep_jac = prepare_jacobian(cons, adtype, x)
function cons_j!(θ)
J = jacobian(cons, adtype, θ, extras_jac)
J = jacobian(cons, prep_jac, adtype, θ)
if size(J, 1) == 1
J = vec(J)
end
Expand All @@ -432,9 +432,9 @@ function instantiate_function(
end

if f.cons_vjp === nothing && cons_vjp == true && cons !== nothing
extras_pullback = prepare_pullback(cons, adtype, x, ones(eltype(x), num_cons))
prep_pullback = prepare_pullback(cons, adtype, x, (ones(eltype(x), num_cons),))
function cons_vjp!(θ, v)
return pullback(cons, adtype, θ, v, extras_pullback)
return only(pullback(cons, prep_pullback, adtype, θ, (v,)))
end
elseif cons_vjp == true && cons !== nothing
cons_vjp! = (θ, v) -> f.cons_vjp(θ, v, p)
Expand All @@ -443,10 +443,10 @@ function instantiate_function(
end

if f.cons_jvp === nothing && cons_jvp == true && cons !== nothing
extras_pushforward = prepare_pushforward(
cons, adtype, x, ones(eltype(x), length(x)))
prep_pushforward = prepare_pushforward(
cons, adtype, x, (ones(eltype(x), length(x)),))
function cons_jvp!(θ, v)
return pushforward(cons, adtype, θ, v, extras_pushforward)
return only(pushforward(cons, prep_pushforward, adtype, θ, (v,)))
end
elseif cons_jvp == true && cons !== nothing
cons_jvp! = (θ, v) -> f.cons_jvp(θ, v, p)
Expand All @@ -458,11 +458,11 @@ function instantiate_function(
conshess_colors = f.cons_hess_colorvec
if cons !== nothing && cons_h == true && f.cons_h === nothing
fncs = [(x) -> cons(x)[i] for i in 1:num_cons]
extras_cons_hess = prepare_hessian.(fncs, Ref(soadtype), Ref(x))
prep_cons_hess = prepare_hessian.(fncs, Ref(soadtype), Ref(x))

function cons_h!(θ)
H = map(1:num_cons) do i
hessian(fncs[i], soadtype, θ, extras_cons_hess[i])
hessian(fncs[i], prep_cons_hess[i], soadtype, θ)
end
return H
end
Expand All @@ -475,15 +475,15 @@ function instantiate_function(
lag_hess_prototype = f.lag_hess_prototype

if cons !== nothing && lag_h == true && f.lag_h === nothing
lag_extras = prepare_hessian(
lag_prep = prepare_hessian(
lagrangian, soadtype, vcat(x, [one(eltype(x))], ones(eltype(x), num_cons)))
lag_hess_prototype = zeros(Bool, length(x) + num_cons + 1, length(x) + num_cons + 1)

function lag_h!(θ, σ, λ)
if σ == zero(eltype(θ))
return λ .* cons_h(θ)
else
return hessian(lagrangian, soadtype, vcat(θ, [σ], λ), lag_extras)[
return hessian(lagrangian, lag_prep, soadtype, vcat(θ, [σ], λ))[
1:length(θ), 1:length(θ)]
end
end
Expand All @@ -494,7 +494,7 @@ function instantiate_function(
return λ .* cons_h(θ)
else
global _p = p
return hessian(lagrangian, soadtype, vcat(θ, [σ], λ), lag_extras)[
return hessian(lagrangian, lag_prep, soadtype, vcat(θ, [σ], λ))[
1:length(θ), 1:length(θ)]
end
end
Expand Down
Loading
Loading