Skip to content

Commit 879527a

Browse files
committed
To support CartesianIndexing, we must punt to the generated function even when there is only a single loop (because that single loop could be over a CartesianIndex).
1 parent f4b0da9 commit 879527a

File tree

4 files changed

+21
-15
lines changed

4 files changed

+21
-15
lines changed

src/LoopVectorization.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using VectorizationBase: REGISTER_SIZE, REGISTER_COUNT, extract_data, num_vector
55
mask, masktable, pick_vector_width_val, valmul, valrem, valmuladd, valadd, valsub, _MM,
66
maybestaticlength, maybestaticsize, staticm1, subsetview, vzero, stridedpointer_for_broadcast,
77
Static, StaticUnitRange, StaticLowerUnitRange, StaticUpperUnitRange, unwrap, maybestaticrange,
8+
AbstractColumnMajorStridedPointer, AbstractRowMajorStridedPointer, AbstractSparseStridedPointer, AbstractStaticStridedPointer,
89
PackedStridedPointer, SparseStridedPointer, RowMajorStridedPointer, StaticStridedPointer, StaticStridedStruct,
910
maybestaticfirst, maybestaticlast
1011
using SIMDPirates: VECTOR_SYMBOLS, evadd, evsub, evmul, evfdiv, vrange, reduced_add, reduced_prod, reduce_to_add, reduce_to_prod,

src/condense_loopset.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -392,11 +392,7 @@ function setup_call(ls::LoopSet, inline = Int8(2), U = zero(Int8), T = zero(Int8
392392
# Creating an anonymous function and calling it also achieves the outlining, while still
393393
# inlining the generated function into the loop preamble.
394394
if inline == Int8(2)
395-
if num_loops(ls) == 1
396-
iszero(U) ? lower(ls) : lower(ls, U, -one(U))
397-
else
398-
setup_call_inline(ls, U, T)
399-
end
395+
setup_call_inline(ls, U, T)
400396
else
401397
setup_call_noinline(ls, U, T)
402398
end

src/reconstruct_loopset.jl

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,27 +105,35 @@ function pushvarg′!(ls::LoopSet, ar::ArrayReferenceMeta, i)
105105
reverse!(ar.loopedindex); reverse!(getindices(ar)) # reverse the listed indices here, and transpose it to make it column major
106106
pushpreamble!(ls, Expr(:(=), vptr(ar), Expr(:call, lv(:transpose), extract_varg(i))))
107107
end
108-
function add_mref!(ls::LoopSet, ar::ArrayReferenceMeta, i::Int, ::Type{PackedStridedPointer{T, N}}) where {T, N}
108+
function add_mref!(
109+
ls::LoopSet, ar::ArrayReferenceMeta, i::Int, ::Type{S}
110+
) where {T, N, S <: AbstractColumnMajorStridedPointer{T,N}}
109111
pushvarg!(ls, ar, i)
110112
end
111-
function add_mref!(ls::LoopSet, ar::ArrayReferenceMeta, i::Int, ::Type{RowMajorStridedPointer{T, N}}) where {T, N}
113+
function add_mref!(
114+
ls::LoopSet, ar::ArrayReferenceMeta, i::Int, ::Type{S}
115+
) where {T, N, S <: AbstractRowMajorStridedPointer{T, N}}
112116
pushvarg′!(ls, ar, i)
113117
end
114-
function add_mref!(ls::LoopSet, ar::ArrayReferenceMeta, i::Int, ::Type{OffsetStridedPointer{T,N,P}}) where {T,N,P}
118+
function add_mref!(
119+
ls::LoopSet, ar::ArrayReferenceMeta, i::Int, ::Type{OffsetStridedPointer{T,N,P}}
120+
) where {T,N,P}
115121
add_mref!(ls, ar, i, P)
116122
end
117123

118124
function add_mref!(
119125
ls::LoopSet, ar::ArrayReferenceMeta, i::Int, ::Type{S}
120-
) where {T, X <: Tuple, S <: VectorizationBase.AbstractStaticStridedPointer{T,X}}
126+
) where {T, X <: Tuple, S <: AbstractStaticStridedPointer{T,X}}
121127
if last(X.parameters)::Int == 1
122128
pushvarg′!(ls, ar, i)
123129
else
124130
pushvarg!(ls, ar, i)
125131
first(X.parameters)::Int == 1 || pushfirst!(getindices(ar), Symbol("##DISCONTIGUOUSSUBARRAY##"))
126132
end
127133
end
128-
function add_mref!(ls::LoopSet, ar::ArrayReferenceMeta, i::Int, ::Type{SparseStridedPointer{T, N}}) where {T, N}
134+
function add_mref!(
135+
ls::LoopSet, ar::ArrayReferenceMeta, i::Int, ::Type{S}
136+
) where {T, N, S <: AbstractSparseStridedPointer{T, N}}
129137
pushvarg!(ls, ar, i)
130138
pushfirst!(getindices(ar), Symbol("##DISCONTIGUOUSSUBARRAY##"))
131139
end
@@ -222,6 +230,7 @@ function calcnops(ls::LoopSet, os::OperationStruct)
222230
end
223231
offsets = ls.loopsymbol_offsets
224232
idxs = loopindex(ls, os.loopdeps, 0x04) # FIXME DRY
233+
iszero(length(idxs)) && return 1
225234
Δidxs = map(i->offsets[i+1]-offsets[i], idxs)
226235
nops = first(Δidxs)
227236
@assert all(isequal(nops), Δidxs)

test/ifelsemasks.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -339,15 +339,15 @@ T = Float32
339339
fill!(c2, -999999999); maybewriteoravx!(c2, a, b)
340340
@test c1 c2
341341

342-
andorassignment!(c1, a, b)
343-
andorassignmentavx!(c2, a, b)
342+
andorassignment!(c1, a, b);
343+
andorassignmentavx!(c2, a, b);
344344
@test c1 c2
345-
fill!(c2, -999999999); andorassignment_avx!(c2, a, b)
345+
fill!(c2, -999999999); andorassignment_avx!(c2, a, b);
346346
@test c1 c2
347347

348348
if T <: Union{Float32,Float64}
349349
a .*= 100;
350-
end
350+
end;
351351
b1 = copy(a);
352352
b2 = copy(a);
353353
condstore!(b1)
@@ -384,7 +384,7 @@ T = Float32
384384
t = Bernoulli_logit(bit, a);
385385
@test t Bernoulli_logitavx(bit, a)
386386
@test t Bernoulli_logit_avx(bit, a)
387-
a = rand(43)
387+
a = rand(43);
388388
bit = a .> 0.5;
389389
t = Bernoulli_logit(bit, a);
390390
@test t Bernoulli_logitavx(bit, a)

0 commit comments

Comments
 (0)