diff --git a/src/Utilities/functions.jl b/src/Utilities/functions.jl index 5826319ec4..3e7ba824cd 100644 --- a/src/Utilities/functions.jl +++ b/src/Utilities/functions.jl @@ -1054,6 +1054,29 @@ function canonical(f::MOI.AbstractFunction) return g end +# Workaround: both `is_canonical` and `canonicalize!` would be slow otherwise +canonical(f::MOI.ScalarNonlinearFunction) = f + +function canonical(f::MOI.ScalarNonlinearFunction) + cache = Dict{MOI.AbstractScalarFunction,MOI.AbstractScalarFunction}() + # Don't use recursion here. This gets called for all scalar nonlinear + # constraints. + stack = Any[arg for arg in f.args] + while !isempty(stack) + arg = pop!(stack) + if arg isa MOI.ScalarNonlinearFunction + for a in arg.args + push!(stack, a) + end + else + if !is_canonical(arg) + return false + end + end + end + return true +end + canonicalize!(f::Union{MOI.VectorOfVariables,MOI.VariableIndex}) = f """ diff --git a/src/Utilities/vector_of_constraints.jl b/src/Utilities/vector_of_constraints.jl index b12b97bd85..edadfa0a2d 100644 --- a/src/Utilities/vector_of_constraints.jl +++ b/src/Utilities/vector_of_constraints.jl @@ -73,7 +73,7 @@ function MOI.add_constraint( ) where {F<:MOI.AbstractFunction,S<:MOI.AbstractSet} # We canonicalize the constraint so that solvers can avoid having to # canonicalize it most of the time (they can check if they need to with - # `is_canonical`. + # `is_canonical`). # Note that the canonicalization is not guaranteed if for instance # `modify` is called and adds a new term. # See https://github.com/jump-dev/MathOptInterface.jl/pull/1118 @@ -103,7 +103,14 @@ function MOI.get( ) where {F,S} MOI.throw_if_not_valid(v, ci) f, _ = v.constraints[ci]::Tuple{F,S} - return copy(f) + # Since `MA.mutability(MOI.ScalarNonlinearFunction)` is `MA.IsNotMutable`, + # this does not copy `MOI.ScalarNonlinearFunction`. This is important if the + # function share aliases of the same subexpression at different parts of + # it's expression graph or the expression graph of other functions of the + # model. If we `copy`, they won't be aliases of the same subexpression + # anymore hence `MOI.Nonlinear.ReverseAD` won't detect them as common + # subexpressions. + return MA.copy_if_mutable(f) end function MOI.get( diff --git a/src/functions.jl b/src/functions.jl index c1a5c36b53..214effee5c 100644 --- a/src/functions.jl +++ b/src/functions.jl @@ -360,7 +360,7 @@ function Base.copy(f::ScalarNonlinearFunction) # We need some sort of hint so that the next time we see this on the # stack we evaluate it using the args in `result_stack`. One option # would be a custom type. Or we can just wrap in (,) and then check - # for a Tuple, which isn't (curretly) a valid argument. + # for a Tuple, which isn't (currently) a valid argument. push!(stack, (arg,)) for child in arg.args push!(stack, child)