Skip to content

Commit 7d983ec

Browse files
committed
Fix scattercopy test
1 parent 2fdb50f commit 7d983ec

File tree

3 files changed

+49
-25
lines changed

3 files changed

+49
-25
lines changed

src/LoopVectorization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ using VectorizationBase: register_size, register_count, cache_linesize, cache_si
2121
contract_and, collapse_and,
2222
contract_or, collapse_or,
2323
num_threads, num_cores,
24-
max_mask
24+
max_mask#,zero_mask
2525

2626

2727
using IfElse: ifelse

src/codegen/lower_store.jl

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ function lower_store_collection!(
142142
end
143143
nothing
144144
end
145+
gf(s::Symbol, n::Int) = Expr(:call, GlobalRef(Core,:getfield), s, n, false)
145146
function lower_store!(
146147
q::Expr, ls::LoopSet, op::Operation, ua::UnrollArgs, mask::Bool,
147148
reductfunc::Symbol = storeinstr_preprend(op, ua.vloop.itersymbol), inds_calc_by_ptr_offset = indices_calculated_by_pointer_offsets(ls, op.ref)
@@ -154,7 +155,6 @@ function lower_store!(
154155
(opind == 1) && lower_store_collection!(q, ls, op, ua, mask, inds_calc_by_ptr_offset)
155156
return
156157
end
157-
158158
falseexpr = Expr(:call, lv(:False));
159159
aliasexpr = falseexpr;
160160
# trueexpr = Expr(:call, lv(:True));
@@ -179,40 +179,48 @@ function lower_store!(
179179
add_memory_mask!(storeexpr, op, ua, mask, ls)
180180
push!(storeexpr.args, falseexpr, aliasexpr, falseexpr, rs)
181181
push!(q.args, storeexpr)
182-
elseif (u₁ > 1) & isu₁
182+
else
183+
parents_op = parents(op)
184+
data_u₁ = isu₁ & (u₁ > 1)
185+
186+
indices_u₁ = data_u₁
187+
if !data_u₁ & (length(parents_op) > 1)
188+
indices_u₁ = first(isunrolled_sym(op, u₁loopsym, u₂loopsym, vloopsym, ls))
189+
end
190+
if indices_u₁
183191
mvard = Symbol(mvar, "##data##")
184192
# isu₁ &&
185-
push!(q.args, Expr(:(=), mvard, Expr(:call, lv(:data), mvar)))
193+
data_u₁ && push!(q.args, Expr(:(=), mvard, Expr(:call, lv(:data), mvar)))
186194
for u 1:u₁
187-
mvaru = :(getfield($mvard, $u, false))
188-
inds = mem_offset_u(op, ua, inds_calc_by_ptr_offset, true, u-1, ls)
189-
# @show isu₁unrolled(opp), opp
190-
storeexpr = if isu₁
191-
if reductfunc === Symbol("")
192-
Expr(:call, lv(:_vstore!), vptr(op), mvaru, inds)
193-
else
194-
Expr(:call, lv(:_vstore!), lv(reductfunc), vptr(op), mvaru, inds)
195-
end
196-
elseif reductfunc === Symbol("")
197-
Expr(:call, lv(:_vstore!), vptr(op), mvar, inds)
195+
inds = mem_offset_u(op, ua, inds_calc_by_ptr_offset, true, u-1, ls)
196+
# @show isu₁unrolled(opp), opp
197+
storeexpr = if data_u₁
198+
if reductfunc === Symbol("")
199+
Expr(:call, lv(:_vstore!), vptr(op), gf(mvard,u), inds)
198200
else
199-
Expr(:call, lv(:_vstore!), lv(reductfunc), vptr(op), mvar, inds)
201+
Expr(:call, lv(:_vstore!), lv(reductfunc), vptr(op), mvaru, inds)
200202
end
201-
domask = mask && (isvectorized(op) & ((u == u₁) | (vloopsym !== u₁loopsym)))
202-
add_memory_mask!(storeexpr, op, ua, domask, ls)# & ((u == u₁) | isvectorized(op)))
203-
push!(storeexpr.args, falseexpr, aliasexpr, falseexpr, rs)
204-
push!(q.args, storeexpr)
203+
elseif reductfunc === Symbol("")
204+
Expr(:call, lv(:_vstore!), vptr(op), mvar, inds)
205+
else
206+
Expr(:call, lv(:_vstore!), lv(reductfunc), vptr(op), mvar, inds)
207+
end
208+
domask = mask && (isvectorized(op) & ((u == u₁) | (vloopsym !== u₁loopsym)))
209+
add_memory_mask!(storeexpr, op, ua, domask, ls)# & ((u == u₁) | isvectorized(op)))
210+
push!(storeexpr.args, falseexpr, aliasexpr, falseexpr, rs)
211+
push!(q.args, storeexpr)
205212
end
206-
else
213+
else
207214
inds = mem_offset_u(op, ua, inds_calc_by_ptr_offset, true, 0, ls)
208215
storeexpr = if reductfunc === Symbol("")
209-
Expr(:call, lv(:_vstore!), vptr(op), mvar, inds)
216+
Expr(:call, lv(:_vstore!), vptr(op), mvar, inds)
210217
else
211-
Expr(:call, lv(:_vstore!), lv(reductfunc), vptr(op), mvar, inds)
218+
Expr(:call, lv(:_vstore!), lv(reductfunc), vptr(op), mvar, inds)
212219
end
213220
add_memory_mask!(storeexpr, op, ua, mask, ls)
214221
push!(storeexpr.args, falseexpr, aliasexpr, falseexpr, rs)
215222
push!(q.args, storeexpr)
223+
end
216224
end
217225
nothing
218226
end

test/copy.jl

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,19 @@ using LoopVectorization, OffsetArrays, Test
138138
end
139139
m
140140
end
141-
141+
function scattercopyavx!(H,a,j)
142+
@avx for i eachindex(j), k eachindex(a)
143+
H[j[i],k] = A[k]
144+
end
145+
H
146+
end
147+
function scattercopy!(H,a,j)
148+
@inbounds for i eachindex(j), k eachindex(a)
149+
H[j[i],k] = A[k]
150+
end
151+
H
152+
end
153+
142154
for T (Float32, Float64, Int32, Int64)
143155
@show T, @__LINE__
144156
R = T <: Integer ? (-T(100):T(100)) : T
@@ -215,6 +227,10 @@ using LoopVectorization, OffsetArrays, Test
215227
@test copy3!(y, x) == x
216228
fill!(y,0);
217229
@test copyselfdot!(y, x) x[1]^2 + x[2]^2
218-
@test view(x, 1:2) == view(y, 1:2)
230+
@test view(x, 1:2) == view(y, 1:2)
231+
232+
H0 = zeros(10,10); H1 = zeros(10,10);
233+
j = [1,2,5,8]; a = rand(10);
234+
@test scattercopyavx!(H0, a, j) == scattercopy!(H1, a, j)
219235
end
220236
end

0 commit comments

Comments
 (0)