@@ -26,10 +26,10 @@ function unitstride(ls::LoopSet, op::Operation, s::Symbol)
26
26
fi = first (inds)
27
27
if fi === Symbol (" ##DISCONTIGUOUSSUBARRAY##" )
28
28
return false
29
- elseif ! first (li)
30
- # We must check if this
31
- parent = findparent (ls, fi)
32
- indexappearences (parent, s) > 1 && return false
29
+ # elseif !first(li)
30
+ # # We must check if this
31
+ # parent = findparent(ls, fi)
32
+ # indexappearences(parent, s) > 1 && return false
33
33
end
34
34
for i ∈ 2 : length (inds)
35
35
if li[i]
@@ -217,8 +217,7 @@ function unroll_no_reductions(ls, order, vectorized)
217
217
elseif isstore (op)
218
218
store_rt += first (cost (ls, op, vectorized, Wshift, size_T))
219
219
end
220
- end
221
- # @show compute_rt, load_rt, store_rt
220
+ end
222
221
# heuristic guess
223
222
# roundpow2(min(4, round(Int, (compute_rt + load_rt + 1) / compute_rt)))
224
223
memory_rt = load_rt + store_rt
@@ -374,7 +373,7 @@ function solve_unroll(X, R, u₁L, u₂L, u₁step, u₂step)
374
373
u₂low = max (u₂step, floor (Int, u₂float)) # must be at least 1
375
374
u₁high = solve_unroll_constT (R, u₂low) + u₁step
376
375
u₂high = solve_unroll_constU (R, u₁low) + u₂step
377
- maxunroll = REGISTER_COUNT == 32 ? 10 : 6
376
+ maxunroll = REGISTER_COUNT == 32 ? (((X₂ > 0 ) & (X₃ > 0 )) ? 10 : 8 ) : 6
378
377
u₁low = min (u₁low, maxunroll)
379
378
u₂low = min (u₂low, maxunroll)
380
379
u₁high = min (u₁high, maxunroll)
@@ -534,18 +533,19 @@ function stride_penalty(ls::LoopSet, op::Operation, order::Vector{Symbol}, loopf
534
533
penalty
535
534
end
536
535
function stride_penalty (ls:: LoopSet , order:: Vector{Symbol} )
537
- stridepenalty = 0.0
536
+ stridepenaltydict = Dict {Symbol,Vector{Float64}} ()
538
537
loopfreqs = Vector {Int} (undef, length (order))
539
538
loopfreqs[1 ] = 1
540
539
for i ∈ 2 : length (order)
541
540
loopfreqs[i] = loopfreqs[i- 1 ] * length (getloop (ls, order[i]))
542
541
end
543
542
for op ∈ operations (ls)
544
543
if accesses_memory (op)
545
- stridepenalty += stride_penalty (ls, op, order, loopfreqs)
544
+ v = get! (() -> Float64[], stridepenaltydict, op. ref. ref. array)
545
+ push! (v, stride_penalty (ls, op, order, loopfreqs))
546
546
end
547
547
end
548
- stridepenalty # * 1e-9
548
+ sum (maximum, values (stridepenaltydict)) # * prod(length, ls.loops) / 1024^length(order)
549
549
end
550
550
function isoptranslation (ls:: LoopSet , op:: Operation , unrollsyms:: UnrollSymbols )
551
551
@unpack u₁loopsym, u₂loopsym, vectorized = unrollsyms
@@ -627,6 +627,7 @@ function load_elimination_cost_factor!(
627
627
@unpack u₁loopsym, u₂loopsym, vectorized = unrollsyms
628
628
if ! iszero (first (isoptranslation (ls, op, unrollsyms)))
629
629
rt, lat, rp = cost (ls, op, vectorized, Wshift, size_T)
630
+ rto = rt
630
631
rt *= iters
631
632
# rt *= factor1; rp *= factor2;
632
633
choose_to_inline[] = true
@@ -645,8 +646,8 @@ function load_elimination_cost_factor!(
645
646
# end
646
647
# # (0.25, REGISTER_COUNT == 32 ? 1.2 : 1.0)
647
648
# (0.25, 1.0)
648
- cost_vec[1 ] + = 0.1 rt
649
- reg_pressure[1 ] += 0.51 rp
649
+ cost_vec[1 ] - = 0.1 prod (length, ls . loops)
650
+ reg_pressure[1 ] += 0.25 rp
650
651
cost_vec[2 ] += rt
651
652
reg_pressure[2 ] += rp
652
653
cost_vec[3 ] += rt
@@ -658,13 +659,13 @@ function load_elimination_cost_factor!(
658
659
false
659
660
end
660
661
end
661
- # function loadintostore(ls::LoopSet, op::Operation)
662
- # isload(op) || return false # leads to bad behavior more than it helps
663
- # for opp ∈ operations(ls)
664
- # isstore(opp) && opp.ref == op.ref && return true
665
- # end
666
- # false
667
- # end
662
+ function loadintostore (ls:: LoopSet , op:: Operation )
663
+ isload (op) || return false # leads to bad behavior more than it helps
664
+ for opp ∈ operations (ls)
665
+ isstore (opp) && opp. ref == op. ref && return true
666
+ end
667
+ false
668
+ end
668
669
function store_load_deps! (deps:: Vector{Symbol} , op:: Operation , compref = op. ref)
669
670
for opp ∈ parents (op)
670
671
foreach (ld -> ((ld ∈ deps) || push! (deps, ld)), loopdependencies (opp))
@@ -799,7 +800,6 @@ function evaluate_cost_tile(
799
800
all (ld -> ld ∈ nested_loop_syms, loopdependencies (op)) || continue
800
801
rd = reduceddependencies (op)
801
802
if hasintersection (rd, @view (nested_loop_syms[1 : end - length (rd)]))
802
- # @show rd, op itersym, nested_loop_syms @view(nested_loop_syms[1:end-length(rd)])
803
803
return 0 ,0 ,Inf ,false
804
804
end
805
805
if isstore (op)
@@ -860,7 +860,7 @@ function evaluate_cost_tile(
860
860
costpenalty = (sum (reg_pressure) > REGISTER_COUNT) ? 2 : 1
861
861
u₁v = vectorized === u₁loopsym; u₂v = vectorized === u₂loopsym
862
862
round_uᵢ = prefetch_good_idea ? (u₁v ? 1 : (u₂v ? 2 : 0 )) : 0
863
- if irreducible_storecosts / sum (cost_vec) ≥ 0.25
863
+ if ( irreducible_storecosts / sum (cost_vec) ≥ 0.25 ) && ! any (op -> loadintostore (ls, op), operations (ls))
864
864
u₁, u₂ = (1 , 1 )
865
865
ucost = unroll_cost (cost_vec, 1 , 1 , length (getloop (ls, u₁loopsym)), length (getloop (ls, u₂loopsym)))
866
866
else
0 commit comments