diff --git a/src/JuMP.jl b/src/JuMP.jl index 823896ff4f3..27e06d6313e 100644 --- a/src/JuMP.jl +++ b/src/JuMP.jl @@ -138,6 +138,8 @@ mutable struct GenericModel{T<:Real} <: AbstractModel # A dictionary to store timing information from the JuMP macros. enable_macro_timing::Bool macro_times::Dict{Tuple{LineNumberNode,String},Float64} + # We use `Any` as key because we haven't defined `GenericNonlinearExpr` yet + subexpressions::Dict{Any,MOI.ScalarNonlinearFunction} end value_type(::Type{GenericModel{T}}) where {T} = T @@ -251,6 +253,7 @@ function direct_generic_model( Dict{Any,MOI.ConstraintIndex}(), false, Dict{Tuple{LineNumberNode,String},Float64}(), + Dict{Any,MOI.ScalarNonlinearFunction}(), ) end diff --git a/src/constraints.jl b/src/constraints.jl index 157907ccd5e..1f9561612f5 100644 --- a/src/constraints.jl +++ b/src/constraints.jl @@ -760,6 +760,10 @@ function moi_function(constraint::AbstractConstraint) return moi_function(jump_function(constraint)) end +function moi_function(constraint::AbstractConstraint, model) + return moi_function(jump_function(constraint), model) +end + """ moi_set(constraint::AbstractConstraint) @@ -1016,6 +1020,17 @@ function _moi_add_constraint( return MOI.add_constraint(model, f, s) end +function check_belongs_to_model(f::Vector, model) + for func in f + check_belongs_to_model(func, model) + end +end + +function moi_function(f, model) + check_belongs_to_model(f, model) + return moi_function(f) +end + """ add_constraint( model::GenericModel, @@ -1032,10 +1047,9 @@ function add_constraint( name::String = "", ) con = model_convert(model, con) + func, set = moi_function(con, model), moi_set(con) # The type of backend(model) is unknown so we directly redirect to another # function. - check_belongs_to_model(con, model) - func, set = moi_function(con), moi_set(con) cindex = _moi_add_constraint( backend(model), func, diff --git a/src/nlp_expr.jl b/src/nlp_expr.jl index a3d777b8923..f398bdff7be 100644 --- a/src/nlp_expr.jl +++ b/src/nlp_expr.jl @@ -552,18 +552,31 @@ end moi_function(x::Number) = x -function moi_function(f::GenericNonlinearExpr{V}) where {V} +function moi_function( + f::GenericNonlinearExpr{V}, + model::JuMP.GenericModel, +) where {V} + cache = model.subexpressions + if haskey(cache, f) + return cache[f] + end ret = MOI.ScalarNonlinearFunction(f.head, similar(f.args)) stack = Tuple{MOI.ScalarNonlinearFunction,Int,GenericNonlinearExpr{V}}[] for i in length(f.args):-1:1 if f.args[i] isa GenericNonlinearExpr{V} push!(stack, (ret, i, f.args[i])) + elseif f.args[i] isa AbstractJuMPScalar + ret.args[i] = moi_function(model, f.args[i]) else ret.args[i] = moi_function(f.args[i]) end end while !isempty(stack) parent, i, arg = pop!(stack) + if haskey(cache, arg) + parent.args[i] = cache[arg] + continue + end child = MOI.ScalarNonlinearFunction(arg.head, similar(arg.args)) parent.args[i] = child for j in length(arg.args):-1:1 @@ -573,7 +586,9 @@ function moi_function(f::GenericNonlinearExpr{V}) where {V} child.args[j] = moi_function(arg.args[j]) end end + cache[arg] = child end + cache[f] = ret return ret end