diff --git a/ext/IpoptMathOptInterfaceExt/MOI_wrapper.jl b/ext/IpoptMathOptInterfaceExt/MOI_wrapper.jl index 314d213..8d99ddb 100644 --- a/ext/IpoptMathOptInterfaceExt/MOI_wrapper.jl +++ b/ext/IpoptMathOptInterfaceExt/MOI_wrapper.jl @@ -1285,6 +1285,7 @@ function _setup_model(model::Optimizer) eval_grad_f_cb, eval_jac_g_cb, has_hessian ? eval_h_cb : nothing, + nothing, # we could use the model in the future but it is a breaking change for the signature of the callback ) if model.sense == MOI.MIN_SENSE Ipopt.AddIpoptNumOption(model.inner, "obj_scaling_factor", 1.0) diff --git a/src/C_wrapper.jl b/src/C_wrapper.jl index 9f28c3a..da13d10 100644 --- a/src/C_wrapper.jl +++ b/src/C_wrapper.jl @@ -3,7 +3,7 @@ # Use of this source code is governed by an MIT-style license that can be found # in the LICENSE.md file or at https://opensource.org/licenses/MIT. -mutable struct IpoptProblem +mutable struct IpoptProblem{M} ipopt_problem::Ptr{Cvoid} # Reference to the internal data structure n::Int # Num vars m::Int # Num cons @@ -21,6 +21,8 @@ mutable struct IpoptProblem eval_jac_g::Function eval_h::Union{Function,Nothing} intermediate::Union{Function,Nothing} + # User data + user_model::M end Base.unsafe_convert(::Type{Ptr{Cvoid}}, p::IpoptProblem) = p.ipopt_problem @@ -37,7 +39,11 @@ function _Eval_F_CB( if x_new == Cint(1) prob.x .= x end - new_obj = convert(Float64, prob.eval_f(x))::Float64 + if isnothing(prob.user_model) + new_obj = convert(Float64, prob.eval_f(x))::Float64 + else + new_obj = convert(Float64, prob.eval_f(prob.user_model, x))::Float64 + end unsafe_store!(obj_value, new_obj) return Cint(1) end @@ -53,7 +59,11 @@ function _Eval_Grad_F_CB( prob = unsafe_pointer_to_objref(user_data)::IpoptProblem new_grad_f = unsafe_wrap(Array, grad_f, Int(n)) x = unsafe_wrap(Array, x_ptr, Int(n)) - prob.eval_grad_f(x, new_grad_f) + if isnothing(prob.user_model) + prob.eval_grad_f(x, new_grad_f) + else + prob.eval_grad_f(prob.user_model, x, new_grad_f) + end return Cint(1) end @@ -71,7 +81,11 @@ function _Eval_G_CB( if x_new == Cint(1) prob.x .= x end - prob.eval_g(x, new_g) + if isnothing(prob.user_model) + prob.eval_g(x, new_g) + else + prob.eval_g(prob.user_model, x, new_g) + end return Cint(1) end @@ -91,10 +105,18 @@ function _Eval_Jac_G_CB( rows = unsafe_wrap(Array, iRow, Int(nele_jac)) cols = unsafe_wrap(Array, jCol, Int(nele_jac)) if values_ptr == C_NULL - prob.eval_jac_g(x, rows, cols, nothing) + if isnothing(prob.user_model) + prob.eval_jac_g(x, rows, cols, nothing) + else + prob.eval_jac_g(prob.user_model, x, rows, cols, nothing) + end else values = unsafe_wrap(Array, values_ptr, Int(nele_jac)) - prob.eval_jac_g(x, rows, cols, values) + if isnothing(prob.user_model) + prob.eval_jac_g(x, rows, cols, values) + else + prob.eval_jac_g(prob.user_model, x, rows, cols, values) + end end return Cint(1) end @@ -123,10 +145,18 @@ function _Eval_H_CB( rows = unsafe_wrap(Array, iRow, Int(nele_hess)) cols = unsafe_wrap(Array, jCol, Int(nele_hess)) if values_ptr == C_NULL - prob.eval_h(x, rows, cols, obj_factor, lambda, nothing) + if isnothing(prob.user_model) + prob.eval_h(x, rows, cols, obj_factor, lambda, nothing) + else + prob.eval_h(prob.user_model, x, rows, cols, obj_factor, lambda, nothing) + end else values = unsafe_wrap(Array, values_ptr, Int(nele_hess)) - prob.eval_h(x, rows, cols, obj_factor, lambda, values) + if isnothing(prob.user_model) + prob.eval_h(x, rows, cols, obj_factor, lambda, values) + else + prob.eval_h(prob.user_model, x, rows, cols, obj_factor, lambda, values) + end end return Cint(1) # Return TRUE for success. end @@ -148,19 +178,36 @@ function _Intermediate_CB( try return reenable_sigint() do prob = unsafe_pointer_to_objref(user_data)::IpoptProblem - return prob.intermediate( - alg_mod, - iter_count, - obj_value, - inf_pr, - inf_du, - mu, - d_norm, - regularization_size, - alpha_du, - alpha_pr, - ls_trials, - ) + if isnothing(prob.user_model) + return prob.intermediate( + alg_mod, + iter_count, + obj_value, + inf_pr, + inf_du, + mu, + d_norm, + regularization_size, + alpha_du, + alpha_pr, + ls_trials, + ) + else + return prob.intermediate( + prob.user_model, + alg_mod, + iter_count, + obj_value, + inf_pr, + inf_du, + mu, + d_norm, + regularization_size, + alpha_du, + alpha_pr, + ls_trials, + ) + end end catch err if !(err isa InterruptException) @@ -184,6 +231,7 @@ function CreateIpoptProblem( eval_grad_f, eval_jac_g, eval_h, + user_model=nothing, ) @assert n == length(x_L) == length(x_U) @assert m == length(g_L) == length(g_U) @@ -272,13 +320,14 @@ function CreateIpoptProblem( zeros(Float64, n), zeros(Float64, n), 0.0, - 0, + Cint(0), eval_f, eval_g, eval_grad_f, eval_jac_g, eval_h, nothing, + user_model, ) finalizer(FreeIpoptProblem, prob) return prob