Skip to content

Commit cc6a9b0

Browse files
authored
Add VcatAtom (#607)
1 parent 0bd470b commit cc6a9b0

File tree

3 files changed

+134
-64
lines changed

3 files changed

+134
-64
lines changed

src/atoms/affine/HcatAtom.jl

Lines changed: 11 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -29,41 +29,15 @@ monotonicity(x::HcatAtom) = ntuple(_ -> Nondecreasing(), length(x.children))
2929

3030
curvature(::HcatAtom) = ConstVexity()
3131

32-
evaluate(x::HcatAtom) = hcat(map(evaluate, x.children)...)
32+
evaluate(x::HcatAtom) = reduce(hcat, collect(map(evaluate, x.children)))
3333

3434
function new_conic_form!(context::Context{T}, x::HcatAtom) where {T}
35-
objectives = map(c -> conic_form!(context, c), AbstractTrees.children(x))
36-
# Suppose the child objectives for two children e1 (2 x 1) and e2 (2 x 2)
37-
# look something like
38-
# e1: x => 1 2 3
39-
# 4 5 6
40-
# y => 2 4
41-
# 7 8
42-
# e2: x => 1 1 1
43-
# 2 2 2
44-
# 3 3 3
45-
# 4 4 4
46-
# The objective of [e1 e2] will look like
47-
# x => 1 2 3
48-
# 4 5 6
49-
# 1 1 1
50-
# 2 2 2
51-
# 3 3 3
52-
# 4 4 4
53-
# y => 2 4
54-
# 7 8
55-
# 0 0
56-
# 0 0
57-
# 0 0
58-
# 0 0
59-
# builds the objective by aggregating a list of coefficients for each
60-
# variable from each child objective, and then vertically concatenating them
61-
return operate(vcat, T, sign(x), objectives...)
35+
args = map(c -> conic_form!(context, c), AbstractTrees.children(x))
36+
# MOI represents matrices by concatenating their columns, so even though
37+
# this is an HcatAtom, we built the conic form by vcat'ing the arguments.
38+
return operate(vcat, T, sign(x), args...)
6239
end
63-
# TODO: fix piracy!
6440

65-
# * `Value` is not owned by Convex.jl
66-
# * splatting creates zero-argument functions, which again are not owned by Convex.jl
6741
Base.hcat(args::AbstractExpr...) = HcatAtom(args...)
6842

6943
function Base.hcat(args::Union{AbstractExpr,Value}...)
@@ -73,26 +47,15 @@ function Base.hcat(args::Union{AbstractExpr,Value}...)
7347
return HcatAtom(args...)
7448
end
7549

76-
# TODO: implement vertical concatenation in a more efficient way
77-
Base.vcat(args::AbstractExpr...) = transpose(HcatAtom(map(transpose, args)...))
78-
79-
function Base.vcat(args::Union{AbstractExpr,Value}...)
80-
if all(Base.Fix2(isa, Value), args)
81-
return Base.cat(args..., dims = Val(1))
82-
end
83-
return transpose(HcatAtom(map(transpose, args)...))
84-
end
85-
8650
function Base.hvcat(
8751
rows::Tuple{Vararg{Int}},
8852
args::Union{AbstractExpr,Value}...,
8953
)
90-
nbr = length(rows)
91-
rs = Vector{Any}(undef, nbr)
92-
a = 1
93-
for i in 1:nbr
94-
rs[i] = HcatAtom(args[a:a-1+rows[i]]...)
95-
a += rows[i]
54+
output_rows = Vector{HcatAtom}(undef, length(rows))
55+
offset = 0
56+
for (i, n) in enumerate(rows)
57+
output_rows[i] = HcatAtom(args[offset.+(1:n)]...)
58+
offset += n
9659
end
97-
return vcat(rs...)
60+
return vcat(output_rows...)
9861
end

src/atoms/affine/VcatAtom.jl

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright (c) 2014: Madeleine Udell and contributors
2+
#
3+
# Use of this source code is governed by a BSD-style license that can be found
4+
# in the LICENSE file or at https://opensource.org/license/bsd-2-clause
5+
6+
mutable struct VcatAtom <: AbstractExpr
7+
children::Tuple
8+
size::Tuple{Int,Int}
9+
10+
function VcatAtom(args...)
11+
args = convert.(AbstractExpr, args)
12+
num_rows, num_cols = 0, args[1].size[2]
13+
for arg in args
14+
if arg.size[2] != num_cols
15+
msg = "[VcatAtom] cannot stack expressions of incompatible size. Got $(arg.size[2]) expected $num_cols."
16+
throw(DimensionMismatch(msg))
17+
end
18+
num_rows += arg.size[1]
19+
end
20+
return new(args, (num_rows, num_cols))
21+
end
22+
end
23+
24+
head(io::IO, ::VcatAtom) = print(io, "vcat")
25+
26+
Base.sign(x::VcatAtom) = sum(map(sign, x.children))
27+
28+
monotonicity(x::VcatAtom) = ntuple(_ -> Nondecreasing(), length(x.children))
29+
30+
curvature(::VcatAtom) = ConstVexity()
31+
32+
evaluate(x::VcatAtom) = reduce(vcat, collect(map(evaluate, x.children)))
33+
34+
function new_conic_form!(context::Context{T}, x::VcatAtom) where {T}
35+
# Converting a VcatAtom to conic form is non-trivial. Consider two matrices:
36+
# x = [1 3; 2 4]
37+
# y = [5 7; 6 8]
38+
# with VcatAtom(x, y). The desired outcome is [1, 2, 5, 6, 3, 4, 7, 8].
39+
# If we naively convert the children to conic form and then vcat, we will
40+
# get:
41+
# vcat([1, 2, 3, 4], [5, 6, 7, 8]) = [1, 2, 3, 4, 5, 6, 7, 8]
42+
# which is not what we are after. We need to first transpose each child to
43+
# get:
44+
# x^T, y^T = [1 2; 3 4], [5 6; 7 8])
45+
# then hcat them to get:
46+
# hcat(x^T, y^T) = [1 2 5 6; 3 4 7 8]
47+
# then transpose this to get:
48+
# hcat(x^T, y^T)^T = [1 3; 2 4; 5 7; 6 8]
49+
# so our final conic form produces the desired
50+
# [1, 2, 5, 6, 3, 4, 7, 8]
51+
return conic_form!(context, transpose(reduce(hcat, transpose.(x.children))))
52+
end
53+
54+
Base.vcat(args::AbstractExpr...) = VcatAtom(args...)
55+
56+
function Base.vcat(args::Union{AbstractExpr,Value}...)
57+
if all(Base.Fix2(isa, Value), args)
58+
return Base.cat(args..., dims = Val(1))
59+
end
60+
return VcatAtom(args...)
61+
end

test/test_atoms.jl

Lines changed: 62 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -379,22 +379,14 @@ function test_HcatAtom()
379379
x = Variable()
380380
return hcat(x, x)
381381
end
382-
_test_atom(target) do context
383-
x = Variable()
384-
return vcat(x, x)
385-
end
386382
target = """
387383
variables: x1, x2
388384
minobjective: [1.0 * x1, 1.0 * x2, 2.0]
389385
"""
390386
_test_atom(target) do context
391-
x = Variable(2)
387+
x = Variable(1, 2)
392388
y = constant(2)
393-
return vcat(x, y)
394-
end
395-
_test_atom(target) do context
396-
x = Variable(2)
397-
return vcat(x, 2)
389+
return hcat(x, y)
398390
end
399391
_test_atom(target) do context
400392
x = Variable(1, 2)
@@ -406,12 +398,6 @@ function test_HcatAtom()
406398
),
407399
hcat(Variable(2), constant(2)),
408400
)
409-
@test_throws(
410-
DimensionMismatch(
411-
"[HcatAtom] cannot stack expressions of incompatible size. Got 2 expected 1.",
412-
),
413-
vcat(Variable(2, 1), Variable(1, 2)),
414-
)
415401
return
416402
end
417403

