Skip to content

Commit 1b312df

Browse files
committed
Updates
1 parent 7b782f5 commit 1b312df

File tree

4 files changed

+146
-43
lines changed

4 files changed

+146
-43
lines changed

src/Bridges/Variable/bridges/set_dot.jl

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
struct DotProductsBridge{T,S,V} <: SetMapBridge{T,S,MOI.SetWithDotProducts{S,V}}
22
variables::Vector{MOI.VariableIndex}
3-
constraint::MOI.ConstraintIndex{MOI.VectorOfVariables, S}
3+
constraint::MOI.ConstraintIndex{MOI.VectorOfVariables,S}
44
set::MOI.SetWithDotProducts{S,V}
55
end
66

@@ -28,10 +28,7 @@ function bridge_constrained_variable(
2828
return BT(variables, constraint, set)
2929
end
3030

31-
function MOI.Bridges.map_set(
32-
bridge::DotProductsBridge{T,S},
33-
set::S,
34-
) where {T,S}
31+
function MOI.Bridges.map_set(bridge::DotProductsBridge{T,S}, set::S) where {T,S}
3532
return MOI.SetWithDotProducts(set, bridge.vectors)
3633
end
3734

@@ -49,14 +46,27 @@ function MOI.Bridges.map_function(
4946
) where {T}
5047
scalars = MOI.Utilities.eachscalar(func)
5148
if i.value in eachindex(bridge.set.vectors)
52-
return MOI.Utilities.set_dot(bridge.set.vectors[i.value], scalars, bridge.set.set)
49+
return MOI.Utilities.set_dot(
50+
bridge.set.vectors[i.value],
51+
scalars,
52+
bridge.set.set,
53+
)
5354
else
54-
return convert(MOI.ScalarAffineFunction{T}, scalars[i.value - length(bridge.vectors)])
55+
return convert(
56+
MOI.ScalarAffineFunction{T},
57+
scalars[i.value-length(bridge.vectors)],
58+
)
5559
end
5660
end
5761

58-
function MOI.Bridges.inverse_map_function(bridge::DotProductsBridge{T}, func) where {T}
62+
function MOI.Bridges.inverse_map_function(
63+
bridge::DotProductsBridge{T},
64+
func,
65+
) where {T}
5966
m = length(bridge.set.vectors)
60-
return MOI.Utilities.operate(vcat, T, MOI.Utilities.eachscalar(func)[(m+1):end])
67+
return MOI.Utilities.operate(
68+
vcat,
69+
T,
70+
MOI.Utilities.eachscalar(func)[(m+1):end],
71+
)
6172
end
62-

src/Utilities/functions.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,8 @@ end
638638
A type that allows iterating over the scalar-functions that comprise an
639639
`AbstractVectorFunction`.
640640
"""
641-
struct ScalarFunctionIterator{F<:MOI.AbstractVectorFunction,C,S} <: AbstractVector{S}
641+
struct ScalarFunctionIterator{F<:MOI.AbstractVectorFunction,C,S} <:
642+
AbstractVector{S}
642643
f::F
643644
# Cache that can be used to store a precomputed datastructure that allows
644645
# an efficient implementation of `getindex`.

src/sets.jl

Lines changed: 88 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1792,73 +1792,129 @@ function Base.getproperty(
17921792
end
17931793

17941794
"""
1795-
SetWithDotProducts(set, vectors::Vector{V})
1795+
SetWithDotProducts(set::MOI.AbstractSet, vectors::AbstractVector)
17961796
1797-
Given a set `set` of dimension `d` and `m` vectors `v_1`, ..., `v_m` given in `vectors`, this is the set:
1798-
``\\{ ((\\langle v_1, x \\rangle, ..., \\langle v_m, x \\rangle, x) \\in \\mathbb{R}^{m + d} : x \\in \\text{set} \\}.``
1797+
Given a set `set` of dimension `d` and `m` vectors `a_1`, ..., `a_m` given in `vectors`, this is the set:
1798+
``\\{ ((\\langle a_1, x \\rangle, ..., \\langle a_m, x \\rangle) \\in \\mathbb{R}^{m} : x \\in \\text{set} \\}.``
17991799
"""
1800-
struct SetWithDotProducts{S,V} <: AbstractVectorSet
1800+
struct SetWithDotProducts{S,A,V<:AbstractVector{A}} <: AbstractVectorSet
18011801
set::S
1802-
vectors::Vector{V}
1802+
vectors::V
18031803
end
18041804

18051805
function Base.copy(s::SetWithDotProducts)
18061806
return SetWithDotProducts(copy(s.set), copy(s.vectors))
18071807
end
18081808

1809-
function dimension(s::SetWithDotProducts)
1810-
return length(s.vectors) + dimension(s.set)
1811-
end
1809+
dimension(s::SetWithDotProducts) = length(s.vectors)
18121810

18131811
function dual_set(s::SetWithDotProducts)
18141812
return LinearCombinationInSet(s.set, s.vectors)
18151813
end
18161814

1817-
function dual_set_type(::Type{SetWithDotProducts{S,V}}) where {S,V}
1818-
return LinearCombinationInSet{S,V}
1815+
function dual_set_type(::Type{SetWithDotProducts{S,A,V}}) where {S,A,V}
1816+
return LinearCombinationInSet{S,A,V}
18191817
end
18201818

18211819
"""
1822-
LinearCombinationInSet{S,V}(set::S, matrices::Vector{V})
1820+
LinearCombinationInSet(set::MOI.AbstractSet, matrices::AbstractVector)
18231821
1824-
Given a set `set` of dimension `d` and `m` vectors `v_1`, ..., `v_m` given in `vectors`, this is the set:
1825-
``\\{ ((y, x) \\in \\mathbb{R}^{m + d} : \\sum_{i=1}^m y_i v_i + x \\in \\text{set} \\}.``
1822+
Given a set `set` of dimension `d` and `m` vectors `a_1`, ..., `a_m` given in `vectors`, this is the set:
1823+
``\\{ (y \\in \\mathbb{R}^{m} : \\sum_{i=1}^m y_i a_i \\in \\text{set} \\}.``
18261824
"""
1827-
struct LinearCombinationInSet{S,V} <: AbstractVectorSet
1825+
struct LinearCombinationInSet{S,A,V<:AbstractVector{A}} <: AbstractVectorSet
18281826
set::S
1829-
vectors::Vector{V}
1827+
vectors::V
18301828
end
18311829

1832-
dimension(s::LinearCombinationInSet) = length(s.vectors) + simension(s.set)
1830+
dimension(s::LinearCombinationInSet) = length(s.vectors)
18331831

18341832
function dual_set(s::LinearCombinationInSet)
1835-
return SetWithDotProducts(s.side_dimension, s.matrices)
1833+
return SetWithDotProducts(s.side_dimension, s.vectors)
1834+
end
1835+
1836+
function dual_set_type(::Type{LinearCombinationInSet{S,A,V}}) where {S,A,V}
1837+
return SetWithDotProducts{S,A,V}
18361838
end
18371839

1838-
function dual_set_type(::Type{LinearCombinationInSet{S,V}}) where {S,V}
1839-
return SetWithDotProducts{S,V}
1840+
abstract type AbstractFactorization{T,F} <: AbstractMatrix{T} end
1841+
1842+
function Base.size(m::AbstractFactorization)
1843+
n = size(m.factor, 1)
1844+
return (n, n)
18401845
end
18411846

18421847
"""
1843-
struct LowRankMatrix{T}
1844-
diagonal::Vector{T}
1845-
factor::Matrix{T}
1848+
struct Factorization{
1849+
T,
1850+
F<:Union{AbstractVector{T},AbstractMatrix{T}},
1851+
D<:Union{T,AbstractVector{T}},
1852+
} <: AbstractMatrix{T}
1853+
factor::F
1854+
scaling::D
18461855
end
18471856
1848-
`factor * Diagonal(diagonal) * factor'`.
1849-
"""
1850-
struct LowRankMatrix{T} <: AbstractMatrix{T}
1851-
diagonal::Vector{T}
1852-
factor::Matrix{T}
1857+
Matrix corresponding to `factor * Diagonal(diagonal) * factor'`.
1858+
If `factor` is a vector and `diagonal` is a scalar, this corresponds to
1859+
the matrix `diagonal * factor * factor'`.
1860+
If `factor` is a matrix and `diagonal` is a vector, this corresponds to
1861+
the matrix `factor * Diagonal(scaling) * factor'`.
1862+
"""
1863+
struct Factorization{
1864+
T,
1865+
F<:Union{AbstractVector{T},AbstractMatrix{T}},
1866+
D<:Union{T,AbstractVector{T}},
1867+
} <: AbstractFactorization{T,F}
1868+
factor::F
1869+
scaling::D
1870+
function Factorization(
1871+
factor::AbstractMatrix{T},
1872+
scaling::AbstractVector{T},
1873+
) where {T}
1874+
if length(scaling) != size(factor, 2)
1875+
error(
1876+
"Length `$(length(scaling))` of diagonal does not match number of columns `$(size(factor, 2))` of factor",
1877+
)
1878+
end
1879+
return new{T,typeof(factor),typeof(scaling)}(factor, scaling)
1880+
end
1881+
function Factorization(factor::AbstractVector{T}, scaling::T) where {T}
1882+
return new{T,typeof(factor),typeof(scaling)}(factor, scaling)
1883+
end
18531884
end
18541885

1855-
function Base.size(m::LowRankMatrix)
1856-
n = size(m.factor, 1)
1857-
return (n, n)
1886+
function Base.getindex(m::Factorization, i::Int, j::Int)
1887+
return sum(
1888+
m.factor[i, k] * m.scaling[k] * m.factor[j, k]' for
1889+
k in eachindex(m.scaling)
1890+
)
1891+
end
1892+
1893+
"""
1894+
struct Factorization{
1895+
T,
1896+
F<:Union{AbstractVector{T},AbstractMatrix{T}},
1897+
D<:Union{T,AbstractVector{T}},
1898+
} <: AbstractMatrix{T}
1899+
factor::F
1900+
scaling::D
1901+
end
1902+
1903+
Matrix corresponding to `factor * Diagonal(diagonal) * factor'`.
1904+
If `factor` is a vector and `diagonal` is a scalar, this corresponds to
1905+
the matrix `diagonal * factor * factor'`.
1906+
If `factor` is a matrix and `diagonal` is a vector, this corresponds to
1907+
the matrix `factor * Diagonal(scaling) * factor'`.
1908+
"""
1909+
struct PositiveSemidefiniteFactorization{
1910+
T,
1911+
F<:Union{AbstractVector{T},AbstractMatrix{T}},
1912+
} <: AbstractFactorization{T,F}
1913+
factor::F
18581914
end
18591915

1860-
function Base.getindex(m::LowRankMatrix, i::Int, j::Int)
1861-
return sum(m.factor[i, k] * m.diagonal[k] * m.factor[j, k]' for k in eachindex(m.diagonal))
1916+
function Base.getindex(m::PositiveSemidefiniteFactorization, i::Int, j::Int)
1917+
return sum(m.factor[i, k] * m.factor[j, k]' for k in axes(m.factor, 2))
18621918
end
18631919

18641920
struct TriangleVectorization{T,M<:AbstractMatrix{T}} <: AbstractVector{T}

test/sets.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ module TestSets
88

99
using Test
1010
import MathOptInterface as MOI
11+
import LinearAlgebra
1112

1213
include("dummy.jl")
1314

@@ -353,6 +354,41 @@ function test_sets_reified()
353354
return
354355
end
355356

357+
function _test_factorization(A, B)
358+
@test size(A) == size(B)
359+
@test A B
360+
d = LinearAlgebra.checksquare(A)
361+
n = div(d * (d + 1), 2)
362+
vA = MOI.TriangleVectorization(A)
363+
@test length(vA) == n
364+
@test eachindex(vA) == Base.OneTo(n)
365+
vB = MOI.TriangleVectorization(B)
366+
@test length(vB) == n
367+
@test eachindex(vA) == Base.OneTo(n)
368+
k = 0
369+
for j in 1:d
370+
for i in 1:j
371+
k += 1
372+
@test vA[k] == vB[k]
373+
@test vA[k] == A[i, j]
374+
end
375+
end
376+
return
377+
end
378+
379+
function test_factorizations()
380+
f = [1, 2]
381+
_test_factorization(f * f', MOI.PositiveSemidefiniteFactorization(f))
382+
_test_factorization(2 * f * f', MOI.Factorization(f, 2))
383+
F = [1 2; 3 4; 5 6]
384+
d = [7, 8]
385+
_test_factorization(F * F', MOI.PositiveSemidefiniteFactorization(F))
386+
return _test_factorization(
387+
F * LinearAlgebra.Diagonal(d) * F',
388+
MOI.Factorization(F, d),
389+
)
390+
end
391+
356392
function runtests()
357393
for name in names(@__MODULE__; all = true)
358394
if startswith("$name", "test_")

0 commit comments

Comments
 (0)