Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/JuMP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ mutable struct GenericModel{T<:Real} <: AbstractModel
# A dictionary to store timing information from the JuMP macros.
enable_macro_timing::Bool
macro_times::Dict{Tuple{LineNumberNode,String},Float64}
# We use `Any` as key because we haven't defined `GenericNonlinearExpr` yet
subexpressions::Dict{Any,MOI.ScalarNonlinearFunction}
end

value_type(::Type{GenericModel{T}}) where {T} = T
Expand Down Expand Up @@ -251,6 +253,7 @@ function direct_generic_model(
Dict{Any,MOI.ConstraintIndex}(),
false,
Dict{Tuple{LineNumberNode,String},Float64}(),
Dict{Any,MOI.ScalarNonlinearFunction}(),
)
end

Expand Down
18 changes: 16 additions & 2 deletions src/constraints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,10 @@ function moi_function(constraint::AbstractConstraint)
return moi_function(jump_function(constraint))
end

function moi_function(constraint::AbstractConstraint, model)
return moi_function(jump_function(constraint), model)
end

"""
moi_set(constraint::AbstractConstraint)

Expand Down Expand Up @@ -1016,6 +1020,17 @@ function _moi_add_constraint(
return MOI.add_constraint(model, f, s)
end

function check_belongs_to_model(f::Vector, model)
for func in f
check_belongs_to_model(func, model)
end
end

function moi_function(f, model)
check_belongs_to_model(f, model)
return moi_function(f)
end

"""
add_constraint(
model::GenericModel,
Expand All @@ -1032,10 +1047,9 @@ function add_constraint(
name::String = "",
)
con = model_convert(model, con)
func, set = moi_function(con, model), moi_set(con)
# The type of backend(model) is unknown so we directly redirect to another
# function.
check_belongs_to_model(con, model)
func, set = moi_function(con), moi_set(con)
cindex = _moi_add_constraint(
backend(model),
func,
Expand Down
17 changes: 16 additions & 1 deletion src/nlp_expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -552,18 +552,31 @@ end

moi_function(x::Number) = x

function moi_function(f::GenericNonlinearExpr{V}) where {V}
function moi_function(
f::GenericNonlinearExpr{V},
model::JuMP.GenericModel,
) where {V}
cache = model.subexpressions
if haskey(cache, f)
return cache[f]
end
ret = MOI.ScalarNonlinearFunction(f.head, similar(f.args))
stack = Tuple{MOI.ScalarNonlinearFunction,Int,GenericNonlinearExpr{V}}[]
for i in length(f.args):-1:1
if f.args[i] isa GenericNonlinearExpr{V}
push!(stack, (ret, i, f.args[i]))
elseif f.args[i] isa AbstractJuMPScalar
ret.args[i] = moi_function(model, f.args[i])
else
ret.args[i] = moi_function(f.args[i])
end
end
while !isempty(stack)
parent, i, arg = pop!(stack)
if haskey(cache, arg)
parent.args[i] = cache[arg]
continue
end
child = MOI.ScalarNonlinearFunction(arg.head, similar(arg.args))
parent.args[i] = child
for j in length(arg.args):-1:1
Expand All @@ -573,7 +586,9 @@ function moi_function(f::GenericNonlinearExpr{V}) where {V}
child.args[j] = moi_function(arg.args[j])
end
end
cache[arg] = child
end
cache[f] = ret
return ret
end

Expand Down
Loading