@@ -3,37 +3,43 @@ using VectorizationBase: vnoaliasstore!
3
3
4
4
@inline vstoreadditivereduce! (args... ) = vnoaliasstore! (args... )
5
5
@inline vstoremultiplicativevereduce! (args... ) = vnoaliasstore! (args... )
6
- @inline function vstoreadditivereduce! (ptr:: VectorizationBase.AbstractStridedPointer , v:: VectorizationBase.SVec , i:: NTuple{N,<: Integer} ) where {N}
6
+ @inline function vstoreadditivereduce! (ptr:: VectorizationBase.AbstractStridedPointer , v:: VectorizationBase.SVec , i:: Tuple{Vararg{Union{ Integer,Static}}} )
7
7
vnoaliasstore! (ptr, SIMDPirates. vsum (v), i)
8
8
end
9
- @inline function vstoreadditivereduce! (ptr:: VectorizationBase.AbstractStridedPointer , v:: VectorizationBase.SVec , i:: NTuple{N,<: Integer} , m:: VectorizationBase.Mask ) where {N}
9
+ @inline function vstoreadditivereduce! (ptr:: VectorizationBase.AbstractStridedPointer , v:: VectorizationBase.SVec , i:: Tuple{Vararg{Union{ Integer,Static}}} , m:: VectorizationBase.Mask )
10
10
vnoaliasstore! (ptr, SIMDPirates. vsum (v), i, m)
11
11
end
12
- @inline function vstoremultiplicativevereduce! (ptr:: VectorizationBase.AbstractStridedPointer , v:: VectorizationBase.SVec , i:: NTuple{N,<: Integer} ) where {N}
12
+ @inline function vstoremultiplicativevereduce! (ptr:: VectorizationBase.AbstractStridedPointer , v:: VectorizationBase.SVec , i:: Tuple{Vararg{Union{ Integer,Static}}} )
13
13
vnoaliasstore! (ptr, SIMDPirates. vprod (v), i)
14
14
end
15
- @inline function vstoremultiplicativevereduce! (ptr:: VectorizationBase.AbstractStridedPointer , v:: VectorizationBase.SVec , i:: NTuple{N,<: Integer} , m:: VectorizationBase.Mask ) where {N}
15
+ @inline function vstoremultiplicativevereduce! (ptr:: VectorizationBase.AbstractStridedPointer , v:: VectorizationBase.SVec , i:: Tuple{Vararg{Union{ Integer,Static}}} , m:: VectorizationBase.Mask )
16
16
vnoaliasstore! (ptr, SIMDPirates. vprod (v), i, m)
17
17
end
18
18
19
- function storeinstr (op:: Operation )
19
+ function storeinstr (op:: Operation , vectorized:: Symbol )
20
+ # defaultstoreop = :vstore!
21
+ defaultstoreop = :vnoaliasstore!
22
+ vectorized ∉ reduceddependencies (op) && return lv (defaultstoreop)
23
+ vectorized ∈ loopdependencies (op) && return lv (defaultstoreop)
24
+ # vectorized is not a loopdep, but is a reduced dep
20
25
opp = first (parents (op))
21
- if instruction (opp). instr === :identity
22
- opp = first (parents (opp))
26
+ while vectorized ∉ loopdependencies (opp)
27
+ oppold = opp
28
+ for oppp ∈ parents (opp)
29
+ if vectorized ∈ reduceddependencies (oppp)
30
+ @assert opp != = oppp " More than one parent is a reduction over the vectorized variable."
31
+ opp = oppp
32
+ end
33
+ end
34
+ @assert opp != = oppold " Failed to find any parents "
23
35
end
24
- defaultstoreop = :vnoaliasstore!
25
- # defaultstoreop = :vstore!
26
- instr = if iszero (length (reduceddependencies (opp)))
36
+ instr_class = reduction_instruction_class (instruction (opp))
37
+ instr = if instr_class === ADDITIVE_IN_REDUCTIONS
38
+ :vstoreadditivereduce!
39
+ elseif instr_class === MULTIPLICATIVE_IN_REDUCTIONS
40
+ :vstoremultiplicativevereduce!
41
+ else # FIXME
27
42
defaultstoreop
28
- else
29
- instr_class = reduction_instruction_class (instruction (opp))
30
- if instr_class === ADDITIVE_IN_REDUCTIONS
31
- :vstoreadditivereduce!
32
- elseif instr_class === MULTIPLICATIVE_IN_REDUCTIONS
33
- :vstoremultiplicativevereduce!
34
- else # FIXME
35
- defaultstoreop
36
- end
37
43
end
38
44
lv (instr)
39
45
end
@@ -117,7 +123,7 @@ function lower_conditionalstore_scalar!(
117
123
varname = varassignname (mvar, u, opu₁)
118
124
condvarname = varassignname (condvar, u, condu₁)
119
125
td = UnrollArgs (ua, u)
120
- push! (q. args, Expr (:&& , condvarname, Expr (:call , storeinstr (op), ptr, varname, mem_offset_u (op, td, inds_calc_by_ptr_offset))))
126
+ push! (q. args, Expr (:&& , condvarname, Expr (:call , storeinstr (op, vectorized ), ptr, varname, mem_offset_u (op, td, inds_calc_by_ptr_offset))))
121
127
end
122
128
nothing
123
129
end
@@ -145,7 +151,7 @@ function lower_conditionalstore_vectorized!(
145
151
td = UnrollArgs (ua, u)
146
152
name, mo = name_memoffset (mvar, op, td, opu₁, inds_calc_by_ptr_offset)
147
153
condvarname = varassignname (condvar, u, condu₁)
148
- instrcall = Expr (:call , storeinstr (op), ptr, name, mo)
154
+ instrcall = Expr (:call , storeinstr (op, vectorized ), ptr, name, mo)
149
155
if mask != = nothing && (vecnotunrolled || u == U - 1 )
150
156
push! (instrcall. args, Expr (:call , :& , condvarname, mask))
151
157
else
@@ -166,7 +172,7 @@ function lower_store_scalar!(
166
172
for u ∈ 0 : u₁- 1
167
173
varname = varassignname (mvar, u, opu₁)
168
174
td = UnrollArgs (ua, u)
169
- push! (q. args, Expr (:call , storeinstr (op), ptr, varname, mem_offset_u (op, td, inds_calc_by_ptr_offset)))
175
+ push! (q. args, Expr (:call , storeinstr (op, vectorized ), ptr, varname, mem_offset_u (op, td, inds_calc_by_ptr_offset)))
170
176
end
171
177
nothing
172
178
end
@@ -191,7 +197,7 @@ function lower_store_vectorized!(
191
197
for u ∈ umin: U- 1
192
198
td = UnrollArgs (ua, u)
193
199
name, mo = name_memoffset (mvar, op, td, opu₁, inds_calc_by_ptr_offset)
194
- instrcall = Expr (:call , storeinstr (op), ptr, name, mo)
200
+ instrcall = Expr (:call , storeinstr (op, vectorized ), ptr, name, mo)
195
201
if mask != = nothing && (vecnotunrolled || u == U - 1 )
196
202
push! (instrcall. args, mask)
197
203
end
0 commit comments