Skip to content

Commit 95fbf0d

Browse files
committed
Make sure to lower the first op in a collection first. Fixes #280.
1 parent 59adada commit 95fbf0d

File tree

2 files changed

+29
-10
lines changed

2 files changed

+29
-10
lines changed

src/codegen/lower_load.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -332,16 +332,21 @@ function _lower_load!(
332332
q::Expr, ls::LoopSet, op::Operation, td::UnrollArgs, mask::Bool, inds_calc_by_ptr_offset::Vector{Bool} = indices_calculated_by_pointer_offsets(ls, op.ref)
333333
)
334334
if rejectinterleave(op)
335-
lower_load_no_optranslation!(q, ls, op, td, mask, inds_calc_by_ptr_offset)
335+
return lower_load_no_optranslation!(q, ls, op, td, mask, inds_calc_by_ptr_offset)
336336
else
337337
omop = offsetloadcollection(ls)
338338
@unpack opids, opidcollectionmap, batchedcollections, batchedcollectionmap = omop
339339
batchid, opind = batchedcollectionmap[identifier(op)]
340-
if opind == 1
341-
collectionid, copind = opidcollectionmap[identifier(op)]
342-
opidmap = opids[collectionid]
343-
idsformap = batchedcollections[batchid]
344-
lower_load_collection!(q, ls, opidmap, idsformap, td, mask, inds_calc_by_ptr_offset)
340+
for (bid, oid) batchedcollectionmap # this relies on `for op ∈ ops` in codegen/operation_evaluation_order.jl
341+
if bid == batchid
342+
if oid == opind
343+
collectionid, copind = opidcollectionmap[identifier(op)]
344+
opidmap = opids[collectionid]
345+
idsformap = batchedcollections[batchid]
346+
lower_load_collection!(q, ls, opidmap, idsformap, td, mask, inds_calc_by_ptr_offset)
347+
end
348+
return nothing
349+
end
345350
end
346351
end
347352
return nothing

test/copy.jl

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,20 +151,25 @@ using LoopVectorization, OffsetArrays, Test
151151
H
152152
end
153153

154-
function issue!(output, input, idx)
154+
function issue279!(output, input, idx)
155155
@turbo for j in axes(output, 2), i in axes(output, 1)
156156
output[i, j, idx] = input[1, 1, i, j, idx]
157157
end
158158
output
159159
end
160160

161-
function issue_plain!(output, input, idx)
161+
function issue279_plain!(output, input, idx)
162162
for j in axes(output, 2), i in axes(output, 1)
163163
output[i, j, idx] = input[1, 1, i, j, idx]
164164
end
165165
output
166166
end
167-
167+
function issue280!(dest, src)
168+
@turbo for i in indices((dest, src), (2, 2))
169+
dest[1, i] = src[2, i]
170+
dest[2, i] = src[1, i]
171+
end
172+
end
168173

169174
for T (Float32, Float64, Int32, Int64)
170175
@show T, @__LINE__
@@ -250,7 +255,16 @@ using LoopVectorization, OffsetArrays, Test
250255

251256
input = rand(R, 2, 2, 5, 5, 1); output = Array{T}(undef, size(input)[3:end]...); output_plain = similar(output);
252257

253-
@test issue!(output, input, 1) issue_plain!(output_plain, input, 1)
258+
@test issue279!(output, input, 1) issue279_plain!(output_plain, input, 1)
254259

260+
src = rand(R, 2, 17); dest = similar(src);
261+
issue280!(dest, src)
262+
@test dest vcat(view(src,2,:)',view(src,1,:)')
263+
if VERSION v"1.6"
264+
src2 = reinterpret(reshape,R,Vector{Tuple{T,T}}(undef, 17)); src2 .= src;
265+
dest2 = reinterpret(reshape,R,Vector{Tuple{T,T}}(undef, 17));
266+
issue280!(dest2, src2)
267+
@test dest2 vcat(view(src,2,:)',view(src,1,:)')
268+
end
255269
end
256270
end

0 commit comments

Comments
 (0)