@@ -9,18 +9,14 @@ mutable struct IndexAtom <: AbstractExpr
9
9
rows:: Union{AbstractArray,Nothing}
10
10
cols:: Union{AbstractArray,Nothing}
11
11
inds:: Union{AbstractArray,Nothing}
12
+ end
12
13
13
- function IndexAtom (
14
- x:: AbstractExpr ,
15
- rows:: AbstractArray ,
16
- cols:: AbstractArray ,
17
- )
18
- return new ((x,), (length (rows), length (cols)), rows, cols, nothing )
19
- end
14
+ function IndexAtom (x:: AbstractExpr , rows:: AbstractArray , cols:: AbstractArray )
15
+ return IndexAtom ((x,), (length (rows), length (cols)), rows, cols, nothing )
16
+ end
20
17
21
- function IndexAtom (x:: AbstractExpr , inds:: AbstractArray )
22
- return new ((x,), (length (inds), 1 ), nothing , nothing , inds)
23
- end
18
+ function IndexAtom (x:: AbstractExpr , inds:: AbstractArray )
19
+ return IndexAtom ((x,), (length (inds), 1 ), nothing , nothing , inds)
24
20
end
25
21
26
22
head (io:: IO , :: IndexAtom ) = print (io, " index" )
@@ -33,7 +29,11 @@ curvature(::IndexAtom) = ConstVexity()
33
29
34
30
function evaluate (x:: IndexAtom )
35
31
result = if x. inds === nothing
36
- getindex (evaluate (x. children[1 ]), x. rows, x. cols)
32
+ # reshape to ensure we are respecting that a scalar row index
33
+ # creates a column vector. We can't just check `length(x.rows)`
34
+ # since that doesn't distinguish between `i` and `i:i`. But
35
+ # we had that info when we set the size, so we will use it now.
36
+ reshape (getindex (evaluate (x. children[1 ]), x. rows, x. cols), x. size)
37
37
else
38
38
getindex (evaluate (x. children[1 ]), x. inds)
39
39
end
@@ -98,7 +98,11 @@ function Base.getindex(x::AbstractExpr, row::Real, col::Real)
98
98
end
99
99
100
100
function Base. getindex (x:: AbstractExpr , row:: Real , cols:: AbstractVector{<:Real} )
101
- return getindex (x, row: row, cols)
101
+ # In this case, we must construct a column vector
102
+ # https://github.com/jump-dev/Convex.jl/issues/509
103
+ # Here we construct `getindex(x, row:row, cols)`
104
+ # except with the size set to be that of a column vector
105
+ return IndexAtom ((x,), (length (cols), 1 ), row: row, cols, nothing )
102
106
end
103
107
104
108
function Base. getindex (x:: AbstractExpr , rows:: AbstractVector{<:Real} , col:: Real )
0 commit comments