Skip to content

Commit 5c398b9

Browse files
committed
Tweak dep chain unrolling factors, make reductions respect initialized int vs float status.
1 parent 0a29447 commit 5c398b9

File tree

3 files changed

+34
-18
lines changed

3 files changed

+34
-18
lines changed

src/codegen/lowering.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -423,11 +423,25 @@ function pointerremcomparison(ls::LoopSet, termind::Int, UFt::Int, n::Int, nisve
423423
end
424424
end
425425

426-
426+
@generated function of_same_size(::Type{T}, ::Type{S}) where {T,S}
427+
sizeof_S = sizeof(S)
428+
sizeof(T) == sizeof_S && return T
429+
Tfloat = T <: Union{Float32,Float64}
430+
if T <: Union{Float32,Float64}
431+
sizeof_S 8 ? Float64 : Float32
432+
elseif T <: Signed
433+
Symbol(:Int, 8sizeof_S)
434+
elseif (T <: Unsigned) | (T === Bool)
435+
Symbol(:UInt, 8sizeof_S)
436+
else
437+
S
438+
end
439+
end
427440
function outer_reduction_zero(op::Operation, u₁u::Bool, Umax::Int, reduct_class::Float64, rs::Expr)
428441
reduct_zero = reduction_zero(reduct_class)
429442
# Tsym = outer_reduct_init_typename(op)
430-
Tsym = ELTYPESYMBOL
443+
# Tsym = ELTYPESYMBOL
444+
Tsym = Expr(:call, lv(:of_same_size), outer_reduct_init_typename(op), ELTYPESYMBOL)
431445
if isvectorized(op)
432446
if Umax == 1 || !u₁u
433447
if reduct_zero === :zero

src/modeling/determinestrategy.jl

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ function cost(ls::LoopSet, op::Operation, (u₁,u₂)::Tuple{Symbol,Symbol}, vlo
7070
return 0.0, 0, 0.0
7171
end
7272
elseif iscompute(op) &&
73-
Base.sym_in(instruction(op).instr, (:vadd_nsw, :vsub_nsw, :(+), :(-), :add_fast, :sub_fast)) &&
74-
all(opp -> (isloopvalue(opp)), parents(op))
73+
(Base.sym_in(instruction(op).instr, (:vadd_nsw, :vsub_nsw, :(+), :(-), :add_fast, :sub_fast)) &&
74+
all(opp -> (isloopvalue(opp)), parents(op)))# || (reg_count(ls) == 32) && (instruction(op).instr === :ifelse))
7575
# all(opp -> (isloopvalue(opp) | isconstant(opp)), parents(op))
7676
return 0.0, 0, 0.0
7777
end
@@ -202,14 +202,14 @@ function depchain_cost!(
202202
skip[identifier(op)] = true
203203
# depth first search
204204
for opp parents(op)
205-
skip[identifier(opp)] && continue
205+
skip[identifier(opp)] && continue
206206
rt, sl = depchain_cost!(ls, skip, opp, unrolled, vloopsym, Wshift, size_T, rt, sl)
207207
end
208208
# Basically assuming memory and compute don't conflict, but everything else does
209209
# Ie, ignoring the fact that integer and floating point operations likely don't either
210210
if iscompute(op)
211-
rtᵢ, slᵢ = cost(ls, op, (unrolled,Symbol("")), vloopsym, Wshift, size_T)
212-
rt += rtᵢ; sl += slᵢ
211+
rtᵢ, slᵢ = cost(ls, op, (unrolled,Symbol("")), vloopsym, Wshift, size_T)
212+
rt += rtᵢ; sl += slᵢ
213213
end
214214
rt, sl
215215
end
@@ -357,11 +357,11 @@ function determine_unroll_factor(ls::LoopSet, order::Vector{Symbol}, vloopsym::S
357357
else
358358
return determine_unroll_factor(ls, order, vloopsym, num_reductions)
359359
end
360-
elseif iszero(num_reductions)
360+
elseif iszero(num_reductions) # handle `BitArray` loops w/out reductions
361361
return 8 ÷ ls.vector_width, vloopsym
362-
else
362+
else # handle `BitArray` loops with reductions
363363
rttemp, ltemp = determine_unroll_factor(ls, order, vloopsym, vloopsym)
364-
UF = min(8, VectorizationBase.nextpow2(max(1, round(Int, ltemp / (rttemp * num_reductions) ) )))
364+
UF = min(8, VectorizationBase.nextpow2(max(1, round(Int, ltemp / (rttemp) ) )))
365365
UFfactor = 8 ÷ ls.vector_width
366366
cld(UF, UFfactor)*UFfactor, vloopsym
367367
end
@@ -383,9 +383,11 @@ function determine_unroll_factor(ls::LoopSet, order::Vector{Symbol}, vloopsym::S
383383
end
384384
end
385385
# min(8, roundpow2(max(1, round(Int, latency / (rt * num_reductions) ) ))), best_unrolled
386-
UF = VectorizationBase.nextpow2(round(Int, clamp(latency / (rt * num_reductions), 1.0, 8.0)))
387-
if UF == 1 && num_reductions > 1
388-
UF = VectorizationBase.nextpow2(round(Int, clamp(latency / (rt * cld(num_reductions, 2)), 1.0, 8.0)))
386+
lrtratio = latency / rt
387+
if lrtratio 7.0
388+
UF = 8
389+
else
390+
UF = VectorizationBase.nextpow2(round(Int, clamp(lrtratio, 1.0, 4.0)))
389391
end
390392
if best_unrolled === vloopsym
391393
UF = demote_unroll_factor(ls, UF, vloopsym)

test/dot.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -291,11 +291,11 @@ using Test
291291
@test dot33(a,b) @view(a[1:33])' * @view(b[1:33])
292292

293293
if T <: Union{Float32,Float64}
294-
πest = T(mcpi(a, b))
295-
@test πest == mcpiavx(a, b)
296-
@test πest == mcpiavx_u4(a, b)
297-
@test πest == mcpi_avx(a, b)
298-
@test πest == mcpi_avx_u4(a, b)
294+
πest = mcpi(a, b)
295+
@test πest mcpiavx(a, b)
296+
@test πest mcpiavx_u4(a, b)
297+
@test πest mcpi_avx(a, b)
298+
@test πest mcpi_avx_u4(a, b)
299299
end
300300

301301
if !(!Bool(LoopVectorization.VectorizationBase.has_feature(Val(:x86_64_avx2))) && T === Int32)

0 commit comments

Comments
 (0)