@@ -142,6 +142,7 @@ function lower_store_collection!(
142
142
end
143
143
nothing
144
144
end
145
+ gf (s:: Symbol , n:: Int ) = Expr (:call , GlobalRef (Core,:getfield ), s, n, false )
145
146
function lower_store! (
146
147
q:: Expr , ls:: LoopSet , op:: Operation , ua:: UnrollArgs , mask:: Bool ,
147
148
reductfunc:: Symbol = storeinstr_preprend (op, ua. vloop. itersymbol), inds_calc_by_ptr_offset = indices_calculated_by_pointer_offsets (ls, op. ref)
@@ -154,7 +155,6 @@ function lower_store!(
154
155
(opind == 1 ) && lower_store_collection! (q, ls, op, ua, mask, inds_calc_by_ptr_offset)
155
156
return
156
157
end
157
-
158
158
falseexpr = Expr (:call , lv (:False ));
159
159
aliasexpr = falseexpr;
160
160
# trueexpr = Expr(:call, lv(:True));
@@ -179,40 +179,48 @@ function lower_store!(
179
179
add_memory_mask! (storeexpr, op, ua, mask, ls)
180
180
push! (storeexpr. args, falseexpr, aliasexpr, falseexpr, rs)
181
181
push! (q. args, storeexpr)
182
- elseif (u₁ > 1 ) & isu₁
182
+ else
183
+ parents_op = parents (op)
184
+ data_u₁ = isu₁ & (u₁ > 1 )
185
+
186
+ indices_u₁ = data_u₁
187
+ if ! data_u₁ & (length (parents_op) > 1 )
188
+ indices_u₁ = first (isunrolled_sym (op, u₁loopsym, u₂loopsym, vloopsym, ls))
189
+ end
190
+ if indices_u₁
183
191
mvard = Symbol (mvar, " ##data##" )
184
192
# isu₁ &&
185
- push! (q. args, Expr (:(= ), mvard, Expr (:call , lv (:data ), mvar)))
193
+ data_u₁ && push! (q. args, Expr (:(= ), mvard, Expr (:call , lv (:data ), mvar)))
186
194
for u ∈ 1 : u₁
187
- mvaru = :(getfield ($ mvard, $ u, false ))
188
- inds = mem_offset_u (op, ua, inds_calc_by_ptr_offset, true , u- 1 , ls)
189
- # @show isu₁unrolled(opp), opp
190
- storeexpr = if isu₁
191
- if reductfunc === Symbol (" " )
192
- Expr (:call , lv (:_vstore! ), vptr (op), mvaru, inds)
193
- else
194
- Expr (:call , lv (:_vstore! ), lv (reductfunc), vptr (op), mvaru, inds)
195
- end
196
- elseif reductfunc === Symbol (" " )
197
- Expr (:call , lv (:_vstore! ), vptr (op), mvar, inds)
195
+ inds = mem_offset_u (op, ua, inds_calc_by_ptr_offset, true , u- 1 , ls)
196
+ # @show isu₁unrolled(opp), opp
197
+ storeexpr = if data_u₁
198
+ if reductfunc === Symbol (" " )
199
+ Expr (:call , lv (:_vstore! ), vptr (op), gf (mvard,u), inds)
198
200
else
199
- Expr (:call , lv (:_vstore! ), lv (reductfunc), vptr (op), mvar , inds)
201
+ Expr (:call , lv (:_vstore! ), lv (reductfunc), vptr (op), mvaru , inds)
200
202
end
201
- domask = mask && (isvectorized (op) & ((u == u₁) | (vloopsym != = u₁loopsym)))
202
- add_memory_mask! (storeexpr, op, ua, domask, ls)# & ((u == u₁) | isvectorized(op)))
203
- push! (storeexpr. args, falseexpr, aliasexpr, falseexpr, rs)
204
- push! (q. args, storeexpr)
203
+ elseif reductfunc === Symbol (" " )
204
+ Expr (:call , lv (:_vstore! ), vptr (op), mvar, inds)
205
+ else
206
+ Expr (:call , lv (:_vstore! ), lv (reductfunc), vptr (op), mvar, inds)
207
+ end
208
+ domask = mask && (isvectorized (op) & ((u == u₁) | (vloopsym != = u₁loopsym)))
209
+ add_memory_mask! (storeexpr, op, ua, domask, ls)# & ((u == u₁) | isvectorized(op)))
210
+ push! (storeexpr. args, falseexpr, aliasexpr, falseexpr, rs)
211
+ push! (q. args, storeexpr)
205
212
end
206
- else
213
+ else
207
214
inds = mem_offset_u (op, ua, inds_calc_by_ptr_offset, true , 0 , ls)
208
215
storeexpr = if reductfunc === Symbol (" " )
209
- Expr (:call , lv (:_vstore! ), vptr (op), mvar, inds)
216
+ Expr (:call , lv (:_vstore! ), vptr (op), mvar, inds)
210
217
else
211
- Expr (:call , lv (:_vstore! ), lv (reductfunc), vptr (op), mvar, inds)
218
+ Expr (:call , lv (:_vstore! ), lv (reductfunc), vptr (op), mvar, inds)
212
219
end
213
220
add_memory_mask! (storeexpr, op, ua, mask, ls)
214
221
push! (storeexpr. args, falseexpr, aliasexpr, falseexpr, rs)
215
222
push! (q. args, storeexpr)
223
+ end
216
224
end
217
225
nothing
218
226
end
0 commit comments