diff --git a/src/datatypes.jl b/src/datatypes.jl index 674f625..268b6fc 100644 --- a/src/datatypes.jl +++ b/src/datatypes.jl @@ -375,6 +375,33 @@ A type for using indicator constraint approach for linear disjunctive constraint """ struct Indicator <: AbstractReformulationMethod end +""" + MOIDisjunction <: AbstractReformulationMethod +A reformulation type for reformulating disjunctions into their MathOptInterface +set representations [`DisjunctionSet`](@ref) which are then added to the model. +""" +struct MOIDisjunction <: AbstractReformulationMethod end + +""" + DisjunctionSet{S <: MOI.AbstractSet} <: MOI.AbstractVectorSet +A MathOptInterface set for representing disjunctions in the format vector of +functions in set to enable: +```julia +@constraint(model, [funcs...] in DisjunctionSet(n, idxs, sets)) +``` +where the vector of functions `funcs` is a flattened version of all the disjunct +constraint functions where the indicator variable is listed first, `n` is the +length of `[funcs...]`, `idxs` is a vector of the indices tracking where each +disjunct begins (i.e., it stores the indices of the indicator variables in +`[funcs...]`), and `sets` is a uses a vector of vectors structure for the MOI sets +that correspond to all the disjunct constraint functions. +""" +struct DisjunctionSet{S <: _MOI.AbstractSet} <: _MOI.AbstractVectorSet + dimension::Int + disjunct_indices::Vector{Int} + constraint_sets::Vector{Vector{S}} +end + ################################################################################ # GDP Data ################################################################################ diff --git a/src/moi.jl b/src/moi.jl new file mode 100644 index 0000000..259fc48 --- /dev/null +++ b/src/moi.jl @@ -0,0 +1,59 @@ +################################################################################ +# UTILITY METHODS +################################################################################ +# Requred for extensions to MOI.AbstractVectorSet +# function _MOI.Utilities.set_dot(x::AbstractVector, y::AbstractVector, set::DisjunctionSet) +# return LinearAlgebra.dot(x, y) # TODO figure out what we should actually do here +# end + +# TODO create a bridge for `DisjunctionSet` + +# TODO create helper method to unpack DisjunctionSet at the MOI side of things + +################################################################################ +# REFRORMULATION METHODS +################################################################################ +# Helper methods to handle recursively flattening the disjuncts +function _constr_set!(model, funcs, con::JuMP.AbstractConstraint) + append!(funcs, JuMP.jump_function(con)) + return JuMP.moi_set(con) +end +function _constr_set!(model, funcs, con::Disjunction) + inner_funcs, set = _disjunction_to_set(model, con) + append!(funcs, inner_funcs) + return set +end + +# Create the vectors needed for a disjunction vector constraint +function _disjunction_to_set(model::JuMP.Model, d::Disjunction) + # allocate memory for the storage vectors + num_disjuncts = length(d.indicators) + constr_mappings = _indicator_to_constraints(model) + num_constrs = sum(length(constr_mappings[lvref]) for lvref in d.indicators) + funcs = sizehint!(JuMP.AbstractJuMPScalar[], num_disjuncts + num_constrs) + sets = Vector{Vector{_MOI.AbstractSet}}(undef, num_disjuncts) + d_idxs = Vector{Int}(undef, num_disjuncts) + # iterate over the underlying disjuncts to fill in the storage vectors + for (i, lvref) in enumerate(d.indicators) + push!(funcs, _indicator_to_binary(model)[lvref]) + d_idxs[i] = length(funcs) + crefs = constr_mappings(model)[lvref] + sets[i] = map(c -> _constr_set!(model, funcs, JuMP.constraint_object(c)), crefs) + end + # convert the `sets` type to be concrete if possible (TODO benchmark if this is worth it) + SetType = typeof(first(sets)) + if SetType != Vector{_MOI.AbstractSet} && all(s -> s isa SetType, sets) + sets = convert(SetType, sets) + end + return funcs, DisjunctionSet(length(funcs), d_idxs, sets) +end + +# Extend the disjunction reformulation +function reformulate_disjunction( + model::JuMP.Model, + d::Disjunction, + ::MOIDisjunction + ) + funcs, set = _disjunction_to_set(model, d) + return [JuMP.VectorConstraint(funcs, set, JuMP.VectorShape())] +end \ No newline at end of file