Skip to content

Commit d8f2ab1

Browse files
committed
Add tests
1 parent 8c2c275 commit d8f2ab1

File tree

2 files changed

+90
-6
lines changed

2 files changed

+90
-6
lines changed

src/Utilities/sparse_matrix.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -183,19 +183,27 @@ function Base.convert(
183183
)
184184
end
185185

186+
_indexing(A::MutableSparseMatrixCSC) = A.indexing
187+
_indexing(::SparseArrays.SparseMatrixCSC) = OneBasedIndexing()
188+
189+
const _SparseMatrixCSC{Tv,Ti} = Union{
190+
MutableSparseMatrixCSC{Tv,Ti},
191+
SparseArrays.SparseMatrixCSC{Tv,Ti}
192+
}
193+
186194
function _first_in_column(
187-
A::MutableSparseMatrixCSC,
195+
A::_SparseMatrixCSC,
188196
row::Integer,
189197
col::Integer,
190198
)
191199
range = SparseArrays.nzrange(A, col)
192-
row = _shift(row, OneBasedIndexing(), A.indexing)
200+
row = _shift(row, OneBasedIndexing(), _indexing(A))
193201
idx = searchsortedfirst(view(A.rowval, range), row)
194202
return get(range, idx, last(range) + 1)
195203
end
196204

197205
function extract_function(
198-
A::MutableSparseMatrixCSC{T},
206+
A::_SparseMatrixCSC{T},
199207
row::Integer,
200208
constant::T,
201209
) where {T}
@@ -205,7 +213,7 @@ function extract_function(
205213
if idx > last(SparseArrays.nzrange(A, col))
206214
continue
207215
end
208-
r = _shift(A.rowval[idx], A.indexing, OneBasedIndexing())
216+
r = _shift(A.rowval[idx], _indexing(A), OneBasedIndexing())
209217
if r == row
210218
push!(
211219
func.terms,
@@ -217,7 +225,7 @@ function extract_function(
217225
end
218226

219227
function extract_function(
220-
A::MutableSparseMatrixCSC{T},
228+
A::_SparseMatrixCSC{T},
221229
rows::UnitRange,
222230
constants::Vector{T},
223231
) where {T}
@@ -231,7 +239,7 @@ function extract_function(
231239
if idx[col] > last(SparseArrays.nzrange(A, col))
232240
continue
233241
end
234-
row = _shift(A.rowval[idx[col]], A.indexing, OneBasedIndexing())
242+
row = _shift(A.rowval[idx[col]], _indexing(A), OneBasedIndexing())
235243
if row != rows[output_index]
236244
continue
237245
end

test/Utilities/matrix_of_constraints.jl

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -714,6 +714,82 @@ function test_unsupported_constraint()
714714
return
715715
end
716716

717+
MOI.Utilities.@product_of_sets(
718+
_EqualTos,
719+
MOI.EqualTo{T},
720+
)
721+
722+
function _equality_constraints(A::AbstractMatrix{T}, b::AbstractVector{T}) where {T}
723+
sets = _EqualTos{T}()
724+
for _ in eachindex(b)
725+
MOI.Utilities.add_set(sets, MOI.Utilities.set_index(sets, MOI.EqualTo{T}))
726+
end
727+
MOI.Utilities.final_touch(sets)
728+
constants = MOI.Utilities.Hyperrectangle(b, b)
729+
model = MOI.Utilities.MatrixOfConstraints{T}(A, constants, sets)
730+
model.final_touch = true
731+
return model
732+
end
733+
734+
735+
# Inspired from MatrixOfConstraints
736+
function test_lp_standard_form()
737+
s = """
738+
variables: x1, x2
739+
cx1: x1 >= 0.0
740+
cx2: x2 >= 0.0
741+
c1: 1x1 == 5.0
742+
c2: 3x1 + 4x2 == 6.0
743+
minobjective: 7x1 + 8x2
744+
"""
745+
expected = MOI.Utilities.Model{Float64}()
746+
MOI.Utilities.loadfromstring!(expected, s)
747+
748+
var_names = ["x1", "x2"]
749+
con_names = ["c1", "c2"]
750+
751+
A = SparseArrays.sparse([
752+
1.0 0.0
753+
3.0 4.0
754+
])
755+
b = [5.0, 6.0]
756+
form = MOI.Utilities.GenericModel{Float16}(
757+
expected.objective,
758+
expected.variables,
759+
_equality_constraints(A, b),
760+
)
761+
762+
model = MOI.Utilities.Model{Float64}()
763+
MOI.copy_to(
764+
MOI.Bridges.Constraint.Scalarize{Float64}(model),
765+
form,
766+
)
767+
MOI.set(
768+
model,
769+
MOI.VariableName(),
770+
MOI.VariableIndex.(eachindex(var_names)),
771+
var_names,
772+
)
773+
MOI.set(
774+
model,
775+
MOI.ConstraintName(),
776+
MOI.ConstraintIndex{
777+
MOI.ScalarAffineFunction{Float64},
778+
MOI.EqualTo{Float64},
779+
}.(eachindex(con_names)),
780+
con_names,
781+
)
782+
783+
MOI.Test.util_test_models_equal(
784+
model,
785+
expected,
786+
var_names,
787+
con_names,
788+
)
789+
790+
return
791+
end
792+
717793
end
718794

719795
TestMatrixOfConstraints.runtests()

0 commit comments

Comments
 (0)