@@ -731,6 +717,66 @@ function test_SumAtom()
731717
return
732718
end
733719

720+
### affine/VcatAtom
721+
722+
function test_VcatAtom()
723+
target = """
724+
variables: x
725+
minobjective: [1.0 * x, 1.0 * x]
726+
"""
727+
_test_atom(target) do context
728+
x = Variable()
729+
return vcat(x, x)
730+
end
731+
target = """
732+
variables: x1, x2
733+
minobjective: [1.0 * x1, 1.0 * x2, 2.0]
734+
"""
735+
_test_atom(target) do context
736+
x = Variable(2)
737+
y = constant(2)
738+
return vcat(x, y)
739+
end
740+
_test_atom(target) do context
741+
x = Variable(2)
742+
return vcat(x, 2)
743+
end
744+
target = """
745+
variables: x1, x2
746+
minobjective: [1.0 * x1, 2.0, 1.0 * x2, 3.0]
747+
"""
748+
_test_atom(target) do context
749+
x = Variable(1, 2)
750+
y = constant([2 3])
751+
return vcat(x, y)
752+
end
753+
target = """
754+
variables: x1, x2, x3
755+
minobjective: [2.0, 1.0 * x1, 2.0, 3.0, 1.0 * x2, 3.0, 4.0, 1.0 * x3, 4.0]
756+
"""
757+
_test_atom(target) do context
758+
x = Variable(1, 3)
759+
y = constant([2 3 4])
760+
return vcat(y, x, y)
761+
end
762+
target = """
763+
variables: x1, x2, x3, x4
764+
minobjective: [x1, x2, 2.0, x3, x4, 3.0]
765+
"""
766+
_test_atom(target) do context
767+
x = Variable(2, 2)
768+
y = constant([2 3])
769+
return vcat(x, y)
770+
end
771+
@test_throws(
772+
DimensionMismatch(
773+
"[VcatAtom] cannot stack expressions of incompatible size. Got 2 expected 1.",
774+
),
775+
vcat(Variable(2, 1), Variable(1, 2)),
776+
)
777+
return
778+
end
779+
734780
### exp_+_sdp_cone/LogDetAtom
735781

736782
function test_LogDetAtom()

0 commit comments

Comments
 (0)