Skip to content

Commit 5063d0e

Browse files
committed
WIP: support CartesianIndex
1 parent 681f828 commit 5063d0e

10 files changed

+88
-49
lines changed

src/add_loads.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ function add_load!(
2424
ls::LoopSet, var::Symbol, mpref::ArrayReferenceMetaPosition, elementbytes::Int
2525
)
2626
length(mpref.loopdependencies) == 0 && return add_constant!(ls, var, mpref, elementbytes)
27-
ref = mpref.mref
2827
op = Operation( ls, var, elementbytes, :getindex, memload, mpref )
2928
add_load!(ls, op, true, false)
3029
end

src/condense_loopset.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44
Base.:|(u::Unsigned, it::IndexType) = u | UInt8(it)
55
Base.:(==)(u::Unsigned, it::IndexType) = (u % UInt8) == UInt8(it)
66

7+
"""
8+
`ArrayRefStruct` stores a representation of an array-reference expression such as `A[i,j]`.
9+
It supports array-references with up to 8 indexes, where the data for each consecutive index is packed into corresponding 8-bit fields
10+
of `index_types` (storing the enum `IndexType`), `indices` (the `id` for each index symbol), and `offsets` (currently unused).
11+
"""
712
struct ArrayRefStruct{array,ptr}
813
index_types::UInt64
914
indices::UInt64
@@ -396,4 +401,3 @@ function setup_call(ls::LoopSet, inline = Int8(2), U = zero(Int8), T = zero(Int8
396401
setup_call_noinline(ls, U, T)
397402
end
398403
end
399-

src/determinestrategy.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ function findparent(ls::LoopSet, s::Symbol)#opdict isn't filled when reconstruct
2121
end
2222
function unitstride(ls::LoopSet, op::Operation, s::Symbol)
2323
inds = getindices(op)
24-
li = op.ref.loopedindex
24+
li, lookup = op.ref.loopedindex, op.ref.indexlookup
2525
# The first index is allowed to be indexed by `s`
2626
fi = first(inds)
2727
if fi === Symbol("##DISCONTIGUOUSSUBARRAY##")
@@ -32,7 +32,7 @@ function unitstride(ls::LoopSet, op::Operation, s::Symbol)
3232
indexappearences(parent, s) > 1 && return false
3333
end
3434
for i 2:length(inds)
35-
if li[i]
35+
if li[lookup[i]]
3636
s === inds[i] && return false
3737
else
3838
parent = findparent(ls, inds[i])
@@ -348,7 +348,7 @@ function maybedemotesize(T::Int, N::Int, U::Int, Uloop::Loop, maxTbase::Int)
348348
end
349349
function solve_tilesize(
350350
ls::LoopSet, unrolled::Symbol, tiled::Symbol,
351-
cost_vec::AbstractVector{Float64},
351+
cost_vec::AbstractVector{Float64},
352352
reg_pressure::AbstractVector{Int},
353353
W::Int, vectorized::Symbol
354354
)
@@ -440,7 +440,7 @@ function evaluate_cost_tile(
440440
# Need to check if fusion is possible
441441
size_T = biggest_type_size(ls)
442442
W, Wshift = VectorizationBase.pick_vector_width_shift(length(ls, vectorized), size_T)::Tuple{Int,Int}
443-
# costs =
443+
# costs =
444444
# cost_mat[1] / ( unrolled * tiled)
445445
# cost_mat[2] / ( tiled)
446446
# cost_mat[3] / ( unrolled)
@@ -574,7 +574,7 @@ function choose_unroll_order(ls::LoopSet, lowest_cost::Float64 = Inf)
574574
iter = iterate(lo, state)
575575
iter === nothing && return best_order, best_vec, lowest_cost
576576
new_order, state = iter
577-
end
577+
end
578578
end
579579
function choose_tile(ls::LoopSet)
580580
lo = LoopOrders(ls)
@@ -632,4 +632,3 @@ function register_pressure(ls::LoopSet)
632632
tU * tT * rp[1] + tU * rp[2] + rp[3] + rp[4]
633633
end
634634
end
635-

src/graphs.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ Base.@propagate_inbounds Base.getindex(lo::LoopOrder, i...) = lo.oporder[LinearI
174174
# O(N) search is faster at small sizes
175175
struct LoopSet
176176
loopsymbols::Vector{Symbol}
177+
loopsymbol_offsets::Vector{Int} # symbol loopsymbols[i] corresponds to loops[lso[i]+1:lso[i+1]] (CartesianIndex handling)
177178
loops::Vector{Loop}
178179
opdict::Dict{Symbol,Operation}
179180
operations::Vector{Operation} # Split them to make it easier to iterate over just a subset
@@ -281,7 +282,7 @@ includesarray(ls::LoopSet, array::Symbol) = array ∈ ls.includedarrays
281282

282283
function LoopSet(mod::Symbol)# = :LoopVectorization)
283284
LoopSet(
284-
Symbol[], Loop[],
285+
Symbol[], [0], Loop[],
285286
Dict{Symbol,Operation}(),
286287
Operation[],
287288
Int[],

src/lower_load.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ function pushvectorload!(q::Expr, op::Operation, var::Symbol, td::UnrollArgs, U:
88
end
99
push!(q.args, Expr(:(=), name, instrcall))
1010
end
11-
function lower_load_scalar!(
11+
function lower_load_scalar!(
1212
q::Expr, op::Operation, vectorized::Symbol, W::Symbol, unrolled::Symbol, tiled::Symbol, U::Int,
1313
suffix::Union{Nothing,Int}, mask::Union{Nothing,Symbol,Unsigned} = nothing
1414
)
@@ -60,6 +60,3 @@ function lower_load!(
6060
lower_load_scalar!(q, op, vectorized, W, unrolled, tiled, U, suffix, mask)
6161
end
6262
end
63-
64-
65-

src/lower_memory_common.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ function mem_offset(op::Operation, td::UnrollArgs)
2727
# @assert accesses_memory(op) "Computing memory offset only makes sense for operations that access memory."
2828
ret = Expr(:tuple)
2929
indices = getindices(op)
30-
loopedindex = op.ref.loopedindex
30+
loopedindex, indexlookup = op.ref.loopedindex, op.ref.indexlookup
3131
start = (first(indices) === Symbol("##DISCONTIGUOUSSUBARRAY##")) + 1
3232
for (n,ind) enumerate(@view(indices[start:end]))
3333
if ind isa Int
3434
push!(ret.args, ind)
35-
elseif loopedindex[n]
35+
elseif loopedindex[indexlookup[n]]
3636
push!(ret.args, ind)
3737
else
3838
push!(ret.args, symbolind(ind, op, td))
@@ -46,7 +46,7 @@ function mem_offset_u(op::Operation, td::UnrollArgs)
4646
incr = u
4747
ret = Expr(:tuple)
4848
indices = getindices(op)
49-
loopedindex = op.ref.loopedindex
49+
loopedindex, indexlookup = op.ref.loopedindex, op.ref.indexlookup
5050
if incr == 0
5151
return mem_offset(op, td)
5252
# append_inds!(ret, indices, loopedindex)
@@ -57,7 +57,7 @@ function mem_offset_u(op::Operation, td::UnrollArgs)
5757
push!(ret.args, ind)
5858
elseif ind === unrolled
5959
push!(ret.args, Expr(:call, :+, ind, incr))
60-
elseif loopedindex[n]
60+
elseif loopedindex[indexlookup[n]]
6161
push!(ret.args, ind)
6262
else
6363
push!(ret.args, symbolind(ind, op, td))
@@ -117,4 +117,3 @@ function name_memoffset(var::Symbol, op::Operation, td::UnrollArgs, W::Symbol, v
117117
end
118118
name, mo
119119
end
120-

src/memory_ops_common.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ function array_reference_meta!(ls::LoopSet, array::Symbol, rawindices, elementby
6666
else
6767
indop = get(ls.opdict, ind, nothing)
6868
if indop !== nothing && !isconstant(indop)
69-
pushparent!(parents, loopdependencies, reduceddeps, parent)
69+
pushparent!(parents, loopdependencies, reduceddeps, parent) # FIXME where does `parent` come from?
7070
# var = get(ls.opdict, ind, nothing)
7171
push!(indices, name(parent)); ninds += 1
7272
push!(loopedindex, false)

src/operations.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@ struct ArrayReferenceMeta
2020
ref::ArrayReference
2121
loopedindex::Vector{Bool}
2222
ptr::Symbol
23+
indexlookup::Vector{Int}
2324
end
24-
function ArrayReferenceMeta(ref::ArrayReference, loopedindex, ptr = vptr(ref))
25+
function ArrayReferenceMeta(ref::ArrayReference, loopedindex, ptr = vptr(ref), indexlookup = [i for i in 1:length(loopedindex)])
2526
ArrayReferenceMeta(
26-
ref, loopedindex, ptr
27+
ref, loopedindex, ptr, indexlookup
2728
)
2829
end
2930
# function Base.hash(x::ArrayReference, h::UInt)
@@ -174,7 +175,7 @@ These names will be further processed if op is tiled and/or unrolled.
174175
if tiled ∈ loopdependencies(op) # `suffix` is tilenumber
175176
mvar = Symbol(op, suffix, :_)
176177
end
177-
if unrolled ∈ loopdependencies(op) # `u` is unroll number
178+
if unrolled ∈ loopdependencies(op) # `u` is unroll number
178179
mvar = Symbol(op, u)
179180
end
180181
```
@@ -240,6 +241,3 @@ getindices(op::Operation) = op.ref.ref.indices
240241
# # access stride info?
241242
# op.numerical_metadata[symposition(op,sym)]
242243
# end
243-
244-
245-

src/reconstruct_loopset.jl

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,33 @@ function Loop(ls, l, sym::Symbol, ::Type{StaticUnitRange{L,U}}) where {L,U}
2525
Loop(sym, L, U, Symbol(""), Symbol(""), true, true)::Loop
2626
end
2727

28+
function Loop(ls::LoopSet, l::Int, k::Int, sym::Symbol, ::Type{<:CartesianIndices{N}}) where N
29+
str = String(sym)*'#'*string(k)*'#'
30+
start = gensym(str*"_loopstart"); stop = gensym(str*"_loopstop")
31+
axisexpr = Expr(:ref, Expr(:., Expr(:ref, :lb, l), QuoteNode(:indices)), k)
32+
pushpreamble!(ls, Expr(:(=), start, Expr(:macrocall, Symbol("@inbounds"), LineNumberNode(@__LINE__, Symbol(@__FILE__)), Expr(:(.), axisexpr, QuoteNode(:start)))))
33+
pushpreamble!(ls, Expr(:(=), stop, Expr(:macrocall, Symbol("@inbounds"), LineNumberNode(@__LINE__, Symbol(@__FILE__)), Expr(:(.), axisexpr, QuoteNode(:stop)))))
34+
Loop(Symbol(str), 0, 1024, start, stop, false, false)::Loop
35+
end
36+
2837
function add_loops!(ls::LoopSet, LPSYM, LB)
2938
n = max(length(LPSYM), length(LB))
3039
for i = 1:n
31-
add_loop!(ls, Loop(ls, i, LPSYM[i], LB[i])::Loop)
40+
sym, l = LPSYM[i], LB[i]
41+
if l<:CartesianIndices
42+
add_loops!(ls, i, sym, l)
43+
else
44+
add_loop!(ls, Loop(ls, i, sym, l)::Loop)
45+
push!(ls.loopsymbol_offsets, ls.loopsymbol_offsets[end]+1)
46+
end
3247
end
3348
end
49+
function add_loops!(ls, i, sym, l::Type{<:CartesianIndices{N}}) where N
50+
for k = N:-1:1
51+
add_loop!(ls, Loop(ls, i, k, sym, l)::Loop)
52+
end
53+
push!(ls.loopsymbol_offsets, ls.loopsymbol_offsets[end]+N)
54+
end
3455

3556
function ArrayReferenceMeta(
3657
ls::LoopSet, @nospecialize(ar::ArrayRefStruct), arraysymbolinds::Vector{Symbol}, opsymbols::Vector{Symbol}
@@ -39,21 +60,28 @@ function ArrayReferenceMeta(
3960
indices = ar.indices
4061
offsets = ar.offsets
4162
ni = filled_8byte_chunks(index_types)
42-
index_vec = Vector{Symbol}(undef, ni)
63+
index_vec = Symbol[]
4364
offset_vec = Vector{Int8}(undef, ni)
4465
loopedindex = fill(false, ni)
66+
indexlookup = Int[]
4567
while index_types != zero(UInt64)
4668
ind = indices % UInt8
47-
symind = if index_types == LoopIndex
69+
if index_types == LoopIndex
70+
for inda in ls.loopsymbol_offsets[ind]+1:ls.loopsymbol_offsets[ind+1]
71+
pushfirst!(index_vec, ls.loopsymbols[inda])
72+
pushfirst!(indexlookup, ni)
73+
end
4874
loopedindex[ni] = true
49-
ls.loopsymbols[ind]
50-
elseif index_types == ComputedIndex
51-
opsymbols[ind]
5275
else
53-
@assert index_types == SymbolicIndex
54-
arraysymbolinds[ind]
76+
symind = if index_types == ComputedIndex
77+
opsymbols[ind]
78+
else
79+
@assert index_types == SymbolicIndex
80+
arraysymbolinds[ind]
81+
end
82+
pushfirst!(index_vec, symind)
83+
pushfirst!(indexlookup, ni)
5584
end
56-
index_vec[ni] = symind
5785
offset_vec[ni] = offsets % Int8
5886
index_types >>>= 8
5987
indices >>>= 8
@@ -62,7 +90,7 @@ function ArrayReferenceMeta(
6290
end
6391
ArrayReferenceMeta(
6492
ArrayReference(array(ar), index_vec, offset_vec),
65-
loopedindex, ptr(ar)
93+
loopedindex, ptr(ar), indexlookup
6694
)
6795
end
6896

@@ -134,14 +162,16 @@ function process_metadata!(ls::LoopSet, AM, num_arrays::Int)::Vector{Symbol}
134162
arraysymbolinds
135163
end
136164
function parents_symvec(ls::LoopSet, u::Unsigned)
137-
i = filled_4byte_chunks(u)
138-
loops = Vector{Symbol}(undef, i)
165+
loops = Symbol[]
166+
offsets = ls.loopsymbol_offsets
139167
while u != zero(u)
140-
loops[i] = getloopsym(ls, ( u % UInt8 ) & 0x0f )
141-
i -= 1
168+
idx = ( u % UInt8 ) & 0x0f
169+
for j = offsets[idx]+1:offsets[idx+1]
170+
push!(loops, getloopsym(ls, j))
171+
end
142172
u >>= 4
143173
end
144-
loops
174+
return reverse!(loops)
145175
end
146176
loopdependencies(ls::LoopSet, os::OperationStruct) = parents_symvec(ls, os.loopdeps)
147177
reduceddependencies(ls::LoopSet, os::OperationStruct) = parents_symvec(ls, os.reduceddeps)
@@ -227,7 +257,7 @@ function avx_loopset(instr, ops, arf, AM, LPSYM, LB, vargs)
227257
num_arrays = length(arf)
228258
elementbytes = sizeofeltypes(vargs, num_arrays)
229259
add_loops!(ls, LPSYM, LB)
230-
resize!(ls.loop_order, length(LB))
260+
resize!(ls.loop_order, ls.loopsymbol_offsets[end])
231261
arraysymbolinds = process_metadata!(ls, AM, length(arf))
232262
opsymbols = [gensym(:op) for _ eachindex(ops)]
233263
mrefs = create_mrefs!(ls, arf, arraysymbolinds, opsymbols, vargs)

test/offsetarrays.jl

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ using Test
8383
end
8484

8585

86-
86+
8787
struct SizedOffsetMatrix{T,LR,UR,LC,RC} <: AbstractMatrix{T}
8888
data::Matrix{T}
8989
end
@@ -141,25 +141,37 @@ using Test
141141
# lsuq = LoopVectorization.LoopSet(macroexpand(Base, uq));
142142
# LoopVectorization.choose_order(lsuq)
143143

144-
144+
145+
function avxgeneric!(out, A, kern, R=CartesianIndices(out), z=zero(eltype(out)))
146+
Rk = CartesianIndices(kern)
147+
@avx for I in R
148+
tmp = z
149+
for J in Rk
150+
tmp += A[I+J]*kern[J]
151+
end
152+
out[I] = tmp
153+
end
154+
out
155+
end
156+
145157
for T (Float32, Float64)
146158
@show T, @__LINE__
147159
A = rand(T, 100, 100);
148160
kern = OffsetArray(rand(T, 3, 3), -1:1, -1:1);
149161
skern = SizedOffsetMatrix{T,-1,1,-1,1}(parent(kern));
150162
out1 = OffsetArray(similar(A, size(A).-2), 1, 1); # stay away from the edges of A
151-
out2 = similar(out1); out3 = similar(out1);
163+
out2 = similar(out1); out3 = similar(out1); out4 = similar(out1)
152164

153165
old2d!(out1, A, kern);
154166
avx2d!(out2, A, kern);
155167
@test out1 out2
156-
168+
157169
avx2douter!(out3, A, kern);
158170
@test out1 out3
159171

160172
fill!(out2, NaN); avx2d!(out2, A, skern);
161173
@test out1 out2
162-
174+
163175
fill!(out3, NaN); avx2douter!(out3, A, skern);
164176
@test out1 out3
165177

@@ -168,12 +180,12 @@ using Test
168180

169181
fill!(out3, NaN); avx2dunrolled2x2!(out3, A, skern);
170182
@test out1 out3
171-
183+
172184
fill!(out3, NaN); avx2dunrolled3x3!(out3, A, skern);
173185
@test out1 out3
174186

187+
@test_broken avxgeneric!(out4, A, kern) out1
175188
end
176189

177-
178-
end
179190

191+
end

0 commit comments

Comments
 (0)