Skip to content

Commit 8815384

Browse files
committed
Merge check_belongs_to_model with moi_function
1 parent 6cfdc73 commit 8815384

File tree

2 files changed

+18
-13
lines changed

2 files changed

+18
-13
lines changed

src/constraints.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -760,6 +760,10 @@ function moi_function(constraint::AbstractConstraint)
760760
return moi_function(jump_function(constraint))
761761
end
762762

763+
function moi_function(constraint::AbstractConstraint, model)
764+
return moi_function(jump_function(constraint), model)
765+
end
766+
763767
"""
764768
moi_set(constraint::AbstractConstraint)
765769
@@ -1016,6 +1020,11 @@ function _moi_add_constraint(
10161020
return MOI.add_constraint(model, f, s)
10171021
end
10181022

1023+
function moi_function(f, model)
1024+
check_belongs_to_model(f, model)
1025+
return moi_function(f)
1026+
end
1027+
10191028
"""
10201029
add_constraint(
10211030
model::GenericModel,
@@ -1032,10 +1041,9 @@ function add_constraint(
10321041
name::String = "",
10331042
)
10341043
con = model_convert(model, con)
1044+
func, set = moi_function(con, model), moi_set(con)
10351045
# The type of backend(model) is unknown so we directly redirect to another
10361046
# function.
1037-
check_belongs_to_model(con, model)
1038-
func, set = moi_function(con), moi_set(con)
10391047
cindex = _moi_add_constraint(
10401048
backend(model),
10411049
func,

src/nlp_expr.jl

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -552,24 +552,25 @@ end
552552

553553
moi_function(x::Number) = x
554554

555-
function moi_function(f::GenericNonlinearExpr{V}) where {V}
556-
model = owner_model(f)
557-
cache = isnothing(model) ? nothing : model.subexpressions
558-
if !isnothing(cache) && haskey(cache, f)
555+
function moi_function(f::GenericNonlinearExpr{V}, model::JuMP.GenericModel) where {V}
556+
cache = model.subexpressions
557+
if haskey(cache, f)
559558
return cache[f]
560559
end
561560
ret = MOI.ScalarNonlinearFunction(f.head, similar(f.args))
562561
stack = Tuple{MOI.ScalarNonlinearFunction,Int,GenericNonlinearExpr{V}}[]
563562
for i in length(f.args):-1:1
564563
if f.args[i] isa GenericNonlinearExpr{V}
565564
push!(stack, (ret, i, f.args[i]))
565+
elseif f.args[i] isa AbstractJuMPScalar
566+
ret.args[i] = moi_function(model, f.args[i])
566567
else
567568
ret.args[i] = moi_function(f.args[i])
568569
end
569570
end
570571
while !isempty(stack)
571572
parent, i, arg = pop!(stack)
572-
if !isnothing(cache) && haskey(cache, arg)
573+
if haskey(cache, arg)
573574
parent.args[i] = cache[arg]
574575
continue
575576
end
@@ -582,13 +583,9 @@ function moi_function(f::GenericNonlinearExpr{V}) where {V}
582583
child.args[j] = moi_function(arg.args[j])
583584
end
584585
end
585-
if !isnothing(cache)
586-
cache[arg] = child
587-
end
588-
end
589-
if !isnothing(cache)
590-
cache[f] = ret
586+
cache[arg] = child
591587
end
588+
cache[f] = ret
592589
return ret
593590
end
594591

0 commit comments

Comments
 (0)