Skip to content

Commit 77bd400

Browse files
committed
allow unroll curly in some more instances
1 parent 2b01184 commit 77bd400

File tree

5 files changed

+73
-36
lines changed

5 files changed

+73
-36
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.12.78"
4+
version = "0.12.79"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/codegen/lower_load.jl

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ function lower_load_no_optranslation!(
146146
u = ifelse(opu₁, u₁, 1)
147147
mvar = Symbol(variable_name(op, Core.ifelse(opu₂, suffix,-1)), '_', u)
148148
falseexpr = Expr(:call, lv(:False)); rs = staticexpr(reg_size(ls))
149-
if (all(op.ref.loopedindex) && !rejectcurly(op)) && vectorization_profitable(op)
149+
if (!rejectcurly(op)) && vectorization_profitable(op)
150150
inds = unrolledindex(op, td, mask, inds_calc_by_ptr_offset, ls)
151151
loadexpr = Expr(:call, lv(:_vload), sptr(op), inds)
152152
add_memory_mask!(loadexpr, op, td, mask, ls, 0)
@@ -222,7 +222,7 @@ function lower_load_for_optranslation!(
222222
ip = GlobalRef(VectorizationBase, :increment_ptr)
223223
vpo = vptr_offset(gptr)
224224
push!(q.args, Expr(:(=), vpo, Expr(:call, ip, ptr, vptr_offset(ptr), gespinds)))
225-
push!(q.args, Expr(:(=), gptr, ptr))#Expr(:call, GlobalRef(VectorizationBase, :reconstruct_ptr),
225+
push!(q.args, Expr(:(=), gptr, ptr))#Expr(:call, GlobalRef(VectorizationBase, :reconstruct_ptr),
226226
fill!(inds_by_ptroff, true)
227227
@unpack ref, loopedindex = mref
228228
indices = copy(getindices(ref))
@@ -367,9 +367,27 @@ function rejectcurly(ls::LoopSet, op::Operation, u₁loopsym::Symbol, vloopsym::
367367
end
368368
else
369369
opp = findop(parents(op), ind)
370+
(isu₁unrolled(opp) || isu₂unrolled(opp)) && return true
371+
length(parents(opp)) == 2 || return true
372+
if instruction(opp).instr === :(+) || instruction(opp).instr === :add_fast
373+
isadd = true
374+
elseif instruction(opp).instr === :(-) || instruction(opp).instr === :sub_fast
375+
isadd = false
376+
else
377+
return true
378+
end
379+
opp1 = parents(opp)[1]
380+
opp2 = parents(opp)[2]
370381
if isvectorized(opp)
371382
AV && return true
372383
AV = true
384+
if isvectorized(opp1)
385+
isvectorized(opp2) && return true
386+
isloopvalue(opp1) || return true
387+
else# opp2 vectorized
388+
isadd || return true
389+
isloopvalue(opp2) || return true
390+
end
373391
end
374392
if (u₁loopsym === CONSTANTZEROINDEX) ? (CONSTANTZEROINDEX loopdependencies(opp)) : (isu₁unrolled(opp))
375393
AU && return true
@@ -393,7 +411,7 @@ function rejectinterleave(ls::LoopSet, op::Operation, vloop::Loop, idsformap::Su
393411
end
394412
end
395413
end
396-
vloopsym = vloop.itersymbol;
414+
vloopsym = vloop.itersymbol;
397415
(first(getindices(op)) === vloopsym) && (length(idsformap) first(getstrides(op)) * gethint(strd))
398416
end
399417
# function lower_load_collection_manual_u₁unroll!(
@@ -501,12 +519,11 @@ function lower_load_collection!(
501519
u = Core.ifelse(opu₁, u₁, 1)
502520
for (i,(opid,o)) enumerate(idsformap)
503521
extractedv = Expr(:call, gf, collectionname, i, false)
504-
522+
505523
_op = ops[opidmap[opid]]
506524
mvar = Symbol(variable_name(_op, Core.ifelse(opu₂, suffix, -1)), '_', u)
507525
push!(q.args, Expr(:(=), mvar, extractedv))
508526
end
509527
# unpack_collection!(q, ls, opidmap, idsformap, ua, loadexpr, collectionname, op, true)
510528
end
511529
end
512-

src/modeling/determinestrategy.jl

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,10 @@ function unitstride(ls::LoopSet, op::Operation, s::Symbol)
4343
end
4444
true
4545
end
46-
46+
function cannot_shuffle(op::Operation, u₁::Symbol, u₂::Symbol, contigind::Symbol, indices) # assumes isvectorized and !unitstride
47+
!((!rejectcurly(op) && (((contigind === CONSTANTZEROINDEX) && ((length(indices) > 1) && (indices[2] === u₁) || (indices[2] === u₂))) ||
48+
((u₁ === contigind) | (u₂ === contigind)))))
49+
end
4750
function cost(ls::LoopSet, op::Operation, (u₁,u₂)::Tuple{Symbol,Symbol}, vloopsym::Symbol, Wshift::Int, size_T::Int = op.elementbytes)
4851
isconstant(op) && return 0.0, 0, 1.0#Float64(length(loopdependencies(op)) > 0)
4952
isloopvalue(op) && return 0.0, 0, 0.0
@@ -67,22 +70,36 @@ function cost(ls::LoopSet, op::Operation, (u₁,u₂)::Tuple{Symbol,Symbol}, vlo
6770
indices = getindices(op)
6871
contigind = first(indices)
6972
shifter = max(2,Wshift)
70-
if rejectinterleave(op)
71-
offset = 0.0 # gather/scatter, alignment doesn't matter
72-
else
73-
shifter -= 1
74-
offset = 0.5reg_size(ls) / cache_lnsze(ls)
75-
if shifter > 1 &&
76-
(!rejectcurly(op) && (((contigind === CONSTANTZEROINDEX) && ((length(indices) > 1) && (indices[2] === u₁) || (indices[2] === u₂))) ||
77-
((u₁ === contigind) | (u₂ === contigind))))
78-
79-
shifter -= 1
80-
offset = 0.5reg_size(ls) / cache_lnsze(ls)
73+
# rejectinterleave false means omop
74+
# cannot shuffle false means reject curly
75+
# either false means shuffle
76+
dont_shuffle = rejectinterleave(op) && cannot_shuffle(op, u₁, u₂, contigind, indices)
77+
if dont_shuffle
78+
# offset = 0.0 # gather/scatter, alignment doesn't matter
79+
r = 1 << shifter
80+
srt = srt*r# + offset
81+
sl *= r
82+
else#if rejectinterleave(op) # means omop
83+
if isload(op) & (length(loopdependencies(op)) > 1)# vmov(a/u)pd
84+
srt += 0.5reg_size(ls) / cache_lnsze(ls)
8185
end
86+
# srt += 0.3shifter # shifter == number of shuffles
87+
# sl += 0.3shifter
88+
srt += shifter # shifter == number of shuffles
89+
sl += shifter
90+
# shifter -= 1
91+
# offset = 0.5reg_size(ls) / cache_lnsze(ls)
92+
# r = 1 << shifter
93+
# srt = srt*r + offset
94+
# sl *= r
95+
# if shifter > 1 && (!(cannot_shuffle(op, u₁, u₂, contigind, indices)))
96+
# shifter -= 1
97+
# offset = 0.5reg_size(ls) / cache_lnsze(ls)
98+
# end
99+
# else
82100
end
83-
r = 1 << shifter
84-
srt = srt*r + offset
85-
sl *= r
101+
# @show srt, sl
102+
# @show shifter, offset, dont_shuffle
86103
elseif isload(op) & (length(loopdependencies(op)) > 1)# vmov(a/u)pd
87104
# penalize vectorized loads with more than 1 loopdep
88105
# heuristic; more than 1 loopdep means that many loads will not be aligned
@@ -94,7 +111,7 @@ function cost(ls::LoopSet, op::Operation, (u₁,u₂)::Tuple{Symbol,Symbol}, vlo
94111
srt += 0.5reg_size(ls) / cache_lnsze(ls)
95112
# srt += 0.25reg_size(ls) / cache_lnsze(ls)
96113
end
97-
elseif isstore(op) # broadcast or reductionstore; if store we want to penalize reduction
114+
elseif isstore(op)# && isvectorized(first(parents(op))) # broadcast or reductionstore; if store we want to penalize reduction
98115
srt *= 3
99116
sl *= 3
100117
end
@@ -828,6 +845,7 @@ function load_elimination_cost_factor!(
828845
# cost_vec[1] -= 0.5625 * iters
829846
# cost_vec[1] -= 0.5625 * iters / 2
830847
# @show rto, 0.8rt, op
848+
# reg_pressure[1] += 0.25rp
831849
reg_pressure[1] += 0.25rp
832850
cost_vec[2] += rt
833851
reg_pressure[2] += rp

src/modeling/graphs.jl

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -570,21 +570,24 @@ function fill_children!(ls::LoopSet)
570570
end
571571
end
572572
end
573+
function rejectinterleave!(ls::LoopSet, op::Operation, u₁loop::Symbol, u₂loop::Symbol, vloopsym::Symbol, vloop::Loop)
574+
setunrolled!(ls, op, u₁loop, u₂loop, vloopsym)
575+
if accesses_memory(op)
576+
rc = rejectcurly(ls, op, u₁loop, vloopsym)
577+
op.rejectcurly = rc
578+
if rc
579+
op.rejectinterleave = true
580+
else
581+
omop = ls.omop
582+
batchid, opind = omop.batchedcollectionmap[identifier(op)]
583+
op.rejectinterleave = ((batchid == 0) || (!isvectorized(op))) || rejectinterleave(ls, op, vloop, omop.batchedcollections[batchid])
584+
end
585+
end
586+
end
573587
function cacheunrolled!(ls::LoopSet, u₁loop::Symbol, u₂loop::Symbol, vloopsym::Symbol)
574588
vloop = getloop(ls, vloopsym)
575589
for op operations(ls)
576-
setunrolled!(ls, op, u₁loop, u₂loop, vloopsym)
577-
if accesses_memory(op)
578-
rc = rejectcurly(ls, op, u₁loop, vloopsym)
579-
op.rejectcurly = rc
580-
if rc
581-
op.rejectinterleave = true
582-
else
583-
omop = ls.omop
584-
batchid, opind = omop.batchedcollectionmap[identifier(op)]
585-
op.rejectinterleave = ((batchid == 0) || (!isvectorized(op))) || rejectinterleave(ls, op, vloop, omop.batchedcollections[batchid])
586-
end
587-
end
590+
rejectinterleave!(ls, op, u₁loop, u₂loop, vloopsym, vloop)
588591
end
589592
end
590593
function setunrolled!(ls::LoopSet, op::Operation, u₁loopsym::Symbol, u₂loopsym::Symbol, vectorized::Symbol)
@@ -1411,7 +1414,6 @@ function fill_offset_memop_collection!(ls::LoopSet)
14111414
end
14121415
for (collectionid,opidc) enumerate(opids)
14131416
length(opidc) > 1 || continue
1414-
14151417
# we check if we can turn the offsets into an unroll
14161418
# we have up to `length(opidc)` loads to do, so we allocate that many "base" vectors
14171419
# then we iterate through them, adding them to collections as appropriate

src/reconstruct_loopset.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -711,7 +711,7 @@ Execute an `@turbo` block. The block's code is represented via the arguments:
711711
- `vargs...` holds the encoded pointers of all the arrays (see `VectorizationBase`'s various pointer types).
712712
"""
713713
@generated function _turbo_!(
714-
::Val{var"#UNROLL#"}, ::Val{var"#OPS#"}, ::Val{var"#ARF#"}, ::Val{var"#AM#"}, ::Val{var"#LPSYM#"}, ::Val{Tuple{var"#LB#",var"#V#"}}, var"#flattened#var#arguments#"::Vararg{Any,var"#num#vargs#"}
714+
::Val{var"#UNROLL#"}, ::Val{var"#OPS#"}, ::Val{var"#ARF#"}, ::Val{var"#AM#"}, ::Val{var"#LPSYM#"}, ::Val{Tuple{var"#LB#",var"#V#"}}, var"#flattened#var#arguments#"::Vararg{Any,var"#num#vargs#"}
715715
) where {var"#UNROLL#", var"#OPS#", var"#ARF#", var"#AM#", var"#LPSYM#", var"#LB#", var"#V#", var"#num#vargs#"}
716716
# 1 + 1 # Irrelevant line you can comment out/in to force recompilation...
717717
ls = _turbo_loopset(var"#OPS#", var"#ARF#", var"#AM#", var"#LPSYM#", var"#LB#".parameters, var"#V#".parameters, var"#UNROLL#")

0 commit comments

Comments
 (0)