@@ -70,8 +70,8 @@ function cost(ls::LoopSet, op::Operation, (u₁,u₂)::Tuple{Symbol,Symbol}, vlo
70
70
return 0.0 , 0 , 0.0
71
71
end
72
72
elseif iscompute (op) &&
73
- Base. sym_in (instruction (op). instr, (:vadd_nsw , :vsub_nsw , :(+ ), :(- ), :add_fast , :sub_fast )) &&
74
- all (opp -> (isloopvalue (opp)), parents (op))
73
+ ( Base. sym_in (instruction (op). instr, (:vadd_nsw , :vsub_nsw , :(+ ), :(- ), :add_fast , :sub_fast )) &&
74
+ all (opp -> (isloopvalue (opp)), parents (op))) # || (reg_count(ls) == 32) && (instruction(op).instr === :ifelse))
75
75
# all(opp -> (isloopvalue(opp) | isconstant(opp)), parents(op))
76
76
return 0.0 , 0 , 0.0
77
77
end
@@ -202,14 +202,14 @@ function depchain_cost!(
202
202
skip[identifier (op)] = true
203
203
# depth first search
204
204
for opp ∈ parents (op)
205
- skip[identifier (opp)] && continue
205
+ skip[identifier (opp)] && continue
206
206
rt, sl = depchain_cost! (ls, skip, opp, unrolled, vloopsym, Wshift, size_T, rt, sl)
207
207
end
208
208
# Basically assuming memory and compute don't conflict, but everything else does
209
209
# Ie, ignoring the fact that integer and floating point operations likely don't either
210
210
if iscompute (op)
211
- rtᵢ, slᵢ = cost (ls, op, (unrolled,Symbol (" " )), vloopsym, Wshift, size_T)
212
- rt += rtᵢ; sl += slᵢ
211
+ rtᵢ, slᵢ = cost (ls, op, (unrolled,Symbol (" " )), vloopsym, Wshift, size_T)
212
+ rt += rtᵢ; sl += slᵢ
213
213
end
214
214
rt, sl
215
215
end
@@ -357,11 +357,11 @@ function determine_unroll_factor(ls::LoopSet, order::Vector{Symbol}, vloopsym::S
357
357
else
358
358
return determine_unroll_factor (ls, order, vloopsym, num_reductions)
359
359
end
360
- elseif iszero (num_reductions)
360
+ elseif iszero (num_reductions) # handle `BitArray` loops w/out reductions
361
361
return 8 ÷ ls. vector_width, vloopsym
362
- else
362
+ else # handle `BitArray` loops with reductions
363
363
rttemp, ltemp = determine_unroll_factor (ls, order, vloopsym, vloopsym)
364
- UF = min (8 , VectorizationBase. nextpow2 (max (1 , round (Int, ltemp / (rttemp * num_reductions ) ) )))
364
+ UF = min (8 , VectorizationBase. nextpow2 (max (1 , round (Int, ltemp / (rttemp) ) )))
365
365
UFfactor = 8 ÷ ls. vector_width
366
366
cld (UF, UFfactor)* UFfactor, vloopsym
367
367
end
@@ -383,9 +383,11 @@ function determine_unroll_factor(ls::LoopSet, order::Vector{Symbol}, vloopsym::S
383
383
end
384
384
end
385
385
# min(8, roundpow2(max(1, round(Int, latency / (rt * num_reductions) ) ))), best_unrolled
386
- UF = VectorizationBase. nextpow2 (round (Int, clamp (latency / (rt * num_reductions), 1.0 , 8.0 )))
387
- if UF == 1 && num_reductions > 1
388
- UF = VectorizationBase. nextpow2 (round (Int, clamp (latency / (rt * cld (num_reductions, 2 )), 1.0 , 8.0 )))
386
+ lrtratio = latency / rt
387
+ if lrtratio ≥ 7.0
388
+ UF = 8
389
+ else
390
+ UF = VectorizationBase. nextpow2 (round (Int, clamp (lrtratio, 1.0 , 4.0 )))
389
391
end
390
392
if best_unrolled === vloopsym
391
393
UF = demote_unroll_factor (ls, UF, vloopsym)
0 commit comments