Skip to content

Commit 728a6b2

Browse files
committed
Fix mask bug when iteration doesn't start at 1, resolves #52.
1 parent c69c210 commit 728a6b2

File tree

5 files changed

+48
-13
lines changed

5 files changed

+48
-13
lines changed

src/add_loads.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,5 +77,5 @@ struct LoopValue end
7777
@inline SIMDPirates.vload(::LoopValue, i::Tuple{_MM{W}}, ::Unsigned) where {W} = SVec(SIMDPirates.vrangeincr(Val{W}(), @inbounds(i[1].i), Val{1}()))
7878
@inline VectorizationBase.load(::LoopValue, i::Integer) = i + one(i)
7979
@inline VectorizationBase.load(::LoopValue, i::Tuple{I}) where {I<:Integer} = @inbounds(i[1]) + one(I)
80-
@inline Base.eltype(::LoopValue) = Int8
80+
@inline Base.eltype(::LoopValue) = Int
8181

src/condense_loopset.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,11 @@ function setup_call(ls::LoopSet, inline = Int8(2), U = zero(Int8), T = zero(Int8
380380
# Creating an anonymous function and calling it also achieves the outlining, while still
381381
# inlining the generated function into the loop preamble.
382382
if inline == Int8(2)
383-
setup_call_inline(ls, U, T)
383+
if num_loops(ls) == 1
384+
iszero(U) ? lower(ls) : lower(ls, U, -one(U))
385+
else
386+
setup_call_inline(ls, U, T)
387+
end
384388
else
385389
setup_call_noinline(ls, U, T)
386390
end

src/constructors.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@ macro avx(arg, q)
134134
@assert q.head === :for
135135
@assert arg.head === :(=)
136136
inline, U, T = check_macro_kwarg(arg)
137-
esc(setup_call(LoopSet(q, __module__), inline, U, T))
137+
ls = LoopSet(q, __module__)
138+
esc(setup_call(ls, inline, U, T))
138139
end
139140
macro avx(arg1, arg2, q)
140141
@assert q.head === :for

src/lowering.jl

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -226,13 +226,21 @@ function determine_eltype(ls::LoopSet)
226226
end
227227
promote_q
228228
end
229-
function determine_width(ls::LoopSet, typeT::Symbol, unrolled::Symbol)
230-
unrolledloop = getloop(ls, unrolled)
231-
if isstaticloop(unrolledloop)
232-
Expr(:call, lv(:pick_vector_width_val), Expr(:call, Expr(:curly, :Val, length(unrolledloop))), typeT)
233-
else
234-
Expr(:call, lv(:pick_vector_width_val), typeT)
229+
function determine_width(ls::LoopSet, vectorized::Symbol)
230+
vloop = getloop(ls, vectorized)
231+
vwidth_q = Expr(:call, lv(:pick_vector_width_val))
232+
if isstaticloop(vloop)
233+
push!(vwidth_q.args, Expr(:call, Expr(:curly, :Val, length(vloop))))
235234
end
235+
push!(vwidth_q.args, ls.T)
236+
# if length(ls.includedactualarrays) < 2
237+
# push!(vwidth_q.args, ls.T)
238+
# else
239+
# for array ∈ ls.includedactualarrays
240+
# push!(vwidth_q.args, Expr(:call, :eltype, array))
241+
# end
242+
# end
243+
vwidth_q
236244
end
237245
function lower_unrolled!(
238246
q::Expr, ls::LoopSet, vectorized::Symbol, U::Int, T::Int, W::Symbol, typeT::Symbol, unrolledloop::Loop
@@ -329,7 +337,7 @@ end
329337
function definemask(loop::Loop, W::Symbol, allon::Bool)
330338
if isstaticloop(loop)
331339
maskexpr(W, length(loop), allon)
332-
elseif loop.starthint == 0
340+
elseif loop.startexact && loop.starthint == 0
333341
maskexpr(W, loop.stopsym, allon)
334342
else
335343
lexpr = if loop.startexact
@@ -349,7 +357,7 @@ using SIMDPirates: sizeequivalentfloat, sizeequivalentint
349357
function setup_preamble!(ls::LoopSet, W::Symbol, typeT::Symbol, vectorized::Symbol, unrolled::Symbol, tiled::Symbol, U::Int)
350358
# println("Setup preamble")
351359
length(ls.includedarrays) == 0 || push!(ls.preamble.args, Expr(:(=), typeT, determine_eltype(ls)))
352-
push!(ls.preamble.args, Expr(:(=), W, determine_width(ls, typeT, unrolled)))
360+
push!(ls.preamble.args, Expr(:(=), W, determine_width(ls, vectorized)))
353361
lower_licm_constants!(ls)
354362
pushpreamble!(ls, definemask(getloop(ls, vectorized), W, U > 1 && unrolled === vectorized))
355363
for op operations(ls)

test/miscellaneous.jl

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,21 @@
415415
end
416416
accu
417417
end
418-
418+
function test_for_with_different_index!(c, a, b, start_sample, num_samples)
419+
@inbounds for i = start_sample:num_samples + start_sample - 1
420+
c[i] = b[i] * a[i]
421+
end
422+
end
423+
function test_for_with_different_indexavx!(c, a, b, start_sample, num_samples)
424+
@avx for i = start_sample:num_samples + start_sample - 1
425+
c[i] = b[i] * a[i]
426+
end
427+
end
428+
function test_for_with_different_index_avx!(c, a, b, start_sample, num_samples)
429+
@_avx for i = start_sample:num_samples + start_sample - 1
430+
c[i] = b[i] * a[i]
431+
end
432+
end
419433

420434
for T (Float32, Float64)
421435
@show T, @__LINE__
@@ -467,7 +481,15 @@
467481
fill!(y2, NaN); clenshawavx!(y2,x,c)
468482
@test y1 y2
469483

470-
484+
C = randn(T, 199, 498);
485+
start_sample = 29; num_samples = 800;
486+
test_for_with_different_index!(B1, A, C, start_sample, num_samples)
487+
test_for_with_different_indexavx!(B2, A, C, start_sample, num_samples)
488+
r = start_sample:start_sample+num_samples - 1
489+
@test view(vec(B1), r) == view(vec(B2), r)
490+
fill!(B2, NaN); test_for_with_different_index_avx!(B2, A, C, start_sample, num_samples)
491+
@test view(vec(B1), r) == view(vec(B2), r)
492+
471493
ni, nj, nk = (127, 113, 13)
472494
x = rand(T, ni, nj, nk);
473495
q1 = similar(x);

0 commit comments

Comments
 (0)