Skip to content

Commit 114fa43

Browse files
committed
Add bridge
1 parent 01fa32c commit 114fa43

File tree

6 files changed

+100
-9
lines changed

6 files changed

+100
-9
lines changed

src/Bridges/Variable/Variable.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ function add_all_bridges(model, ::Type{T}) where {T}
3434
MOI.Bridges.add_bridge(model, RSOCtoPSDBridge{T})
3535
MOI.Bridges.add_bridge(model, HermitianToSymmetricPSDBridge{T})
3636
MOI.Bridges.add_bridge(model, ParameterToEqualToBridge{T})
37+
MOI.Bridges.add_bridge(model, DotProductsBridge{T})
3738
return
3839
end
3940

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
struct DotProductsBridge{T,S,V} <: SetMapBridge{T,S,MOI.SetWithDotProducts{S,V}}
2+
variables::Vector{MOI.VariableIndex}
3+
constraint::MOI.ConstraintIndex{MOI.VectorOfVariables, S}
4+
set::MOI.SetWithDotProducts{S,V}
5+
end
6+
7+
function supports_constrained_variable(
8+
::Type{<:DotProductsBridge},
9+
::Type{<:MOI.SetWithDotProducts},
10+
)
11+
return true
12+
end
13+
14+
function concrete_bridge_type(
15+
::Type{<:DotProductsBridge{T}},
16+
::Type{MOI.SetWithDotProducts{S,V}},
17+
) where {T,S,V}
18+
return DotProductsBridge{T,S,V}
19+
end
20+
21+
function bridge_constrained_variable(
22+
BT::Type{DotProductsBridge{T,S,V}},
23+
model::MOI.ModelLike,
24+
set::MOI.SetWithDotProducts{S,V},
25+
) where {T,S,V}
26+
variables, constraint =
27+
_add_constrained_var(model, MOI.Bridges.inverse_map_set(BT, set))
28+
return BT(variables, constraint, set)
29+
end
30+
31+
function MOI.Bridges.map_set(
32+
bridge::DotProductsBridge{T,S},
33+
set::S,
34+
) where {T,S}
35+
return MOI.SetWithDotProducts(set, bridge.vectors)
36+
end
37+
38+
function MOI.Bridges.inverse_map_set(
39+
::Type{<:DotProductsBridge},
40+
set::MOI.SetWithDotProducts,
41+
)
42+
return set.set
43+
end
44+
45+
function MOI.Bridges.map_function(
46+
bridge::DotProductsBridge{T},
47+
func,
48+
i::MOI.Bridges.IndexInVector,
49+
) where {T}
50+
scalars = MOI.Utilities.eachscalar(func)
51+
if i.value in eachindex(bridge.set.vectors)
52+
return MOI.Utilities.set_dot(bridge.set.vectors[i.value], scalars, bridge.set.set)
53+
else
54+
return convert(MOI.ScalarAffineFunction{T}, scalars[i.value - length(bridge.vectors)])
55+
end
56+
end
57+
58+
function MOI.Bridges.inverse_map_function(bridge::DotProductsBridge{T}, func) where {T}
59+
m = length(bridge.set.vectors)
60+
return MOI.Utilities.operate(vcat, T, MOI.Utilities.eachscalar(func)[(m+1):end])
61+
end
62+

