@@ -186,7 +186,7 @@ function unroll_no_reductions(ls, order, unrolled, vectorized, Wshift, size_T)
186
186
# @show compute_rt, load_rt
187
187
# roundpow2(min(4, round(Int, (compute_rt + load_rt + 1) / compute_rt)))
188
188
rt = max (compute_rt, load_rt)
189
- rt == 0.0 && return 4
189
+ iszero (rt) && return 4
190
190
max (1 , roundpow2 ( min ( 4 , round (Int, 16 / rt) ) ))
191
191
end
192
192
function determine_unroll_factor (
@@ -286,9 +286,10 @@ function solve_unroll(X, R, u₁L, u₂L, u₁step, u₂step)
286
286
discriminant < 0 && return - 1 ,- 1 ,Inf
287
287
u₁float = max (float (u₁step), (sqrt (discriminant) + b) / (- 2 a)) # must be at least 1
288
288
u₂float = (RR - u₁float* R₂)/ (u₁float* R₁)
289
- if ! (isfinite (u₂float) && isfinite (u₁float))
290
- return 4 , 4 , unroll_cost (X, 4 , 4 , u₁L, u₂L)
291
- # return itertilesize(X, u₁L, u₂L)
289
+ if ! (isfinite (u₂float) & isfinite (u₁float)) # brute force
290
+ u₁low = u₂low = 1
291
+ u₁high = u₂high = REGISTER_COUNT == 32 ? 10 : 6 # 8
292
+ return solve_unroll_iter (X, R, u₁L, u₂L, u₁low: u₁step: u₁high, u₂low: u₂step: u₂high)
292
293
end
293
294
u₁low = floor (Int, u₁float)
294
295
u₂low = max (u₂step, floor (Int, u₂float)) # must be at least 1
@@ -564,6 +565,13 @@ function load_elimination_cost_factor!(
564
565
false
565
566
end
566
567
end
568
+ function loadintostore (ls:: LoopSet , op:: Operation )
569
+ isload (op) || return false
570
+ for opp ∈ operations (ls)
571
+ isstore (opp) && opp. ref == op. ref && return true
572
+ end
573
+ false
574
+ end
567
575
function add_constant_offset_load_elmination_cost! (
568
576
X, R, choose_to_inline, ls:: LoopSet , op:: Operation , iters, unrollsyms:: UnrollSymbols , u₁reduces:: Bool , u₂reduces:: Bool , Wshift:: Int , size_T:: Int , opisininnerloop:: Bool
569
577
)
@@ -575,6 +583,9 @@ function add_constant_offset_load_elmination_cost!(
575
583
rt, lat, rp = cost (ls, op, vectorized, Wshift, size_T)
576
584
rt *= iters
577
585
rp = opisininnerloop ? rp : zero (rp)
586
+ # if loadintostore(ls, op) # For now, let's just avoid unrolling in this way...
587
+ # rt = Inf
588
+ # end
578
589
# u_uid is getting eliminated
579
590
# we treat this as the unrolled loop getting eliminated is split into 2 parts:
580
591
# 1 a non-cost-reduced part, with factor udependent_reduction
@@ -700,7 +711,8 @@ function evaluate_cost_tile(
700
711
prefetch_good_idea = true
701
712
end
702
713
# @show isunrolled₁, isunrolled₂, op rt, lat, rp
703
- rp = opisininnerloop ? rp : zero (rp) # we only care about register pressure within the inner most loop
714
+ rp = (opisininnerloop && ! (loadintostore (ls, op))) ? rp : zero (rp) # we only care about register pressure within the inner most loop
715
+ # rp = opisininnerloop ? rp : zero(rp) # we only care about register pressure within the inner most loop
704
716
rt *= iters[id]
705
717
if u₁reduces & u₂reduces
706
718
cost_vec[4 ] += rt
0 commit comments