@@ -274,27 +274,27 @@ function solve_unroll_iter(X, R, u₁L, u₂L, u₁range, u₂range)
274
274
u₁best, u₂best, bestcost
275
275
end
276
276
277
- function solve_unroll (X, R, u₁L, u₂L)
277
+ function solve_unroll (X, R, u₁L, u₂L, u₁step, u₂step )
278
278
X₁, X₂, X₃, X₄ = X[1 ], X[2 ], X[3 ], X[4 ]
279
279
R₁, R₂, R₃, R₄, R₅ = R[1 ], R[2 ], R[3 ], R[4 ], R[5 ]
280
- iszero (R₅) || return solve_unroll_iter (X, R, u₁L, u₂L, 1 : 10 , 1 : 10 )
280
+ iszero (R₅) || return solve_unroll_iter (X, R, u₁L, u₂L, u₁step : u₁step : 10 , u₂step : u₂step : 10 )
281
281
RR = REGISTER_COUNT - R₃ - R₄
282
282
a = R₂^ 2 * X₃ - R₁* X₄ * R₂ - R₁* X₂* RR
283
283
b = R₁ * X₄ * RR - R₁ * X₄ * RR - 2 X₃* RR* R₂
284
284
c = X₃* RR^ 2
285
285
discriminant = b^ 2 - 4 a* c
286
286
discriminant < 0 && return - 1 ,- 1 ,Inf
287
- u₁float = max (1.0 , (sqrt (discriminant) + b) / (- 2 a)) # must be at least 1
287
+ u₁float = max (float (u₁step) , (sqrt (discriminant) + b) / (- 2 a)) # must be at least 1
288
288
u₂float = (RR - u₁float* R₂)/ (u₁float* R₁)
289
289
if ! (isfinite (u₂float) && isfinite (u₁float))
290
290
return 4 , 4 , unroll_cost (X, 4 , 4 , u₁L, u₂L)
291
291
# return itertilesize(X, u₁L, u₂L)
292
292
end
293
293
u₁low = floor (Int, u₁float)
294
- u₂low = max (1 , floor (Int, u₂float)) # must be at least 1
295
- u₁high = solve_unroll_constT (R, u₂low) + 1
296
- u₂high = solve_unroll_constU (R, u₁low) + 1
297
- solve_unroll_iter (X, R, u₁L, u₂L, u₁low: u₁high, u₂low: u₂high)
294
+ u₂low = max (u₂step , floor (Int, u₂float)) # must be at least 1
295
+ u₁high = solve_unroll_constT (R, u₂low) + u₁step
296
+ u₂high = solve_unroll_constU (R, u₁low) + u₂step
297
+ solve_unroll_iter (X, R, u₁L, u₂L, u₁low: u₁step : u₁ high, u₂low: u₂step : u₂high)
298
298
end
299
299
300
300
function solve_unroll_constU (R:: AbstractVector , u₁:: Int )
@@ -308,9 +308,9 @@ function solve_unroll_constT(ls::LoopSet, u₂::Int)
308
308
floor (Int, (REGISTER_COUNT - R[3 ] - R[4 ] - u₂* R[5 ]) / (u₂ * R[1 ] + R[2 ]))
309
309
end
310
310
# Tiling here is about alleviating register pressure for the UxT
311
- function solve_unroll (X, R, u₁max, u₂max, u₁L, u₂L)
311
+ function solve_unroll (X, R, u₁max, u₂max, u₁L, u₂L, u₁step, u₂step )
312
312
# iszero(first(R)) && return -1,-1,Inf #solve_smalltilesize(X, R, u₁max, u₂max)
313
- u₁, u₂, cost = solve_unroll (X, R, u₁L, u₂L)
313
+ u₁, u₂, cost = solve_unroll (X, R, u₁L, u₂L, u₁step, u₂step )
314
314
# u₂ -= u₂ & 1
315
315
# u₁ = min(u₁, u₂)
316
316
u₁_too_large = u₁ > u₁max
@@ -354,12 +354,19 @@ function solve_unroll(
354
354
ls:: LoopSet , u₁loopsym:: Symbol , u₂loopsym:: Symbol ,
355
355
cost_vec:: AbstractVector{Float64} ,
356
356
reg_pressure:: AbstractVector{Float64} ,
357
- W:: Int , vectorized:: Symbol
357
+ W:: Int , vectorized:: Symbol , rounduᵢ :: Int
358
358
)
359
+ (u₁step, u₂step) = if rounduᵢ == 1 # max is to safeguard against some weird arch I've never heard of.
360
+ (max (1 ,VectorizationBase. CACHELINE_SIZE ÷ VectorizationBase. REGISTER_SIZE), 1 )
361
+ elseif rounduᵢ == 2
362
+ (1 , max (1 ,VectorizationBase. CACHELINE_SIZE ÷ VectorizationBase. REGISTER_SIZE))
363
+ else
364
+ (1 , 1 )
365
+ end
359
366
u₁loop = getloop (ls, u₁loopsym)
360
367
u₂loop = getloop (ls, u₂loopsym)
361
368
solve_unroll (
362
- u₁loopsym, u₂loopsym, cost_vec, reg_pressure, W, vectorized, u₁loop, u₂loop
369
+ u₁loopsym, u₂loopsym, cost_vec, reg_pressure, W, vectorized, u₁loop, u₂loop, u₁step, u₂step
363
370
)
364
371
end
365
372
@@ -368,7 +375,8 @@ function solve_unroll(
368
375
cost_vec:: AbstractVector{Float64} ,
369
376
reg_pressure:: AbstractVector{Float64} ,
370
377
W:: Int , vectorized:: Symbol ,
371
- u₁loop:: Loop , u₂loop:: Loop
378
+ u₁loop:: Loop , u₂loop:: Loop ,
379
+ u₁step:: Int , u₂step:: Int
372
380
)
373
381
maxu₂base = maxu₁base = REGISTER_COUNT == 32 ? 10 : 6 # 8
374
382
maxu₂ = maxu₂base# 8
@@ -393,7 +401,7 @@ function solve_unroll(
393
401
u₁L = u₁loopsym === vectorized ? cld (u₁L,W) : u₁L
394
402
maxu₁ = min (4 maxu₁, u₁L)
395
403
end
396
- u₁, u₂, cost = solve_unroll (cost_vec, reg_pressure, maxu₁, maxu₂, length (u₁loop), length (u₂loop))
404
+ u₁, u₂, cost = solve_unroll (cost_vec, reg_pressure, maxu₁, maxu₂, length (u₁loop), length (u₂loop), u₁step, u₂step )
397
405
# heuristic to more evenly divide small numbers of iterations
398
406
if isstaticloop (u₂loop)
399
407
u₂ = maybedemotesize (u₂, length (u₂loop), u₁, u₁loop, maxu₂base)
@@ -637,6 +645,7 @@ function evaluate_cost_tile(
637
645
u₁reached = u₂reached = false
638
646
choose_to_inline = Ref (false )
639
647
copyto! (names (ls), order); reverse! (names (ls))
648
+ prefetch_good_idea = false
640
649
for n ∈ 1 : N
641
650
itersym = order[n]
642
651
if itersym == u₁loopsym
@@ -688,6 +697,7 @@ function evaluate_cost_tile(
688
697
rt, lat, rp = cost (ls, op, vectorized, Wshift, size_T)
689
698
if isload (op) && ! iszero (prefetchisagoodidea (ls, op, UnrollArgs (4 , unrollsyms, 4 , 0 )))
690
699
rt += 0.5 VectorizationBase. REGISTER_SIZE / VectorizationBase. CACHELINE_SIZE
700
+ prefetch_good_idea = true
691
701
end
692
702
# @show isunrolled₁, isunrolled₂, op rt, lat, rp
693
703
rp = opisininnerloop ? rp : zero (rp) # we only care about register pressure within the inner most loop
@@ -710,10 +720,11 @@ function evaluate_cost_tile(
710
720
costpenalty = (sum (reg_pressure) > REGISTER_COUNT) ? 2 : 1
711
721
# @show order, vectorized cost_vec reg_pressure
712
722
# @show solve_unroll(ls, u₁loopsym, u₂loopsym, cost_vec, reg_pressure)
713
- u₁, u₂, ucost = solve_unroll (ls, u₁loopsym, u₂loopsym, cost_vec, reg_pressure, W, vectorized)
723
+ u₁v = vectorized === u₁loopsym; u₂v = vectorized === u₂loopsym
724
+ round_uᵢ = prefetch_good_idea ? (u₁v ? 1 : (u₂v ? 2 : 0 )) : 0
725
+ u₁, u₂, ucost = solve_unroll (ls, u₁loopsym, u₂loopsym, cost_vec, reg_pressure, W, vectorized, round_uᵢ)
714
726
outer_reduct_penalty = length (ls. outer_reductions) * (u₁ + isodd (u₁))
715
727
favor_bigger_u₂ = u₁ - u₂
716
- u₁v = vectorized === u₁loopsym; u₂v = vectorized === u₂loopsym
717
728
favor_smaller_vectorized = u₁v ? ( u₁ - u₂ ) : (u₂v ? ( u₂ - u₁ ) : 0 )
718
729
favor_u₁_vectorized = - 0.2 u₁v
719
730
favoring_heuristics = favor_bigger_u₂ + 0.5 favor_smaller_vectorized + favor_u₁_vectorized
0 commit comments