@@ -102,8 +102,9 @@ function add_prefetches!(q::Expr, ls::LoopSet, op::Operation, td::UnrollArgs, pr
102
102
end
103
103
nothing
104
104
end
105
+ broadcastedname (mvar) = Symbol (mvar, " ##broadcasted##" )
105
106
function pushbroadcast! (q:: Expr , mvar:: Symbol )
106
- push! (q. args, Expr (:(= ), Symbol (mvar, " ##broadcasted## " ), Expr (:call , lv (:vbroadcast ), VECTORWIDTHSYMBOL, mvar)))
107
+ push! (q. args, Expr (:(= ), broadcastedname (mvar), Expr (:call , lv (:vbroadcast ), VECTORWIDTHSYMBOL, mvar)))
107
108
end
108
109
function lower_load_no_optranslation! (
109
110
q:: Expr , ls:: LoopSet , op:: Operation , td:: UnrollArgs , mask:: Bool , inds_calc_by_ptr_offset:: Vector{Bool}
@@ -119,7 +120,7 @@ function lower_load_no_optranslation!(
119
120
120
121
if all (op. ref. loopedindex)
121
122
inds = unrolledindex (op, td, mask, inds_calc_by_ptr_offset)
122
- loadexpr = Expr (:call , lv (:vload ), vptr (op), inds)
123
+ loadexpr = Expr (:call , lv (:_vload ), vptr (op), inds)
123
124
add_memory_mask! (loadexpr, op, td, mask)
124
125
push! (loadexpr. args, falseexpr, rs) # unaligned load
125
126
push! (q. args, Expr (:(= ), mvar, loadexpr))
@@ -128,7 +129,7 @@ function lower_load_no_optranslation!(
128
129
# for u ∈ 1:u₁
129
130
let t = u₁, t = q
130
131
inds = mem_offset_u (op, td, inds_calc_by_ptr_offset, true , u- 1 )
131
- loadexpr = Expr (:call , lv (:vload ), vptr (op), inds)
132
+ loadexpr = Expr (:call , lv (:_vload ), vptr (op), inds)
132
133
add_memory_mask! (loadexpr, op, td, mask & ((u == u₁) | isvectorized (op)))
133
134
push! (loadexpr. args, falseexpr, rs)
134
135
# push!(t.args, loadexpr)
@@ -137,7 +138,7 @@ function lower_load_no_optranslation!(
137
138
# push!(q.args, Expr(:(=), mvar, Expr(:call, lv(:VecUnroll), t)))
138
139
else
139
140
inds = mem_offset_u (op, td, inds_calc_by_ptr_offset, true , 0 )
140
- loadexpr = Expr (:call , lv (:vload ), vptr (op), inds)
141
+ loadexpr = Expr (:call , lv (:_vload ), vptr (op), inds)
141
142
add_memory_mask! (loadexpr, op, td, mask)
142
143
push! (loadexpr. args, falseexpr, rs)
143
144
push! (q. args, Expr (:(= ), mvar, loadexpr))
@@ -183,82 +184,95 @@ function indisvectorized(ls::LoopSet, ind::Symbol)
183
184
end
184
185
@inline firstunroll (vu:: VecUnroll ) = getfield (getfield (vu,:data ),1 ,false )
185
186
@inline firstunroll (x) = x
187
+ @inline lastunroll (vu:: VecUnroll ) = last (getfield (vu,:data ))
188
+ @inline lastunroll (x) = x
186
189
function lower_load_for_optranslation! (
187
190
q:: Expr , op:: Operation , ls:: LoopSet , td:: UnrollArgs , mask:: Bool , translationind:: Int
188
191
)
189
- @unpack u₁, u₁loopsym, u₂loopsym, vloopsym, u₂max, suffix = td
192
+ @unpack u₁loop, u₂loop, vloop, u₁, u₂max, suffix = td
193
+
194
+ # @unpack u₁, u₁loopsym, u₂loopsym, vloopsym, u₂max, suffix = td
190
195
iszero (suffix) || return
191
196
197
+ total_unroll = u₁ + u₂max - 1
198
+
199
+
200
+ mref = op. ref
201
+ inds_by_ptroff = indices_calculated_by_pointer_offsets (ls, mref)
192
202
# initial offset pointer
193
- gespinds = mem_offset (op, UnrollArgs (td, u₁), indices_calculated_by_pointer_offsets (ls, op. ref), false )
203
+
204
+ # Unroll directions can be + or -
205
+ # we want to start at minimum position.
206
+ step₁ = gethint (step (u₁loop))
207
+ step₂ = gethint (step (u₂loop))
208
+
209
+
210
+ # abs of steps are equal
211
+ #
212
+ equal_steps = step₁ == step₂
213
+ _td = UnrollArgs (u₁loop, u₂loop, vloop, total_unroll, u₂max, Core. ifelse (equal_steps, 0 , u₂max - 1 ))
214
+ gespinds = mem_offset (op, _td, inds_by_ptroff, false )
194
215
ptr = vptr (op)
195
216
gptr = Symbol (ptr, " ##GESPED##" )
196
217
for i ∈ eachindex (gespinds. args)
197
218
if i == translationind
198
- gespinds. args[i] = Expr (:call , lv (:firstunroll ), gespinds. args[i])
219
+ gespinds. args[i] = Expr (:call , lv (Core . ifelse (equal_steps, :firstunroll , :lastunroll ) ), gespinds. args[i])
199
220
else
200
221
gespinds. args[i] = Expr (:call , lv (:data ), gespinds. args[i])
201
222
end
202
223
end
203
224
push! (q. args, Expr (:(= ), gptr, Expr (:call , lv (:gesp ), ptr, gespinds)))
204
225
205
- shouldbroadcast = ( ! isvectorized (op)) && any (isvectorized, children (op) )
226
+ fill! (inds_by_ptroff, true )
206
227
207
- inds = Expr (:tuple )
208
- indices = getindicesonly (op)
209
- for (i,ind) ∈ enumerate (indices)
210
- if i == translationind # vectorized ind cannot be the translation ind
211
- push! (inds. args, Expr (:call , Expr (:curly , lv (:Static ), 0 )))
212
- elseif (ind === vloopsym) || indisvectorized (ls, ind)
213
- push! (inds. args, _MMind (Expr (:call , lv (:Zero ))))
214
- else
215
- push! (inds. args, Expr (:call , lv (:Zero )))
216
- end
217
- end
218
- variable_name0 = variable_name (op, 0 )
219
- varbase = Symbol (" ##var##" , variable_name0)
220
- loadcall = Expr (:call , lv (:vload ), gptr, copy (inds))
221
- falseexpr = Expr (:call , lv (:False )); rs = staticexpr (reg_size (ls));
222
- mask && push! (loadcall. args, MASKSYMBOL)
223
- push! (loadcall. args, falseexpr, rs)
228
+ @unpack ref, loopedindex = mref
229
+ indices = getindicesonly (ref)
230
+ old_translation_index = indices[translationind]
231
+ indices[translationind] = u₁loop. itersymbol
232
+ # getindicesonly returns a view of `getindices`
233
+ dummyref = ArrayReference (ref. array, getindices (ref), zero (getoffsets (ref)), getstrides (ref))
234
+ loopedindex[translationind] = true
235
+ dummymref = ArrayReferenceMeta (dummyref, loopedindex, gptr)
224
236
225
- varbase0 = Symbol (varbase, 0 )
226
- t = Expr (:tuple , varbase0)
227
- push! (q. args, Expr (:(= ), varbase0, loadcall))
228
- for u ∈ 1 : u₁- 1
229
- inds. args[translationind] = Expr (:call , Expr (:curly , lv (:Static ), u))
230
- loadcall = Expr (:call , lv (:vload ), gptr, copy (inds))
231
- mask && push! (loadcall. args, MASKSYMBOL)
232
- push! (loadcall. args, falseexpr, rs)
233
- varbaseu = Symbol (varbase, u)
234
- push! (q. args, Expr (:(= ), varbaseu, loadcall))
235
- push! (t. args, varbaseu)
237
+ _td = UnrollArgs (u₁loop, u₂loop, vloop, total_unroll, u₂max, - 1 )
238
+ op. ref = dummymref
239
+ _lower_load! (q, ls, op, _td, mask)
240
+ # set old values
241
+ op. ref = mref
242
+ loopedindex[translationind] = false
243
+ indices[translationind] = old_translation_index
244
+
245
+ shouldbroadcast = (! isvectorized (op)) && any (isvectorized, children (op))
246
+
247
+ # now we need to assign the `Vec`s from the `VecUnroll` to the correct name.
248
+ variable_name_u = Symbol (variable_name (op, - 1 ), ' _' , total_unroll)
249
+ variable_name_data = Symbol (variable_name_u, " ##data##" )
250
+ push! (q. args, :($ variable_name_data = getfield ($ variable_name_u, 1 )))
251
+ if shouldbroadcast
252
+ broadcasted_data = broadcastedname (variable_name_data)
253
+ push! (q. args, :($ broadcasted_data = getfield ($ (broadcastedname (variable_name_u)), 1 )))
236
254
end
237
- vecunroll_name = Symbol (variable_name0, ' _' , u₁)
238
- push! (q. args, Expr (:(= ), vecunroll_name, Expr (:call , lv (:VecUnroll ), t)))
239
- shouldbroadcast && pushbroadcast! (q, vecunroll_name)
240
- # this takes care of u₂ == 0
241
- offset = u₁
242
- for u₂ ∈ 1 : u₂max- 1
255
+ for u₂ ∈ 0 : u₂max- 1
256
+ variable_name_u₂ = Symbol (variable_name (op, u₂), ' _' , u₁)
243
257
t = Expr (:tuple )
244
- varold = varbase
245
- varbase = variable_name (op, u₂)
246
- for u ∈ 0 : u₁- 2
247
- varbaseu = Symbol (varbase, u)
248
- push! (q. args, Expr (:(= ), varbaseu, Symbol (varold, u + 1 )))
249
- push! (t. args, varbaseu)
258
+ if shouldbroadcast
259
+ tb = Expr (:tuple )
260
+ end
261
+ for u ∈ 1 : u₁
262
+ uu = if equal_steps
263
+ u + u₂
264
+ else
265
+ u - u₂ + u₂max - 1
266
+ end
267
+ push! (t. args, :(getfield ($ variable_name_data, $ uu)))
268
+ if shouldbroadcast
269
+ push! (tb. args, :(getfield ($ broadcasted_data, $ uu)))
270
+ end
271
+ end
272
+ push! (q. args, :($ variable_name_u₂ = VecUnroll ($ t)))
273
+ if shouldbroadcast
274
+ push! (q. args, :($ (broadcastedname (variable_name_u₂)) = VecUnroll ($ tb)))
250
275
end
251
- inds. args[translationind] = Expr (:call , Expr (:curly , lv (:Static ), offset))
252
- loadcall = Expr (:call , lv (:vload ), gptr, copy (inds))
253
- mask && push! (loadcall. args, MASKSYMBOL)
254
- push! (loadcall. args, falseexpr, rs)
255
- varload = Symbol (varbase, u₁ - 1 )
256
- push! (q. args, Expr (:(= ), varload, loadcall))
257
- push! (t. args, varload)
258
- offset += 1
259
- vecunroll_name = Symbol (variable_name (op, u₂), ' _' , u₁)
260
- push! (q. args, Expr (:(= ), vecunroll_name, Expr (:call , lv (:VecUnroll ), t)))
261
- shouldbroadcast && pushbroadcast! (q, vecunroll_name)
262
276
end
263
277
nothing
264
278
end
@@ -268,40 +282,34 @@ end
268
282
function lower_load! (
269
283
q:: Expr , op:: Operation , ls:: LoopSet , td:: UnrollArgs , mask:: Bool
270
284
)
271
- @unpack u₁, u₁loopsym, u₂loopsym, vloopsym, suffix = td
285
+ @unpack u₁, u₂max, u ₁loopsym, u₂loopsym, vloopsym, suffix = td
272
286
if (suffix != - 1 ) && ls. loadelimination[]
273
- istr, ispl = isoptranslation (ls, op, UnrollSymbols (u₁loopsym, u₂loopsym, vloopsym))
287
+ if (u₁ > 1 ) & (u₂max > 1 )
288
+ istr, ispl = isoptranslation (ls, op, UnrollSymbols (u₁loopsym, u₂loopsym, vloopsym))
289
+ else
290
+ istr, ispl = 0 , false
291
+ end
274
292
if ! iszero (istr) & ispl
275
293
return lower_load_for_optranslation! (q, op, ls, td, mask, istr)
276
- elseif suffix > 0
277
- if u₂loopsym != = vloopsym
278
- mno, id = maxnegativeoffset (ls, op, u₂loopsym)
279
- if - suffix < mno < 0 # already checked that `suffix != -1` above
280
- varnew = variable_name (op, suffix)
281
- varold = variable_name (operations (ls)[id], suffix + mno)
282
- opold = operations (ls)[id]
283
- u = isu₁unrolled (op) ? u₁ : 1
284
- push! (q. args, Expr (:(= ), Symbol (varnew, ' _' , u), Symbol (varold, ' _' , u)))
285
- # if isu₁unrolled(op)
286
- # for u ∈ 0:u₁-1
287
- # push!(q.args, Expr(:(=), Symbol(varnew, u), Symbol(varold, u)))
288
- # end
289
- # else
290
-
291
- # end
292
- return
293
- end
294
+ elseif (suffix > 0 ) && (u₂loopsym != = vloopsym)
295
+ mno, id = maxnegativeoffset (ls, op, u₂loopsym)
296
+ if - suffix < mno < 0 # already checked that `suffix != -1` above
297
+ varnew = variable_name (op, suffix)
298
+ varold = variable_name (operations (ls)[id], suffix + mno)
299
+ opold = operations (ls)[id]
300
+ u = isu₁unrolled (op) ? u₁ : 1
301
+ push! (q. args, Expr (:(= ), Symbol (varnew, ' _' , u), Symbol (varold, ' _' , u)))
302
+ return
294
303
end
295
304
end
296
305
end
297
306
_lower_load! (q, ls, op, td, mask)
298
307
end
299
308
function _lower_load! (
300
- q:: Expr , ls:: LoopSet , op:: Operation , td:: UnrollArgs , mask:: Bool
309
+ q:: Expr , ls:: LoopSet , op:: Operation , td:: UnrollArgs , mask:: Bool , inds_calc_by_ptr_offset :: Vector{Bool} = indices_calculated_by_pointer_offsets (ls, op . ref)
301
310
)
302
311
omop = offsetloadcollection (ls)
303
312
batchid, opind = omop. batchedcollectionmap[identifier (op)]
304
- inds_calc_by_ptr_offset = indices_calculated_by_pointer_offsets (ls, op. ref)
305
313
# @show batchid == 0 (!isvectorized(op)) rejectinterleave(op, td.vloop, idsformap)
306
314
if batchid == 0 || (! isvectorized (op)) || (rejectinterleave (op, td. vloop, omop. batchedcollections[batchid]))
307
315
lower_load_no_optranslation! (q, ls, op, td, mask, inds_calc_by_ptr_offset)
@@ -349,7 +357,7 @@ function lower_load_collection!(
349
357
end
350
358
uinds = Expr (:call , unrollcurl₂, inds)
351
359
vp = vptr (op)
352
- loadexpr = Expr (:call , lv (:vload ), vp, uinds)
360
+ loadexpr = Expr (:call , lv (:_vload ), vp, uinds)
353
361
# not using `add_memory_mask!(storeexpr, op, ua, mask)` because we checked `isconditionalmemop` earlier in `lower_load_collection!`
354
362
(mask && isvectorized (op)) && push! (loadexpr. args, MASKSYMBOL)
355
363
push! (loadexpr. args, falseexpr, rs)
0 commit comments