Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ext/IpoptMathOptInterfaceExt/MOI_wrapper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1285,6 +1285,7 @@ function _setup_model(model::Optimizer)
eval_grad_f_cb,
eval_jac_g_cb,
has_hessian ? eval_h_cb : nothing,
nothing, # we could use the model in the future but it is a breaking change for the signature of the callback
)
if model.sense == MOI.MIN_SENSE
Ipopt.AddIpoptNumOption(model.inner, "obj_scaling_factor", 1.0)
Expand Down
93 changes: 71 additions & 22 deletions src/C_wrapper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Use of this source code is governed by an MIT-style license that can be found
# in the LICENSE.md file or at https://opensource.org/licenses/MIT.

mutable struct IpoptProblem
mutable struct IpoptProblem{M}
ipopt_problem::Ptr{Cvoid} # Reference to the internal data structure
n::Int # Num vars
m::Int # Num cons
Expand All @@ -21,6 +21,8 @@ mutable struct IpoptProblem
eval_jac_g::Function
eval_h::Union{Function,Nothing}
intermediate::Union{Function,Nothing}
# User data
user_model::M
end

Base.unsafe_convert(::Type{Ptr{Cvoid}}, p::IpoptProblem) = p.ipopt_problem
Expand All @@ -37,7 +39,11 @@ function _Eval_F_CB(
if x_new == Cint(1)
prob.x .= x
end
new_obj = convert(Float64, prob.eval_f(x))::Float64
if isnothing(prob.user_model)
new_obj = convert(Float64, prob.eval_f(x))::Float64
else
new_obj = convert(Float64, prob.eval_f(prob.user_model, x))::Float64
end
unsafe_store!(obj_value, new_obj)
return Cint(1)
end
Expand All @@ -53,7 +59,11 @@ function _Eval_Grad_F_CB(
prob = unsafe_pointer_to_objref(user_data)::IpoptProblem
new_grad_f = unsafe_wrap(Array, grad_f, Int(n))
x = unsafe_wrap(Array, x_ptr, Int(n))
prob.eval_grad_f(x, new_grad_f)
if isnothing(prob.user_model)
prob.eval_grad_f(x, new_grad_f)
else
prob.eval_grad_f(prob.user_model, x, new_grad_f)
end
return Cint(1)
end

Expand All @@ -71,7 +81,11 @@ function _Eval_G_CB(
if x_new == Cint(1)
prob.x .= x
end
prob.eval_g(x, new_g)
if isnothing(prob.user_model)
prob.eval_g(x, new_g)
else
prob.eval_g(prob.user_model, x, new_g)
end
return Cint(1)
end

Expand All @@ -91,10 +105,18 @@ function _Eval_Jac_G_CB(
rows = unsafe_wrap(Array, iRow, Int(nele_jac))
cols = unsafe_wrap(Array, jCol, Int(nele_jac))
if values_ptr == C_NULL
prob.eval_jac_g(x, rows, cols, nothing)
if isnothing(prob.user_model)
prob.eval_jac_g(x, rows, cols, nothing)
else
prob.eval_jac_g(prob.user_model, x, rows, cols, nothing)
end
else
values = unsafe_wrap(Array, values_ptr, Int(nele_jac))
prob.eval_jac_g(x, rows, cols, values)
if isnothing(prob.user_model)
prob.eval_jac_g(x, rows, cols, values)
else
prob.eval_jac_g(prob.user_model, x, rows, cols, values)
end
end
return Cint(1)
end
Expand Down Expand Up @@ -123,10 +145,18 @@ function _Eval_H_CB(
rows = unsafe_wrap(Array, iRow, Int(nele_hess))
cols = unsafe_wrap(Array, jCol, Int(nele_hess))
if values_ptr == C_NULL
prob.eval_h(x, rows, cols, obj_factor, lambda, nothing)
if isnothing(prob.user_model)
prob.eval_h(x, rows, cols, obj_factor, lambda, nothing)
else
prob.eval_h(prob.user_model, x, rows, cols, obj_factor, lambda, nothing)
end
else
values = unsafe_wrap(Array, values_ptr, Int(nele_hess))
prob.eval_h(x, rows, cols, obj_factor, lambda, values)
if isnothing(prob.user_model)
prob.eval_h(x, rows, cols, obj_factor, lambda, values)
else
prob.eval_h(prob.user_model, x, rows, cols, obj_factor, lambda, values)
end
end
return Cint(1) # Return TRUE for success.
end
Expand All @@ -148,19 +178,36 @@ function _Intermediate_CB(
try
return reenable_sigint() do
prob = unsafe_pointer_to_objref(user_data)::IpoptProblem
return prob.intermediate(
alg_mod,
iter_count,
obj_value,
inf_pr,
inf_du,
mu,
d_norm,
regularization_size,
alpha_du,
alpha_pr,
ls_trials,
)
if isnothing(prob.user_model)
return prob.intermediate(
alg_mod,
iter_count,
obj_value,
inf_pr,
inf_du,
mu,
d_norm,
regularization_size,
alpha_du,
alpha_pr,
ls_trials,
)
else
return prob.intermediate(
prob.user_model,
alg_mod,
iter_count,
obj_value,
inf_pr,
inf_du,
mu,
d_norm,
regularization_size,
alpha_du,
alpha_pr,
ls_trials,
)
end
end
catch err
if !(err isa InterruptException)
Expand All @@ -184,6 +231,7 @@ function CreateIpoptProblem(
eval_grad_f,
eval_jac_g,
eval_h,
user_model=nothing,
)
@assert n == length(x_L) == length(x_U)
@assert m == length(g_L) == length(g_U)
Expand Down Expand Up @@ -272,13 +320,14 @@ function CreateIpoptProblem(
zeros(Float64, n),
zeros(Float64, n),
0.0,
0,
Cint(0),
eval_f,
eval_g,
eval_grad_f,
eval_jac_g,
eval_h,
nothing,
user_model,
)
finalizer(FreeIpoptProblem, prob)
return prob
Expand Down
Loading