Skip to content

Commit a2e40aa

Browse files
committed
Fix crash from #134.
1 parent 1238699 commit a2e40aa

File tree

2 files changed

+65
-26
lines changed

2 files changed

+65
-26
lines changed

src/lower_store.jl

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,37 +3,43 @@ using VectorizationBase: vnoaliasstore!
33

44
@inline vstoreadditivereduce!(args...) = vnoaliasstore!(args...)
55
@inline vstoremultiplicativevereduce!(args...) = vnoaliasstore!(args...)
6-
@inline function vstoreadditivereduce!(ptr::VectorizationBase.AbstractStridedPointer, v::VectorizationBase.SVec, i::NTuple{N,<:Integer}) where {N}
6+
@inline function vstoreadditivereduce!(ptr::VectorizationBase.AbstractStridedPointer, v::VectorizationBase.SVec, i::Tuple{Vararg{Union{Integer,Static}}})
77
vnoaliasstore!(ptr, SIMDPirates.vsum(v), i)
88
end
9-
@inline function vstoreadditivereduce!(ptr::VectorizationBase.AbstractStridedPointer, v::VectorizationBase.SVec, i::NTuple{N,<:Integer}, m::VectorizationBase.Mask) where {N}
9+
@inline function vstoreadditivereduce!(ptr::VectorizationBase.AbstractStridedPointer, v::VectorizationBase.SVec, i::Tuple{Vararg{Union{Integer,Static}}}, m::VectorizationBase.Mask)
1010
vnoaliasstore!(ptr, SIMDPirates.vsum(v), i, m)
1111
end
12-
@inline function vstoremultiplicativevereduce!(ptr::VectorizationBase.AbstractStridedPointer, v::VectorizationBase.SVec, i::NTuple{N,<:Integer}) where {N}
12+
@inline function vstoremultiplicativevereduce!(ptr::VectorizationBase.AbstractStridedPointer, v::VectorizationBase.SVec, i::Tuple{Vararg{Union{Integer,Static}}})
1313
vnoaliasstore!(ptr, SIMDPirates.vprod(v), i)
1414
end
15-
@inline function vstoremultiplicativevereduce!(ptr::VectorizationBase.AbstractStridedPointer, v::VectorizationBase.SVec, i::NTuple{N,<:Integer}, m::VectorizationBase.Mask) where {N}
15+
@inline function vstoremultiplicativevereduce!(ptr::VectorizationBase.AbstractStridedPointer, v::VectorizationBase.SVec, i::Tuple{Vararg{Union{Integer,Static}}}, m::VectorizationBase.Mask)
1616
vnoaliasstore!(ptr, SIMDPirates.vprod(v), i, m)
1717
end
1818

