Skip to content

Commit 5fea8c3

Browse files
authored
Improve performance of IndexAtom a little (#615)
1 parent 317e94c commit 5fea8c3

File tree

3 files changed

+55
-27
lines changed

3 files changed

+55
-27
lines changed

src/atoms/IndexAtom.jl

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -42,33 +42,41 @@ function evaluate(x::IndexAtom)
4242
return output(result)
4343
end
4444

45-
function new_conic_form!(context::Context{T}, x::IndexAtom) where {T}
46-
obj = conic_form!(context, only(AbstractTrees.children(x)))
47-
m = length(x)
48-
n = length(x.children[1])
45+
function _index(tape::SparseTape{T}, keep_rows::Vector{Int}) where {T}
46+
A = tape.operation.matrix
47+
indexed = A[keep_rows, :]
48+
af = SparseAffineOperation{T}(indexed, tape.operation.vector[keep_rows])
49+
return SparseTape{T}(af, tape.variables)
50+
end
51+
52+
_index(tape::Vector, keep_rows::Vector{Int}) = tape[keep_rows]
53+
54+
function _index_real(
55+
obj_size::Tuple,
56+
obj_tape::Union{SparseTape,SPARSE_VECTOR},
57+
x::IndexAtom,
58+
)
4959
if x.inds === nothing
50-
sz = length(x.cols) * length(x.rows)
51-
J = Vector{Int}(undef, sz)
52-
k = 1
53-
num_rows = x.children[1].size[1]
54-
for c in x.cols
55-
for r in x.rows
56-
J[k] = num_rows * (convert(Int, c) - 1) + convert(Int, r)
57-
k += 1
58-
end
59-
end
60-
index_matrix = create_sparse(T, collect(1:sz), J, one(T), m, n)
61-
else
62-
index_matrix = create_sparse(
63-
T,
64-
collect(1:length(x.inds)),
65-
collect(x.inds),
66-
one(T),
67-
m,
68-
n,
69-
)
60+
linear_indices = LinearIndices(CartesianIndices(obj_size))
61+
return _index(obj_tape, vec(linear_indices[x.rows, x.cols]))
62+
end
63+
return _index(obj_tape, vec(collect(x.inds)))
64+
end
65+
66+
function new_conic_form!(context::Context{T}, x::IndexAtom) where {T}
67+
input = x.children[1]
68+
if !iscomplex(x) # real case
69+
input_tape = conic_form!(context, input)
70+
return _index_real(size(input), input_tape, x)
71+
end
72+
input_tape = conic_form!(context, input)
73+
re = _index_real(size(input), real(input_tape), x)
74+
im = _index_real(size(input), imag(input_tape), x)
75+
if re isa SPARSE_VECTOR
76+
@assert im isa SPARSE_VECTOR
77+
return ComplexStructOfVec(re, im)
7078
end
71-
return operate(add_operation, T, sign(x), index_matrix, obj)
79+
return ComplexTape(re, im)
7280
end
7381

7482
function Base.getindex(

src/problem_depot/problems/affine.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,26 @@ end
260260
rtol,
261261
::Type{T},
262262
) where {T,test}
263+
x = ComplexVariable(2)
264+
fix!(x, [1, 2] + im * [1, 2])
265+
t = Variable()
266+
p = minimize(t + real(x[1]), t >= 0; numeric_type = T)
267+
handle_problem!(p)
268+
if test
269+
@test p.optval 1 atol = atol rtol = rtol
270+
end
271+
272+
x = Variable(4, 2)
273+
y = [1:4 5:8]
274+
add_constraint!(x, x == y)
275+
p = minimize(dot(x[[4, 3], 2], [7, 13]); numeric_type = T)
276+
handle_problem!(p)
277+
if test
278+
# we would get 153 if we weren't respecting the index ordering
279+
@test dot(y[[4, 3], 2], [7, 13]) == 147
280+
@test p.optval 147 atol = atol rtol = rtol
281+
end
282+
263283
x = Variable(2)
264284
p = minimize(x[1] + x[2], [x >= 1]; numeric_type = T)
265285

src/problems.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,13 +172,13 @@ function Context(p::Problem{T}, optimizer_factory) where {T}
172172
if p.head == :satisfy
173173
MOI.set(context.model, MOI.ObjectiveSense(), MOI.FEASIBILITY_SENSE)
174174
else
175-
obj = _to_scalar_moi(T, cfp)
176-
MOI.set(context.model, MOI.ObjectiveFunction{typeof(obj)}(), obj)
177175
MOI.set(
178176
context.model,
179177
MOI.ObjectiveSense(),
180178
p.head == :maximize ? MOI.MAX_SENSE : MOI.MIN_SENSE,
181179
)
180+
obj = _to_scalar_moi(T, cfp)
181+
MOI.set(context.model, MOI.ObjectiveFunction{typeof(obj)}(), obj)
182182
end
183183
return context
184184
end

0 commit comments

Comments
 (0)