Skip to content

Commit ce419fc

Browse files
committed
Updates
1 parent c126a57 commit ce419fc

File tree

4 files changed

+146
-43
lines changed

4 files changed

+146
-43
lines changed

src/Bridges/Variable/bridges/set_dot.jl

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
struct DotProductsBridge{T,S,V} <: SetMapBridge{T,S,MOI.SetWithDotProducts{S,V}}
22
variables::Vector{MOI.VariableIndex}
3-
constraint::MOI.ConstraintIndex{MOI.VectorOfVariables, S}
3+
constraint::MOI.ConstraintIndex{MOI.VectorOfVariables,S}
44
set::MOI.SetWithDotProducts{S,V}
55
end
66

@@ -28,10 +28,7 @@ function bridge_constrained_variable(
2828
return BT(variables, constraint, set)
2929
end
3030

31-
function MOI.Bridges.map_set(
32-
bridge::DotProductsBridge{T,S},
33-
set::S,
34-
) where {T,S}
31+
function MOI.Bridges.map_set(bridge::DotProductsBridge{T,S}, set::S) where {T,S}
3532
return MOI.SetWithDotProducts(set, bridge.vectors)
3633
end
3734

@@ -49,14 +46,27 @@ function MOI.Bridges.map_function(
4946
) where {T}
5047
scalars = MOI.Utilities.eachscalar(func)
5148
if i.value in eachindex(bridge.set.vectors)
52-
return MOI.Utilities.set_dot(bridge.set.vectors[i.value], scalars, bridge.set.set)
49+
return MOI.Utilities.set_dot(
50+
bridge.set.vectors[i.value],
51+
scalars,
52+
bridge.set.set,
53+
)
5354
else
54-
return convert(MOI.ScalarAffineFunction{T}, scalars[i.value - length(bridge.vectors)])
55+
return convert(
56+
MOI.ScalarAffineFunction{T},
57+
scalars[i.value-length(bridge.vectors)],
58+
)
5559
end
5660
end
5761

58-
function MOI.Bridges.inverse_map_function(bridge::DotProductsBridge{T}, func) where {T}
62+
function MOI.Bridges.inverse_map_function(
63+
bridge::DotProductsBridge{T},
64+
func,
65+
) where {T}
5966
m = length(bridge.set.vectors)
60-
return MOI.Utilities.operate(vcat, T, MOI.Utilities.eachscalar(func)[(m+1):end])
67+
return MOI.Utilities.operate(
68+
vcat,
69+
T,
70+
MOI.Utilities.eachscalar(func)[(m+1):end],
71+
)
6172
end
62-

src/Utilities/functions.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,8 @@ end
638638
A type that allows iterating over the scalar-functions that comprise an
639639
`AbstractVectorFunction`.
640640
"""
641-
struct ScalarFunctionIterator{F<:MOI.AbstractVectorFunction,C,S} <: AbstractVector{S}
641+
struct ScalarFunctionIterator{F<:MOI.AbstractVectorFunction,C,S} <:
642+
AbstractVector{S}
642643
f::F
643644
# Cache that can be used to store a precomputed datastructure that allows
644645
# an efficient implementation of `getindex`.

src/sets.jl

Lines changed: 88 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1804,73 +1804,129 @@ function Base.getproperty(
18041804
end
18051805

18061806
"""
1807-
SetWithDotProducts(set, vectors::Vector{V})
1807+
SetWithDotProducts(set::MOI.AbstractSet, vectors::AbstractVector)
18081808
1809-
Given a set `set` of dimension `d` and `m` vectors `v_1`, ..., `v_m` given in `vectors`, this is the set:
1810-
``\\{ ((\\langle v_1, x \\rangle, ..., \\langle v_m, x \\rangle, x) \\in \\mathbb{R}^{m + d} : x \\in \\text{set} \\}.``
1809+
Given a set `set` of dimension `d` and `m` vectors `a_1`, ..., `a_m` given in `vectors`, this is the set:
1810+
``\\{ ((\\langle a_1, x \\rangle, ..., \\langle a_m, x \\rangle) \\in \\mathbb{R}^{m} : x \\in \\text{set} \\}.``
18111811
"""
1812-
struct SetWithDotProducts{S,V} <: AbstractVectorSet
1812+
struct SetWithDotProducts{S,A,V<:AbstractVector{A}} <: AbstractVectorSet
18131813
set::S
1814-
vectors::Vector{V}
1814+
vectors::V
18151815
end
18161816

18171817
function Base.copy(s::SetWithDotProducts)
18181818
return SetWithDotProducts(copy(s.set), copy(s.vectors))
18191819
end
18201820

1821-
function dimension(s::SetWithDotProducts)
1822-
return length(s.vectors) + dimension(s.set)
1823-
end
1821+
dimension(s::SetWithDotProducts) = length(s.vectors)
18241822

18251823
function dual_set(s::SetWithDotProducts)
18261824
return LinearCombinationInSet(s.set, s.vectors)
18271825
end
18281826

1829-
function dual_set_type(::Type{SetWithDotProducts{S,V}}) where {S,V}
1830-
return LinearCombinationInSet{S,V}
1827+
function dual_set_type(::Type{SetWithDotProducts{S,A,V}}) where {S,A,V}
1828+
return LinearCombinationInSet{S,A,V}
18311829
end
18321830

18331831
"""
1834-
LinearCombinationInSet{S,V}(set::S, matrices::Vector{V})
1832+
LinearCombinationInSet(set::MOI.AbstractSet, matrices::AbstractVector)
18351833
1836-
Given a set `set` of dimension `d` and `m` vectors `v_1`, ..., `v_m` given in `vectors`, this is the set:
1837-
``\\{ ((y, x) \\in \\mathbb{R}^{m + d} : \\sum_{i=1}^m y_i v_i + x \\in \\text{set} \\}.``
1834+
Given a set `set` of dimension `d` and `m` vectors `a_1`, ..., `a_m` given in `vectors`, this is the set:
1835+
``\\{ (y \\in \\mathbb{R}^{m} : \\sum_{i=1}^m y_i a_i \\in \\text{set} \\}.``
18381836
"""
1839-
struct LinearCombinationInSet{S,V} <: AbstractVectorSet
1837+
struct LinearCombinationInSet{S,A,V<:AbstractVector{A}} <: AbstractVectorSet
18401838
set::S
1841-
vectors::Vector{V}
1839+
vectors::V
18421840
end
18431841

1844-
dimension(s::LinearCombinationInSet) = length(s.vectors) + simension(s.set)
1842+
dimension(s::LinearCombinationInSet) = length(s.vectors)
18451843

18461844
function dual_set(s::LinearCombinationInSet)
1847-
return SetWithDotProducts(s.side_dimension, s.matrices)
1845+
return SetWithDotProducts(s.side_dimension, s.vectors)
1846+
end
1847+
1848+
function dual_set_type(::Type{LinearCombinationInSet{S,A,V}}) where {S,A,V}
1849+
return SetWithDotProducts{S,A,V}
18481850
end
18491851

1850-
function dual_set_type(::Type{LinearCombinationInSet{S,V}}) where {S,V}
1851-
return SetWithDotProducts{S,V}
1852+
abstract type AbstractFactorization{T,F} <: AbstractMatrix{T} end
1853+
1854+
function Base.size(m::AbstractFactorization)
1855+
n = size(m.factor, 1)
1856+
return (n, n)
18521857
end
18531858

18541859
"""
1855-
struct LowRankMatrix{T}
1856-
diagonal::Vector{T}
1857-
factor::Matrix{T}
1860+
struct Factorization{
1861+
T,
1862+
F<:Union{AbstractVector{T},AbstractMatrix{T}},
1863+
D<:Union{T,AbstractVector{T}},
1864+
} <: AbstractMatrix{T}
1865+
factor::F
1866+
scaling::D
18581867
end
18591868
1860-
`factor * Diagonal(diagonal) * factor'`.
1861-
"""
1862-
struct LowRankMatrix{T} <: AbstractMatrix{T}
1863-
diagonal::Vector{T}
1864-
factor::Matrix{T}
1869+
Matrix corresponding to `factor * Diagonal(diagonal) * factor'`.
1870+
If `factor` is a vector and `diagonal` is a scalar, this corresponds to
1871+
the matrix `diagonal * factor * factor'`.
1872+
If `factor` is a matrix and `diagonal` is a vector, this corresponds to
1873+
the matrix `factor * Diagonal(scaling) * factor'`.
1874+
"""
1875+
struct Factorization{
1876+
T,
1877+
F<:Union{AbstractVector{T},AbstractMatrix{T}},
1878+
D<:Union{T,AbstractVector{T}},
1879+
} <: AbstractFactorization{T,F}
1880+
factor::F
1881+
scaling::D
1882+
function Factorization(
1883+
factor::AbstractMatrix{T},
1884+
scaling::AbstractVector{T},
1885+
) where {T}
1886+
if length(scaling) != size(factor, 2)
1887+
error(
1888+
"Length `$(length(scaling))` of diagonal does not match number of columns `$(size(factor, 2))` of factor",
1889+
)
1890+
end
1891+
return new{T,typeof(factor),typeof(scaling)}(factor, scaling)
1892+
end
1893+
function Factorization(factor::AbstractVector{T}, scaling::T) where {T}
1894+
return new{T,typeof(factor),typeof(scaling)}(factor, scaling)
1895+
end
18651896
end
18661897

1867-
function Base.size(m::LowRankMatrix)
1868-
n = size(m.factor, 1)
1869-
return (n, n)
1898+
function Base.getindex(m::Factorization, i::Int, j::Int)
1899+
return sum(
1900+
m.factor[i, k] * m.scaling[k] * m.factor[j, k]' for
1901+
k in eachindex(m.scaling)
1902+
)
1903+
end
1904+
1905+
"""
1906+
struct Factorization{
1907+
T,
1908+
F<:Union{AbstractVector{T},AbstractMatrix{T}},
1909+
D<:Union{T,AbstractVector{T}},
1910+
} <: AbstractMatrix{T}
1911+
factor::F
1912+
scaling::D
1913+
end
1914+
1915+
Matrix corresponding to `factor * Diagonal(diagonal) * factor'`.
1916+
If `factor` is a vector and `diagonal` is a scalar, this corresponds to
1917+
the matrix `diagonal * factor * factor'`.
1918+
If `factor` is a matrix and `diagonal` is a vector, this corresponds to
1919+
the matrix `factor * Diagonal(scaling) * factor'`.
1920+
"""
1921+
struct PositiveSemidefiniteFactorization{
1922+
T,
1923+
F<:Union{AbstractVector{T},AbstractMatrix{T}},
1924+
} <: AbstractFactorization{T,F}
1925+
factor::F
18701926
end
18711927

1872-
function Base.getindex(m::LowRankMatrix, i::Int, j::Int)
1873-
return sum(m.factor[i, k] * m.diagonal[k] * m.factor[j, k]' for k in eachindex(m.diagonal))
1928+
function Base.getindex(m::PositiveSemidefiniteFactorization, i::Int, j::Int)
1929+
return sum(m.factor[i, k] * m.factor[j, k]' for k in axes(m.factor, 2))
18741930
end
18751931

18761932
struct TriangleVectorization{T,M<:AbstractMatrix{T}} <: AbstractVector{T}

test/sets.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ module TestSets
88

99
using Test
1010
import MathOptInterface as MOI
11+
import LinearAlgebra
1112

1213
include("dummy.jl")
1314

@@ -353,6 +354,41 @@ function test_sets_reified()
353354
return
354355
end
355356

357+
function _test_factorization(A, B)
358+
@test size(A) == size(B)
359+
@test A B
360+
d = LinearAlgebra.checksquare(A)
361+
n = div(d * (d + 1), 2)
362+
vA = MOI.TriangleVectorization(A)
363+
@test length(vA) == n
364+
@test eachindex(vA) == Base.OneTo(n)
365+
vB = MOI.TriangleVectorization(B)
366+
@test length(vB) == n
367+
@test eachindex(vA) == Base.OneTo(n)
368+
k = 0
369+
for j in 1:d
370+
for i in 1:j
371+
k += 1
372+
@test vA[k] == vB[k]
373+
@test vA[k] == A[i, j]
374+
end
375+
end
376+
return
377+
end
378+
379+
function test_factorizations()
380+
f = [1, 2]
381+
_test_factorization(f * f', MOI.PositiveSemidefiniteFactorization(f))
382+
_test_factorization(2 * f * f', MOI.Factorization(f, 2))
383+
F = [1 2; 3 4; 5 6]
384+
d = [7, 8]
385+
_test_factorization(F * F', MOI.PositiveSemidefiniteFactorization(F))
386+
return _test_factorization(
387+
F * LinearAlgebra.Diagonal(d) * F',
388+
MOI.Factorization(F, d),
389+
)
390+
end
391+
356392
function runtests()
357393
for name in names(@__MODULE__; all = true)
358394
if startswith("$name", "test_")

0 commit comments

Comments
 (0)