Skip to content

Commit cbe5500

Browse files
authored
Fix indexing AbstractExor with BitVector and BitMatrix (#710)
1 parent b5923f9 commit cbe5500

File tree

3 files changed

+33
-24
lines changed

3 files changed

+33
-24
lines changed

src/atoms/IndexAtom.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -109,20 +109,22 @@ function Base.getindex(x::AbstractExpr, rows::AbstractVector{<:Real}, col::Real)
109109
return getindex(x, rows, col:col)
110110
end
111111

112-
function Base.getindex(
113-
x::AbstractExpr,
114-
I::Union{AbstractMatrix{Bool},<:BitMatrix},
115-
)
112+
function Base.getindex(x::AbstractExpr, I::AbstractMatrix{Bool})
116113
return [xi for (xi, ii) in zip(x, I) if ii]
117114
end
118115

119-
function Base.getindex(
120-
x::AbstractExpr,
121-
I::Union{<:AbstractVector{Bool},<:BitVector},
122-
)
116+
function Base.getindex(x::AbstractExpr, I::AbstractVector{Bool})
123117
return [xi for (xi, ii) in zip(x, I) if ii]
124118
end
125119

120+
function Base.getindex(x::AbstractExpr, ind::BitVector)
121+
return getindex(x, findall(ind))
122+
end
123+
124+
function Base.getindex(x::AbstractExpr, ind::BitMatrix)
125+
return getindex(x, LinearIndices(ind)[ind])
126+
end
127+
126128
# All rows and columns
127129
Base.getindex(x::AbstractExpr, ::Colon, ::Colon) = x
128130

src/atoms/VcatAtom.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,10 @@ function Base.getindex(x::VcatAtom, inds::AbstractVector{<:Real})
119119
return IndexAtom(remaining, inds)
120120
end
121121

122+
function Base.getindex(x::VcatAtom, inds::BitVector)
123+
return getindex(x, findall(inds))
124+
end
125+
122126
function Base.getindex(x::VcatAtom, inds::AbstractVector{Bool})
123-
return getindex(x, first.(filter!(last, collect(enumerate(inds)))))
127+
return getindex(x, convert(BitVector, inds))
124128
end

test/test_atoms.jl

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -535,22 +535,25 @@ function test_IndexAtom()
535535
Convex.set_value!(x, [1 3; 2 4])
536536
@test Convex.evaluate.(z) == [1, 2, 4]
537537
# Base.getindex(x::AbstractExpr, I::BitVector)
538-
y = BitVector([true, false, true])
539-
x = Variable(3)
540-
z = x[y]
541-
@test string(z) == string([x[1], x[3]])
542-
@test z isa Vector{Convex.IndexAtom}
543-
@test length(z) == 2
544-
Convex.set_value!(x, [1, 2, 3])
545-
@test Convex.evaluate.(z) == [1, 3]
538+
target = """
539+
variables: x1, x2, x3
540+
minobjective: [1.0 * x1, 1.0 * x3]
541+
"""
542+
_test_atom(target) do context
543+
x = Variable(3)
544+
y = BitVector([true, false, true])
545+
return x[y]
546+
end
546547
# Base.getindex(x::AbstractExpr, I::BitMatrix)
547-
y = BitMatrix([true false; true true])
548-
x = Variable(2, 2)
549-
z = x[y]
550-
@test z isa Vector{Convex.IndexAtom}
551-
@test length(z) == 3
552-
Convex.set_value!(x, [1 3; 2 4])
553-
@test Convex.evaluate.(z) == [1, 2, 4]
548+
target = """
549+
variables: x1, x2, x3, x4
550+
minobjective: [1.0 * x1, 1.0 * x2, 1.0 * x4]
551+
"""
552+
_test_atom(target) do context
553+
x = Variable(2, 2)
554+
y = BitMatrix([true false; true true])
555+
return x[y]
556+
end
554557
return
555558
end
556559

0 commit comments

Comments
 (0)