src/Bridges/Variable/set_map.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ function MOI.get(
171171
if any(isnothing, value)
172172
return nothing
173173
end
174-
return MOI.Bridges.map_function(typeof(bridge), value, i)
174+
return MOI.Bridges.map_function(bridge, value, i)
175175
end
176176

177177
function MOI.supports(
@@ -203,7 +203,7 @@ function MOI.Bridges.bridged_function(
203203
i::MOI.Bridges.IndexInVector,
204204
) where {T}
205205
func = MOI.Bridges.map_function(
206-
typeof(bridge),
206+
bridge,
207207
MOI.VectorOfVariables(bridge.variables),
208208
i,
209209
)
@@ -212,7 +212,7 @@ end
212212

213213
function unbridged_map(bridge::SetMapBridge{T}, vi::MOI.VariableIndex) where {T}
214214
F = MOI.ScalarAffineFunction{T}
215-
mapped = MOI.Bridges.inverse_map_function(typeof(bridge), vi)
215+
mapped = MOI.Bridges.inverse_map_function(bridge, vi)
216216
return Pair{MOI.VariableIndex,F}[bridge.variable=>mapped]
217217
end
218218

@@ -222,9 +222,10 @@ function unbridged_map(
222222
) where {T}
223223
F = MOI.ScalarAffineFunction{T}
224224
func = MOI.VectorOfVariables(vis)
225-
funcs = MOI.Bridges.inverse_map_function(typeof(bridge), func)
225+
funcs = MOI.Bridges.inverse_map_function(bridge, func)
226226
scalars = MOI.Utilities.eachscalar(funcs)
227+
# FIXME not correct for SetWithDotProducts, it won't recover the dot product variables
227228
return Pair{MOI.VariableIndex,F}[
228-
bridge.variables[i] => scalars[i] for i in eachindex(vis)
229+
bridge.variables[i] => scalars[i] for i in eachindex(bridge.variables)
229230
]
230231
end

src/Bridges/set_map.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ function map_function(::Type{BT}, func, i::IndexInVector) where {BT}
8080
return MOI.Utilities.eachscalar(map_function(BT, func))[i.value]
8181
end
8282

83+
function map_function(bridge::AbstractBridge, func, i::IndexInVector)
84+
return map_function(typeof(bridge), func, i)
85+
end
86+
8387
"""
8488
inverse_map_function(bridge::MOI.Bridges.AbstractBridge, func)
8589
inverse_map_function(::Type{BT}, func) where {BT}

src/Utilities/functions.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -638,18 +638,23 @@ end
638638
A type that allows iterating over the scalar-functions that comprise an
639639
`AbstractVectorFunction`.
640640
"""
641-
struct ScalarFunctionIterator{F<:MOI.AbstractVectorFunction,C}
641+
struct ScalarFunctionIterator{F<:MOI.AbstractVectorFunction,C,S} <: AbstractVector{S}
642642
f::F
643643
# Cache that can be used to store a precomputed datastructure that allows
644644
# an efficient implementation of `getindex`.
645645
cache::C
646+
function ScalarFunctionIterator(f::MOI.AbstractVectorFunction, cache)
647+
return new{typeof(f),typeof(cache),scalar_type(typeof(f))}(f, cache)
648+
end
646649
end
647650

648651
function ScalarFunctionIterator(func::MOI.AbstractVectorFunction)
649652
return ScalarFunctionIterator(func, scalar_iterator_cache(func))
650653
end
651654

652-
scalar_iterator_cache(func::MOI.AbstractVectorFunction) = nothing
655+
Base.size(s::ScalarFunctionIterator) = (MOI.output_dimension(s.f),)
656+
657+
scalar_iterator_cache(::MOI.AbstractVectorFunction) = nothing
653658

654659
function output_index_iterator(terms::AbstractVector, output_dimension)
655660
start = zeros(Int, output_dimension)

src/sets.jl

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1855,15 +1855,33 @@ end
18551855
18561856
`factor * Diagonal(diagonal) * factor'`.
18571857
"""
1858-
struct LowRankMatrix{T}
1858+
struct LowRankMatrix{T} <: AbstractMatrix{T}
18591859
diagonal::Vector{T}
18601860
factor::Matrix{T}
18611861
end
18621862

1863-
struct TriangleVectorization{M}
1863+
function Base.size(m::LowRankMatrix)
1864+
n = size(m.factor, 1)
1865+
return (n, n)
1866+
end
1867+
1868+
function Base.getindex(m::LowRankMatrix, i::Int, j::Int)
1869+
return sum(m.factor[i, k] * m.diagonal[k] * m.factor[j, k]' for k in eachindex(m.diagonal))
1870+
end
1871+
1872+
struct TriangleVectorization{T,M<:AbstractMatrix{T}} <: AbstractVector{T}
18641873
matrix::M
18651874
end
18661875

1876+
function Base.size(v::TriangleVectorization)
1877+
n = size(v.matrix, 1)
1878+
return (Utilities.trimap(n, n),)
1879+
end
1880+
1881+
function Base.getindex(v::TriangleVectorization, k::Int)
1882+
return getindex(v.matrix, Utilities.inverse_trimap(k)...)
1883+
end
1884+
18671885
"""
18681886
SOS1{T<:Real}(weights::Vector{T})
18691887

0 commit comments

Comments
 (0)