19-
function storeinstr(op::Operation)
19+
function storeinstr(op::Operation, vectorized::Symbol)
20+
# defaultstoreop = :vstore!
21+
defaultstoreop = :vnoaliasstore!
22+
vectorized reduceddependencies(op) && return lv(defaultstoreop)
23+
vectorized loopdependencies(op) && return lv(defaultstoreop)
24+
# vectorized is not a loopdep, but is a reduced dep
2025
opp = first(parents(op))
21-
if instruction(opp).instr === :identity
22-
opp = first(parents(opp))
26+
while vectorized loopdependencies(opp)
27+
oppold = opp
28+
for oppp parents(opp)
29+
if vectorized reduceddependencies(oppp)
30+
@assert opp !== oppp "More than one parent is a reduction over the vectorized variable."
31+
opp = oppp
32+
end
33+
end
34+
@assert opp !== oppold "Failed to find any parents "
2335
end
24-
defaultstoreop = :vnoaliasstore!
25-
# defaultstoreop = :vstore!
26-
instr = if iszero(length(reduceddependencies(opp)))
36+
instr_class = reduction_instruction_class(instruction(opp))
37+
instr = if instr_class === ADDITIVE_IN_REDUCTIONS
38+
:vstoreadditivereduce!
39+
elseif instr_class === MULTIPLICATIVE_IN_REDUCTIONS
40+
:vstoremultiplicativevereduce!
41+
else #FIXME
2742
defaultstoreop
28-
else
29-
instr_class = reduction_instruction_class(instruction(opp))
30-
if instr_class === ADDITIVE_IN_REDUCTIONS
31-
:vstoreadditivereduce!
32-
elseif instr_class === MULTIPLICATIVE_IN_REDUCTIONS
33-
:vstoremultiplicativevereduce!
34-
else #FIXME
35-
defaultstoreop
36-
end
3743
end
3844
lv(instr)
3945
end
@@ -117,7 +123,7 @@ function lower_conditionalstore_scalar!(
117123
varname = varassignname(mvar, u, opu₁)
118124
condvarname = varassignname(condvar, u, condu₁)
119125
td = UnrollArgs(ua, u)
120-
push!(q.args, Expr(:&&, condvarname, Expr(:call, storeinstr(op), ptr, varname, mem_offset_u(op, td, inds_calc_by_ptr_offset))))
126+
push!(q.args, Expr(:&&, condvarname, Expr(:call, storeinstr(op, vectorized), ptr, varname, mem_offset_u(op, td, inds_calc_by_ptr_offset))))
121127
end
122128
nothing
123129
end
@@ -145,7 +151,7 @@ function lower_conditionalstore_vectorized!(
145151
td = UnrollArgs(ua, u)
146152
name, mo = name_memoffset(mvar, op, td, opu₁, inds_calc_by_ptr_offset)
147153
condvarname = varassignname(condvar, u, condu₁)
148-
instrcall = Expr(:call, storeinstr(op), ptr, name, mo)
154+
instrcall = Expr(:call, storeinstr(op, vectorized), ptr, name, mo)
149155
if mask !== nothing && (vecnotunrolled || u == U - 1)
150156
push!(instrcall.args, Expr(:call, :&, condvarname, mask))
151157
else
@@ -166,7 +172,7 @@ function lower_store_scalar!(
166172
for u 0:u₁-1
167173
varname = varassignname(mvar, u, opu₁)
168174
td = UnrollArgs(ua, u)
169-
push!(q.args, Expr(:call, storeinstr(op), ptr, varname, mem_offset_u(op, td, inds_calc_by_ptr_offset)))
175+
push!(q.args, Expr(:call, storeinstr(op, vectorized), ptr, varname, mem_offset_u(op, td, inds_calc_by_ptr_offset)))
170176
end
171177
nothing
172178
end
@@ -191,7 +197,7 @@ function lower_store_vectorized!(
191197
for u umin:U-1
192198
td = UnrollArgs(ua, u)
193199
name, mo = name_memoffset(mvar, op, td, opu₁, inds_calc_by_ptr_offset)
194-
instrcall = Expr(:call, storeinstr(op), ptr, name, mo)
200+
instrcall = Expr(:call, storeinstr(op, vectorized), ptr, name, mo)
195201
if mask !== nothing && (vecnotunrolled || u == U - 1)
196202
push!(instrcall.args, mask)
197203
end

test/miscellaneous.jl

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,35 @@ function splitintonoloop_reference(U = randn(2,2), E1 = randn(2))
692692
end
693693
U, E1
694694
end
695+
function findreducedparentfornonvecstoreavx!(U::AbstractMatrix{T}, E1::AbstractVector{T}) where T
696+
n,k = size(U)
697+
_s = zero(T)
698+
a = 1.0
699+
@avx for j = 1:k
700+
for i = 1:n
701+
t = tanh(a * U[i,j])
702+
U[i,j] = t
703+
_s += a * (1 - t^2)
704+
end
705+
E1[j] = _s / n
706+
end
707+
U,E1
708+
end
709+
function findreducedparentfornonvecstore!(U::AbstractMatrix{T}, E1::AbstractVector{T}) where T
710+
n,k = size(U)
711+
_s = zero(T)
712+
a = 1.0
713+
for j = 1:k
714+
for i = 1:n
715+
t = tanh(a * U[i,j])
716+
U[i,j] = t
717+
_s += a * (1 - t^2)
718+
end
719+
E1[j] = _s / n
720+
end
721+
U,E1
722+
end
723+
695724

696725

697726

@@ -898,11 +927,15 @@ end
898927
R .+= randn.(T); Rc = copy(R);
899928
@test maxavx!(R, Q, true) == max.(vec(maximum(Q, dims=(2,3))), Rc)
900929

901-
U1 = randn(5,7); E1 = randn(7);
902-
U2, E2 = splitintonoloop_reference(copy(U1), copy(E1));
903-
splitintonoloop(U1, E1);
930+
U0 = randn(5,7); E0 = randn(7);
931+
U1, E1 = splitintonoloop_reference(copy(U0), copy(E0));
932+
U2, E2 = splitintonoloop(copy(U0), copy(E0));
904933
@test U1 U2
905934
@test E1 E2
935+
U3, E3 = findreducedparentfornonvecstoreavx!(copy(U0), copy(E0));
936+
findreducedparentfornonvecstore!(U0, E0);
937+
@test U0 U3
938+
@test E0 E3
906939
end
907940
for T [Int16, Int32, Int64]
908941
n = 8sizeof(T) - 1

0 commit comments

Comments
 (0)