@@ -283,112 +283,113 @@ function unroll_no_reductions(ls, order, vloopsym)
283
283
# (iszero(rt) ? 4 : max(1, VectorizationBase.nextpow2( min( 4, round(Int, 8 / rt) ) ))), unrolled
284
284
end
285
285
function determine_unroll_factor (
286
- ls:: LoopSet , order:: Vector{Symbol} , unrolled:: Symbol , vloopsym:: Symbol
286
+ ls:: LoopSet , order:: Vector{Symbol} , unrolled:: Symbol , vloopsym:: Symbol
287
287
)
288
- cacheunrolled! (ls, unrolled, Symbol (" " ), vloopsym)
289
- size_T = biggest_type_size (ls)
290
- W, Wshift = lsvecwidthshift (ls, vloopsym, size_T)
288
+ cacheunrolled! (ls, unrolled, Symbol (" " ), vloopsym)
289
+ size_T = biggest_type_size (ls)
290
+ W, Wshift = lsvecwidthshift (ls, vloopsym, size_T)
291
291
292
- # So if num_reductions > 0, we set the unroll factor to be high enough so that the CPU can be kept busy
293
- # if there are, U = max(1, round(Int, max(latency) * throughput / num_reductions)) = max(1, round(Int, latency / (recip_throughput * num_reductions)))
294
- # We also make sure register pressure is not too high.
295
- latency = 1.0
296
- # compute_recip_throughput_u = 0.0
297
- compute_recip_throughput = 0.0
298
- visited_nodes = fill (false , length (operations (ls)))
299
- load_recip_throughput = 0.0
300
- store_recip_throughput = 0.0
301
- for op ∈ operations (ls)
302
- if isreduction (op)
303
- rt, sl = depchain_cost! (ls, visited_nodes, op, unrolled, vloopsym, Wshift, size_T)
304
- if isouterreduction (ls, op) ≠ - 1 || unrolled ∉ reduceddependencies (op)
305
- latency = max (sl, latency)
306
- end
307
- # if unrolled ∈ loopdependencies(op)
308
- # compute_recip_throughput_u += rt
309
- # else
310
- compute_recip_throughput += rt
311
- # end
312
- elseif isload (op)
313
- load_recip_throughput += first (cost (ls, op, (unrolled,Symbol (" " )), vloopsym, Wshift, size_T))
314
- elseif isstore (op)
315
- store_recip_throughput += first (cost (ls, op, (unrolled,Symbol (" " )), vloopsym, Wshift, size_T))
316
- end
292
+ # So if num_reductions > 0, we set the unroll factor to be high enough so that the CPU can be kept busy
293
+ # if there are, U = max(1, round(Int, max(latency) * throughput / num_reductions)) = max(1, round(Int, latency / (recip_throughput * num_reductions)))
294
+ # We also make sure register pressure is not too high.
295
+ latency = 1.0
296
+ # compute_recip_throughput_u = 0.0
297
+ compute_recip_throughput = 0.0
298
+ visited_nodes = fill (false , length (operations (ls)))
299
+ load_recip_throughput = 0.0
300
+ store_recip_throughput = 0.0
301
+ for op ∈ operations (ls)
302
+ if isreduction (op)
303
+ rt, sl = depchain_cost! (ls, visited_nodes, op, unrolled, vloopsym, Wshift, size_T)
304
+ if isouterreduction (ls, op) ≠ - 1 || unrolled ∉ reduceddependencies (op)
305
+ latency = max (sl, latency)
306
+ end
307
+ # if unrolled ∈ loopdependencies(op)
308
+ # compute_recip_throughput_u += rt
309
+ # else
310
+ compute_recip_throughput += rt
311
+ # end
312
+ elseif isload (op)
313
+ load_recip_throughput += first (cost (ls, op, (unrolled,Symbol (" " )), vloopsym, Wshift, size_T))
314
+ elseif isstore (op)
315
+ store_recip_throughput += first (cost (ls, op, (unrolled,Symbol (" " )), vloopsym, Wshift, size_T))
317
316
end
318
- recip_throughput = max (
319
- compute_recip_throughput,
320
- load_recip_throughput,
321
- store_recip_throughput
322
- )
323
- recip_throughput, latency
317
+ end
318
+ recip_throughput = max (
319
+ compute_recip_throughput,
320
+ load_recip_throughput,
321
+ store_recip_throughput
322
+ )
323
+ # @show latency, recip_throughput
324
+ recip_throughput, latency
324
325
end
325
326
function count_reductions (ls:: LoopSet )
326
- num_reductions = 0
327
- for op ∈ operations (ls)
328
- if isreduction (op) & iscompute (op) && parentsnotreduction (op)
329
- num_reductions += 1
330
- end
327
+ num_reductions = 0
328
+ for op ∈ operations (ls)
329
+ if isreduction (op) & iscompute (op) && parentsnotreduction (op)
330
+ num_reductions += 1
331
331
end
332
- num_reductions
332
+ end
333
+ num_reductions
333
334
end
334
335
335
336
demote_unroll_factor (ls:: LoopSet , UF, loop:: Symbol ) = demote_unroll_factor (ls, UF, getloop (ls, loop))
336
337
function demote_unroll_factor (ls:: LoopSet , UF, loop:: Loop )
337
- W = ls. vector_width
338
- if ! iszero (W) && isstaticloop (loop)
339
- UFW = maybedemotesize (UF* W, length (loop))
340
- UF = cld (UFW, W)
341
- end
342
- UF
338
+ W = ls. vector_width
339
+ if ! iszero (W) && isstaticloop (loop)
340
+ UFW = maybedemotesize (UF* W, length (loop))
341
+ UF = cld (UFW, W)
342
+ end
343
+ UF
343
344
end
344
345
345
346
function determine_unroll_factor (ls:: LoopSet , order:: Vector{Symbol} , vloopsym:: Symbol )
346
- num_reductions = count_reductions (ls)
347
- # The strategy is to use an unroll factor of 1, unless there appears to be loop carried dependencies (ie, num_reductions > 0)
348
- # The assumption here is that unrolling provides no real benefit, unless it is needed to enable OOO execution by breaking up these dependency chains
349
- loopindexesbit = ls. loopindexesbit
350
- if iszero (length (loopindexesbit)) || ((! loopindexesbit[getloopid (ls, vloopsym)]))
351
- if iszero (num_reductions)
352
- return unroll_no_reductions (ls, order, vloopsym)
353
- else
354
- return determine_unroll_factor (ls, order, vloopsym, num_reductions)
355
- end
356
- elseif iszero (num_reductions) # handle `BitArray` loops w/out reductions
357
- return 8 ÷ ls. vector_width, vloopsym
358
- else # handle `BitArray` loops with reductions
359
- rttemp, ltemp = determine_unroll_factor (ls, order, vloopsym, vloopsym)
360
- UF = min (8 , VectorizationBase. nextpow2 (max (1 , round (Int, ltemp / (rttemp) ) )))
361
- UFfactor = 8 ÷ ls. vector_width
362
- cld (UF, UFfactor)* UFfactor, vloopsym
363
- end
347
+ num_reductions = count_reductions (ls)
348
+ # The strategy is to use an unroll factor of 1, unless there appears to be loop carried dependencies (ie, num_reductions > 0)
349
+ # The assumption here is that unrolling provides no real benefit, unless it is needed to enable OOO execution by breaking up these dependency chains
350
+ loopindexesbit = ls. loopindexesbit
351
+ if iszero (length (loopindexesbit)) || ((! loopindexesbit[getloopid (ls, vloopsym)]))
352
+ if iszero (num_reductions)
353
+ return unroll_no_reductions (ls, order, vloopsym)
354
+ else
355
+ return determine_unroll_factor (ls, order, vloopsym, num_reductions)
356
+ end
357
+ elseif iszero (num_reductions) # handle `BitArray` loops w/out reductions
358
+ return 8 ÷ ls. vector_width, vloopsym
359
+ else # handle `BitArray` loops with reductions
360
+ rttemp, ltemp = determine_unroll_factor (ls, order, vloopsym, vloopsym)
361
+ UF = min (8 , VectorizationBase. nextpow2 (max (1 , round (Int, ltemp / (rttemp) ) )))
362
+ UFfactor = 8 ÷ ls. vector_width
363
+ cld (UF, UFfactor)* UFfactor, vloopsym
364
+ end
364
365
end
365
366
# function scale_unrolled()
366
367
# end
367
368
function determine_unroll_factor (ls:: LoopSet , order:: Vector{Symbol} , vloopsym:: Symbol , num_reductions:: Int )
368
- innermost_loop = last (order)
369
- rt = Inf ; rtcomp = Inf ; latency = Inf ; best_unrolled = Symbol (" " )
370
- for unrolled ∈ order
371
- reject_reorder (ls, unrolled, false ) && continue
372
- rttemp, ltemp = determine_unroll_factor (ls, order, unrolled, vloopsym)
373
- rtcomptemp = rttemp + (0.01 * ((vloopsym === unrolled) + (unrolled === innermost_loop) - latency))
374
- if rtcomptemp < rtcomp
375
- rt = rttemp
376
- rtcomp = rtcomptemp
377
- latency = ltemp
378
- best_unrolled = unrolled
379
- end
380
- end
381
- # min(8, roundpow2(max(1, round(Int, latency / (rt * num_reductions) ) ))), best_unrolled
382
- lrtratio = latency / rt
383
- if lrtratio ≥ 7.0
384
- UF = 8
385
- else
386
- UF = VectorizationBase. nextpow2 (round (Int, clamp (lrtratio, 1.0 , 4.0 )))
387
- end
388
- if best_unrolled === vloopsym
389
- UF = demote_unroll_factor (ls, UF, vloopsym)
369
+ innermost_loop = last (order)
370
+ rt = Inf ; rtcomp = Inf ; latency = Inf ; best_unrolled = Symbol (" " )
371
+ for unrolled ∈ order
372
+ reject_reorder (ls, unrolled, false ) && continue
373
+ rttemp, ltemp = determine_unroll_factor (ls, order, unrolled, vloopsym)
374
+ rtcomptemp = rttemp + (0.01 * ((vloopsym === unrolled) + (unrolled === innermost_loop) - latency))
375
+ if rtcomptemp < rtcomp
376
+ rt = rttemp
377
+ rtcomp = rtcomptemp
378
+ latency = ltemp
379
+ best_unrolled = unrolled
390
380
end
391
- UF, best_unrolled
381
+ end
382
+ # min(8, roundpow2(max(1, round(Int, latency / (rt * num_reductions) ) ))), best_unrolled
383
+ lrtratio = latency / rt
384
+ if lrtratio ≥ 7.0
385
+ UF = 8
386
+ else
387
+ UF = VectorizationBase. nextpow2 (round (Int, clamp (lrtratio, 1.0 , 4.0 ), RoundUp))
388
+ end
389
+ if best_unrolled === vloopsym
390
+ UF = demote_unroll_factor (ls, UF, vloopsym)
391
+ end
392
+ UF, best_unrolled
392
393
end
393
394
394
395
@inline function unroll_cost (X, u₁, u₂, u₁L, u₂L)
0 commit comments