Skip to content

Commit a2b4bd0

Browse files
committed
Improve support for AbstractVector{<:Bool} and BitVecOrMat, resolves #91. Add an optimization to favor unrolling in direction of static offsets.
1 parent 6b21805 commit a2b4bd0

13 files changed

+301
-49
lines changed

Project.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@ VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
1313

1414
[compat]
1515
OffsetArrays = "1"
16-
SIMDPirates = "0.7.11"
16+
SIMDPirates = "0.7.12"
1717
SLEEFPirates = "0.4.4"
1818
UnPack = "0"
19-
VectorizationBase = "0.10"
19+
VectorizationBase = "0.10.1"
2020
julia = "1.1"
2121

2222
[extras]
23+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2324
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2425

2526
[targets]
26-
test = ["Test"]
27+
test = ["Random", "Test"]

src/broadcast.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ end
244244
q = lower(ls)
245245
push!(q.args, :dest)
246246
pushfirst!(q.args, Expr(:meta,:inline))
247+
# @show q
247248
q
248249
# ls
249250
end

src/condense_loopset.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ function ArrayRefStruct(ls::LoopSet, mref::ArrayReferenceMeta, arraysymbolinds::
3535
index_types <<= 8
3636
indices <<= 8
3737
offsets <<= 8
38-
offsets |= offv[n]
38+
offsets |= (offv[n] % UInt8)
3939
if mref.loopedindex[n]
4040
index_types |= LoopIndex
4141
indices |= getloopid(ls, ind)

src/costs.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ const OPAQUE_INSTRUCTION = InstructionCost(50, 50.0, -1.0, VectorizationBase.REG
100100
# hand, should indicate how many registers we're keeping live for the sake of eventually storing.
101101
const COST = Dict{Instruction,InstructionCost}(
102102
Instruction(:getindex) => InstructionCost(-3.0,0.5,3,1),
103+
Instruction(:conditionalload) => InstructionCost(-3.0,0.5,3,1),
103104
Instruction(:setindex!) => InstructionCost(-3.0,1.0,3,0),
104105
Instruction(:conditionalstore!) => InstructionCost(-3.0,1.0,3,0),
105106
Instruction(:zero) => InstructionCost(1,0.5),

src/determinestrategy.jl

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,53 @@ function isoptranslation(ls::LoopSet, op::Operation, u1::Symbol, u2::Symbol, vec
458458
end
459459
istranslation, translationplus
460460
end
461-
function convolution_cost_factor(ls::LoopSet, op::Operation, u1::Symbol, u2::Symbol, v::Symbol)
461+
function maxnegativeoffset(ls::LoopSet, op::Operation, u::Symbol)
462+
opmref = op.ref
463+
opref = opmref.ref
464+
mno = typemin(Int)
465+
id = 0
466+
for opp operations(ls)
467+
opp === op && continue
468+
oppmref = opp.ref
469+
oppref = oppmref.ref
470+
sameref(opref, oppref) || continue
471+
opinds = getindicesonly(op)
472+
oppinds = getindicesonly(opp)
473+
opoffs = opref.offsets
474+
oppoffs = oppref.offsets
475+
# oploopi = opmref.loopedindex
476+
# opploopi = oppmref.loopedindex
477+
mnonew = typemin(Int)
478+
for i eachindex(opinds)
479+
if opinds[i] !== oppinds[i]
480+
mnonew = 1
481+
break
482+
end
483+
if opinds[i] === u
484+
mnonew = (opoffs[i] - oppoffs[i])
485+
elseif opoffs[i] != oppoffs[i]
486+
mnonew = 1
487+
break
488+
end
489+
end
490+
if mno < mnonew < 0
491+
mno = mnonew
492+
id = identifier(opp)
493+
end
494+
end
495+
mno, id
496+
end
497+
function maxnegativeoffset(ls::LoopSet, op::Operation, u1::Symbol, u2::Symbol, v::Symbol)
498+
mno = typemin(Int)
499+
if u1 !== v
500+
mno = first(maxnegativeoffset(ls, op, u1))
501+
end
502+
if u2 !== v
503+
mno = max(mno, first(maxnegativeoffset(ls, op, u2)))
504+
end
505+
mno
506+
end
507+
function loadelimination_cost_factor(ls::LoopSet, op::Operation, u1::Symbol, u2::Symbol, v::Symbol)
462508
if first(isoptranslation(ls, op, u1, u2, v))
463509
for loop ls.loops
464510
# If another loop is short, assume that LLVM will unroll it, in which case
@@ -473,7 +519,12 @@ function convolution_cost_factor(ls::LoopSet, op::Operation, u1::Symbol, u2::Sym
473519
end
474520
(0.25, VectorizationBase.REGISTER_COUNT == 32 ? 1.2 : 1.0)
475521
else
476-
(1.0, 1.0)
522+
offset = maxnegativeoffset(ls, op, u1, u2, v)
523+
if -5 < offset < 0
524+
(-0.25offset, 1.0)
525+
else
526+
(1.0, 1.0)
527+
end
477528
end
478529
end
479530
# Just tile outer two loops?
@@ -542,7 +593,7 @@ function evaluate_cost_tile(
542593
rt, lat, rp = cost(ls, op, vectorized, Wshift, size_T)
543594
# @show op rt, lat, rp
544595
if isload(op)
545-
factor1, factor2 = convolution_cost_factor(ls, op, unrolled, tiled, vectorized)
596+
factor1, factor2 = loadelimination_cost_factor(ls, op, unrolled, tiled, vectorized)
546597
rt *= factor1; rp *= factor2;
547598
end
548599
# @show isunrolled, istiled, op rt, lat, rp

src/lower_load.jl

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ function lower_load_scalar!(
1616
end
1717
else
1818
condop = last(parents(op))
19-
condvar = variable_name(condop, suffix)
20-
condunrolled = any(isequal(unrolled), loopdependencies(condop))
19+
condvar = tiled loopdependencies(condop) ? variable_name(condop, suffix) : variable_name(condop, nothing)
20+
condunrolled = unrolled loopdependencies(condop)
2121
for u umin:U-1
2222
condsym = condunrolled ? Symbol(condvar, u) : condvar
2323
varname = varassignname(var, u, isunrolled)
@@ -32,7 +32,7 @@ end
3232
function pushvectorload!(
3333
q::Expr, op::Operation, var::Symbol, td::UnrollArgs, U::Int, W::Symbol, vectorized::Symbol, mask
3434
)
35-
@unpack u, unrolled, suffix = td
35+
@unpack u, unrolled, tiled, suffix = td
3636
ptr = refname(op)
3737
vecnotunrolled = vectorized !== unrolled
3838
name, mo = name_memoffset(var, op, td, W, vecnotunrolled)
@@ -43,8 +43,8 @@ function pushvectorload!(
4343
if iscondstore
4444
condop = last(parents(op))
4545
# @show condop
46-
condsym = variable_name(condop, suffix)
47-
condsym = any(isequal(unrolled), loopdependencies(condop)) ? Symbol(condsym, u) : condsym
46+
condsym = tiled loopdependencies(condop) ? variable_name(condop, suffix) : variable_name(condop, nothing)
47+
condsym = unrolled loopdependencies(condop) ? Symbol(condsym, u) : condsym
4848
if vectorized loopdependencies(condop)
4949
if maskend
5050
push!(instrcall.args, Expr(:call, :&, condsym, mask))
@@ -99,6 +99,23 @@ function lower_load!(
9999
push!(q.args, Expr(:(=), Symbol(varnew, u), Symbol(varold, u + 1)))
100100
end
101101
umin = U - 1
102+
elseif tiled !== vectorized
103+
mno, id = maxnegativeoffset(ls, op, tiled)
104+
if -suffix < mno < 0
105+
varnew = variable_name(op, suffix)
106+
varold = variable_name(operations(ls)[id], suffix + mno)
107+
opold = operations(ls)[id]
108+
if unrolled loopdependencies(op)
109+
for u 0:U-1
110+
push!(q.args, Expr(:(=), Symbol(varnew, u), Symbol(varold, u)))
111+
end
112+
else
113+
push!(q.args, Expr(:(=), varnew, varold))
114+
end
115+
return
116+
else
117+
umin = 0
118+
end
102119
else
103120
umin = 0
104121
end

src/lower_memory_common.jl

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,23 @@ function symbolind(ind::Symbol, op::Operation, td::UnrollArgs)
2323
end
2424
unrolled loopdependencies(parent) ? Symbol(pvar, u) : pvar
2525
end
26+
27+
addoffset(ex, offset::Integer) = iszero(offset) ? ex : Expr(:call, :+, ex, convert(Int, offset))
28+
2629
function mem_offset(op::Operation, td::UnrollArgs)
2730
# @assert accesses_memory(op) "Computing memory offset only makes sense for operations that access memory."
2831
ret = Expr(:tuple)
29-
indices = getindices(op)
32+
indices = getindicesonly(op)
33+
offsets = getoffsets(op)
3034
loopedindex = op.ref.loopedindex
31-
start = (first(indices) === Symbol("##DISCONTIGUOUSSUBARRAY##")) + 1
32-
for (n,ind) enumerate(@view(indices[start:end]))
33-
if ind isa Int
34-
push!(ret.args, ind)
35+
for (n,ind) enumerate(indices)
36+
offset = offsets[n]
37+
if ind isa Int # impossible
38+
push!(ret.args, ind + offset)
3539
elseif loopedindex[n]
36-
push!(ret.args, ind)
40+
push!(ret.args, addoffset(ind, offset))
3741
else
38-
push!(ret.args, symbolind(ind, op, td))
42+
push!(ret.args, addoffset(symbolind(ind, op, td), offset))
3943
end
4044
end
4145
ret
@@ -45,22 +49,23 @@ function mem_offset_u(op::Operation, td::UnrollArgs)
4549
@unpack unrolled, u = td
4650
incr = u
4751
ret = Expr(:tuple)
48-
indices = getindices(op)
52+
indices = getindicesonly(op)
53+
offsets = getoffsets(op)
4954
loopedindex = op.ref.loopedindex
5055
if incr == 0
5156
return mem_offset(op, td)
5257
# append_inds!(ret, indices, loopedindex)
5358
else
54-
start = (first(indices) === Symbol("##DISCONTIGUOUSSUBARRAY##")) + 1
55-
for (n,ind) enumerate(@view(indices[start:end]))
59+
for (n,ind) enumerate(indices)
60+
offset = offsets[n]
5661
if ind isa Int
5762
push!(ret.args, ind)
5863
elseif ind === unrolled
59-
push!(ret.args, Expr(:call, :+, ind, incr))
64+
push!(ret.args, Expr(:call, :+, ind, incr + offset))
6065
elseif loopedindex[n]
61-
push!(ret.args, ind)
66+
push!(ret.args, addoffset(ind, offset))
6267
else
63-
push!(ret.args, symbolind(ind, op, td))
68+
push!(ret.args, addoffset(symbolind(ind, op, td), offset))
6469
end
6570
end
6671
end
@@ -71,22 +76,35 @@ function mem_offset_u(op::Operation, td::UnrollArgs, mul::Symbol)
7176
@unpack unrolled, u = td
7277
incr = u
7378
ret = Expr(:tuple)
74-
indices = getindices(op)
79+
indices = getindicesonly(op)
80+
offsets = getoffsets(op)
7581
loopedindex = op.ref.loopedindex
7682
if incr == 0
7783
return mem_offset(op, td)
7884
# append_inds!(ret, indices, loopedindex)
7985
else
80-
start = (first(indices) === Symbol("##DISCONTIGUOUSSUBARRAY##")) + 1
81-
for (n,ind) enumerate(@view(indices[start:end]))
82-
if ind isa Int
83-
push!(ret.args, ind)
86+
for (n,ind) enumerate(indices)
87+
offset = offsets[n]
88+
if ind isa Int # impossible
89+
push!(ret.args, ind + offset)
8490
elseif ind === unrolled
85-
push!(ret.args, Expr(:call, :+, ind, Expr(:call, lv(:valmul), mul, incr)))
91+
if isone(incr)
92+
if iszero(offset)
93+
push!(ret.args, Expr(:call, lv(:valadd), mul, ind))
94+
else
95+
push!(ret.args, Expr(:call, :+, ind, Expr(:call, lv(:valadd), mul, convert(Int, offset))))
96+
end
97+
else
98+
if iszero(offset)
99+
push!(ret.args, Expr(:call, lv(:valmuladd), mul, incr, ind))
100+
else
101+
push!(ret.args, Expr(:call, :+, ind, Expr(:call, lv(:valmuladd), mul, incr, convert(Int, offset))))
102+
end
103+
end
86104
elseif loopedindex[n]
87-
push!(ret.args, ind)
105+
push!(ret.args, addoffset(ind, offset))
88106
else
89-
push!(ret.args, symbolind(ind, op, td))
107+
push!(ret.args, addoffset(symbolind(ind, op, td), offset))
90108
end
91109
end
92110
end

src/memory_ops_common.jl

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,51 @@ function subset_vptr!(ls::LoopSet, vptr::Symbol, indnum::Int, ind, previndices,
6565
pushpreamble!(ls, Expr(:(=), subsetvptr, Expr(:call, lv(:subsetview), vptr, valcall, indm1)))
6666
subsetvptr
6767
end
68+
69+
function addoffset!(ls, indices, offsets, loopedindex, loopdependencies, ind, offset)
70+
if typemin(Int8) offset typemax(Int8)
71+
push!(indices, ind);
72+
push!(offsets, offset % Int8)
73+
push!(loopedindex, true)
74+
push!(loopdependencies, ind)
75+
true
76+
else
77+
false
78+
end
79+
end
80+
81+
function checkforoffset!(
82+
ls::LoopSet, indices::Vector{Symbol}, offsets::Vector{Int8}, loopedindex::Vector{Bool}, loopdependencies::Vector{Symbol}, ind::Expr
83+
)
84+
ind.head === :call || return false
85+
f = first(ind.args)
86+
(((f === :+) || (f === :-)) && (length(ind.args) == 3)) || return false
87+
factor = f === :+ ? 1 : -1
88+
arg1 = ind.args[2]
89+
arg2 = ind.args[3]
90+
if arg1 isa Integer
91+
if arg2 isa Symbol && arg2 ls.loopsymbols
92+
addoffset!(ls, indices, offsets, loopedindex, loopdependencies, arg2, arg1 * factor)
93+
else
94+
false
95+
end
96+
elseif arg2 isa Integer
97+
if arg1 isa Symbol && arg1 ls.loopsymbols
98+
addoffset!(ls, indices, offsets, loopedindex, loopdependencies, arg1, arg2 * factor)
99+
else
100+
false
101+
end
102+
else
103+
false
104+
end
105+
end
106+
68107
const DISCONTIGUOUS = Symbol("##DISCONTIGUOUSSUBARRAY##")
69108
function array_reference_meta!(ls::LoopSet, array::Symbol, rawindices, elementbytes::Int, var::Union{Nothing,Symbol} = nothing)
70109
vptrarray = vptr(array)
71110
add_vptr!(ls, array, vptrarray) # now, subset
72111
indices = Symbol[]
112+
offsets = Int8[]
73113
loopedindex = Bool[]
74114
parents = Operation[]
75115
loopdependencies = Symbol[]
@@ -82,21 +122,26 @@ function array_reference_meta!(ls::LoopSet, array::Symbol, rawindices, elementby
82122
length(indices) == 0 && push!(indices, DISCONTIGUOUS)
83123
elseif ind isa Expr
84124
#FIXME: position (in loopnest) wont be length(ls.loopsymbols) in general
85-
parent = add_operation!(ls, gensym(:indexpr), ind, elementbytes, length(ls.loopsymbols))
86-
pushparent!(parents, loopdependencies, reduceddeps, parent)
87-
# var = get(ls.opdict, ind, nothing)
88-
push!(indices, name(parent)); ninds += 1
89-
push!(loopedindex, false)
125+
if !checkforoffset!(ls, indices, offsets, loopedindex, loopdependencies, ind)
126+
parent = add_operation!(ls, gensym(:indexpr), ind, elementbytes, length(ls.loopsymbols))
127+
pushparent!(parents, loopdependencies, reduceddeps, parent)
128+
push!(indices, name(parent));
129+
push!(offsets, zero(Int8))
130+
push!(loopedindex, false)
131+
end
132+
ninds += 1
90133
elseif ind isa Symbol
91134
if ind loopset
92135
push!(indices, ind); ninds += 1
136+
push!(offsets, zero(Int8))
93137
push!(loopedindex, true)
94138
push!(loopdependencies, ind)
95139
else
96140
indop = get(ls.opdict, ind, nothing)
97141
if indop !== nothing && !isconstant(indop)
98142
pushparent!(parents, loopdependencies, reduceddeps, indop)
99143
push!(indices, name(indop)); ninds += 1
144+
push!(offsets, zero(Int8))
100145
push!(loopedindex, false)
101146
else
102147
vptrarray = subset_vptr!(ls, vptrarray, ninds, ind, indices, loopedindex)
@@ -108,7 +153,7 @@ function array_reference_meta!(ls::LoopSet, array::Symbol, rawindices, elementby
108153
end
109154
end
110155
# (length(parents) != 0 && first(indices) !== Symbol("##DISCONTIGUOUSSUBARRAY##")) && pushfirst!(indices, Symbol("##DISCONTIGUOUSSUBARRAY##"))
111-
mref = ArrayReferenceMeta(ArrayReference( array, indices ), loopedindex, vptrarray)
156+
mref = ArrayReferenceMeta(ArrayReference( array, indices, offsets ), loopedindex, vptrarray)
112157
ArrayReferenceMetaPosition(mref, parents, loopdependencies, reduceddeps, isnothing(var) ? Symbol("") : var )
113158
end
114159
function tryrefconvert(ls::LoopSet, ex::Expr, elementbytes::Int, var::Union{Nothing,Symbol} = nothing)::Tuple{Bool,ArrayReferenceMetaPosition}

0 commit comments

Comments
 (0)