Skip to content

Commit bb17c4c

Browse files
authored
Special-case getindex on VcatAtom (#625)
1 parent 5fea8c3 commit bb17c4c

File tree

3 files changed

+107
-7
lines changed

3 files changed

+107
-7
lines changed

src/atoms/IndexAtom.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,7 @@ function Base.getindex(x::AbstractExpr, I::AbstractVector{Bool})
114114
end
115115

116116
# All rows and columns
117-
function Base.getindex(x::AbstractExpr, ::Colon, ::Colon)
118-
rows, cols = size(x)
119-
return getindex(x, 1:rows, 1:cols)
120-
end
117+
Base.getindex(x::AbstractExpr, ::Colon, ::Colon) = x
121118

122119
# All rows for this column(s)
123120
function Base.getindex(x::AbstractExpr, ::Colon, col)

src/atoms/VcatAtom.jl

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,71 @@ function Base.vcat(args::Union{AbstractExpr,Value}...)
5959
end
6060
return VcatAtom(args...)
6161
end
62+
63+
function Base.getindex(
64+
x::VcatAtom,
65+
rows::AbstractVector{<:Real},
66+
cols::AbstractVector{<:Real},
67+
)
68+
idx = 0
69+
rows = collect(rows) # make a mutable copy
70+
keep_children = ()
71+
for c in x.children
72+
# here are the row indices into `x` that point to `c`
73+
I = idx .+ (1:size(c, 1))
74+
if issubset(rows, I)
75+
# if all the row indices we want are in this one child, we can early exit
76+
if rows == I && cols == 1:size(c, 2)
77+
return c
78+
else
79+
return c[rows.-idx, cols]
80+
end
81+
elseif !isdisjoint(rows, I)
82+
# we have some but not all rows in this child, so keep it
83+
keep_children = (keep_children..., c)
84+
idx += size(c, 1)
85+
else # we can drop this child!
86+
# let's update `rows` to account for the removal
87+
l = last(I)
88+
for i in eachindex(rows)
89+
if rows[i] >= l
90+
rows[i] -= length(I)
91+
end
92+
end
93+
end
94+
end
95+
# If we are here, the indices span multiple children.
96+
# We can't necessarily index each separately, since they may be out of order.
97+
# So we will defer to an `IndexAtom` on the remaining children
98+
remaining = VcatAtom(keep_children...)
99+
return IndexAtom(remaining, rows, cols)
100+
end
101+
102+
# linear indexing: very similar to row-indexing above, but with linear indices
103+
function Base.getindex(x::VcatAtom, inds::AbstractVector{<:Real})
104+
idx = 0
105+
inds = collect(inds)
106+
keep_children = ()
107+
for c in x.children
108+
I = idx .+ (1:length(c))
109+
if issubset(inds, I)
110+
if inds == I
111+
return c
112+
else
113+
return c[inds.-idx]
114+
end
115+
elseif !isdisjoint(inds, I)
116+
keep_children = (keep_children..., c)
117+
idx += length(c)
118+
else
119+
l = last(I)
120+
for i in eachindex(inds)
121+
if inds[i] >= l
122+
inds[i] -= length(I)
123+
end
124+
end
125+
end
126+
end
127+
remaining = VcatAtom(keep_children...)
128+
return IndexAtom(remaining, inds)
129+
end

test/test_atoms.jl

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -477,9 +477,8 @@ function test_IndexAtom()
477477
_test_atom(target) do context
478478
return Variable(2)[:, 1]
479479
end
480-
_test_atom(target) do context
481-
return Variable(2)[:, :]
482-
end
480+
x = Variable(2)
481+
@test x[:, :] === x
483482
target = """
484483
variables: x1, x2, x3
485484
minobjective: [1.0 * x1, 1.0 * x3]
@@ -814,6 +813,42 @@ function test_VcatAtom()
814813
return
815814
end
816815

816+
function test_VcatAtom_getindex()
817+
x = Variable()
818+
sq = square(x)
819+
for v in [vcat(x, sq, -sq), vcat(vcat(x, sq), -sq)]
820+
@test isequal(v[1], x)
821+
@test isequal(v[2], sq)
822+
@test isequal(v[:, 1], v[1:3, 1])
823+
@test isequal(v[1:3, :], v[1:3, 1:1])
824+
@test v[2:-1:1].children[1] isa Convex.VcatAtom
825+
@test v[2:-1:1].children[1].children == (x, sq)
826+
@test vexity(v[1]) isa Convex.AffineVexity
827+
@test vexity(v[2]) isa Convex.ConvexVexity
828+
@test vexity(v[3]) isa Convex.ConcaveVexity
829+
@test vexity(v[1:2]) isa Convex.ConvexVexity
830+
@test vexity(v[:, 1]) isa Convex.NotDcp
831+
end
832+
x = Variable(2, 2)
833+
sq = square(x)
834+
for v in [vcat(x, sq, -sq), vcat(vcat(x, sq), -sq)]
835+
@test isequal(v[:, :], v)
836+
@test isequal(v[1:2, 1:2], x)
837+
@test v[1, 1:2] isa Convex.IndexAtom
838+
@test v[1:3, 1:2] isa Convex.IndexAtom
839+
@test v[1:3, 1:2].children[1] isa Convex.VcatAtom
840+
@test isequal(v[3:4, 1:2], sq)
841+
@test v[3:-1:1, 1].children[1] isa Convex.VcatAtom
842+
@test v[4:-1:1, :].children[1].children == (x, sq)
843+
@test vexity(v[1]) isa Convex.AffineVexity
844+
@test vexity(v[3:4, 1:2]) isa Convex.ConvexVexity
845+
@test vexity(v[5:6, 1]) isa Convex.ConcaveVexity
846+
@test vexity(v[5:6, 1:2]) isa Convex.ConcaveVexity
847+
@test vexity(v[1:3, 1]) isa Convex.ConvexVexity
848+
end
849+
return
850+
end
851+
817852
### exp_+_sdp_cone/LogDetAtom
818853

819854
function test_LogDetAtom()

0 commit comments

Comments
 (0)