Skip to content

Commit 8657d6b

Browse files
authored
Merge pull request #84 from timholy/teh/multiindices
WIP: support mixed Int/CartesianIndex indexing
2 parents ca31ba1 + a311e1a commit 8657d6b

File tree

7 files changed

+218
-74
lines changed

7 files changed

+218
-74
lines changed

src/add_ifelse.jl

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
## Currently, if/else will create its own local scope
33
## Assignments will not register in the loop's main scope
44
## although stores and return values will.
5-
5+
negateop!(ls::LoopSet, condop::Operation, elementbytes::Int) = add_compute!(ls, gensym(:negated_mask), :~, [condop], elementbytes)
66

77
function add_if!(ls::LoopSet, LHS::Symbol, RHS::Expr, elementbytes::Int, position::Int, mpref::Union{Nothing,ArrayReferenceMetaPosition} = nothing)
88
# for now, just simple 1-liners
@@ -14,18 +14,24 @@ function add_if!(ls::LoopSet, LHS::Symbol, RHS::Expr, elementbytes::Int, positio
1414
add_operation!(ls, gensym(:mask), condition, mpref, elementbytes, position)
1515
end
1616
iftrue = RHS.args[2]
17-
trueop = if iftrue isa Expr
18-
(iftrue isa Expr && iftrue.head !== :call) && throw("Only calls or constant expressions are currently supported in if/else blocks.")
19-
add_operation!(ls, Symbol(:iftrue), iftrue, elementbytes, position)
17+
if iftrue isa Expr
18+
trueop = add_operation!(ls, Symbol(:iftrue), iftrue, elementbytes, position)
19+
if iftrue.head === :ref && all(ld -> ld loopdependencies(trueop), loopdependencies(condop))
20+
trueop.instruction = Instruction(:conditionalload)
21+
push!(parents(trueop), condop)
22+
end
2023
else
21-
getop(ls, iftrue, elementbytes)
24+
trueop = getop(ls, iftrue, elementbytes)
2225
end
2326
iffalse = RHS.args[3]
24-
falseop = if iffalse isa Expr
25-
(iffalse isa Expr && iffalse.head !== :call) && throw("Only calls or constant expressions are currently supported in if/else blocks.")
26-
add_operation!(ls, Symbol(:iffalse), iffalse, elementbytes, position)
27+
if iffalse isa Expr
28+
falseop = add_operation!(ls, Symbol(:iffalse), iffalse, elementbytes, position)
29+
if iffalse.head === :ref && all(ld -> ld loopdependencies(falseop), loopdependencies(condop))
30+
falseop.instruction = Instruction(:conditionalload)
31+
push!(parents(falseop), negateop!(ls, condop, elementbytes))
32+
end
2733
else
28-
getop(ls, iffalse, elementbytes)
34+
falseop = getop(ls, iffalse, elementbytes)
2935
end
3036
add_compute!(ls, LHS, :vifelse, [condop, trueop, falseop], elementbytes)
3137
end
@@ -67,7 +73,7 @@ function add_andblock!(ls::LoopSet, ex::Expr, elementbytes::Int, position::Int)
6773
end
6874

6975
function add_orblock!(ls::LoopSet, condop::Operation, LHS, rhsop::Operation, elementbytes::Int, position::Int)
70-
negatedcondop = add_compute!(ls, gensym(:negated_mask), :~, [condop], elementbytes)
76+
negatedcondop = negateop!(ls, condop, elementbytes)
7177
if LHS isa Symbol
7278
altop = getop(ls, LHS, elementbytes)
7379
# return add_compute!(ls, LHS, :vifelse, [condop, altop, rhsop], elementbytes)

src/condense_loopset.jl

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -113,23 +113,23 @@ function OperationStruct!(varnames::Vector{Symbol}, ls::LoopSet, op::Operation)
113113
end
114114
## turn a LoopSet into a type object which can be used to reconstruct the LoopSet.
115115

116+
function loop_boundary(loop::Loop)
117+
startexact = loop.startexact
118+
stopexact = loop.stopexact
119+
if startexact & stopexact
120+
Expr(:call, Expr(:curly, lv(:StaticUnitRange), loop.starthint, loop.stophint))
121+
elseif startexact
122+
Expr(:call, Expr(:curly, lv(:StaticLowerUnitRange), loop.starthint), loop.stopsym)
123+
elseif stopexact
124+
Expr(:call, Expr(:curly, lv(:StaticUpperUnitRange), loop.stophint), loop.startsym)
125+
else
126+
Expr(:call, :(:), loop.startsym, loop.stopsym)
127+
end
128+
end
116129

117130
function loop_boundaries(ls::LoopSet)
118131
lbd = Expr(:tuple)
119-
for loop ls.loops
120-
startexact = loop.startexact
121-
stopexact = loop.stopexact
122-
lexpr = if startexact & stopexact
123-
Expr(:call, Expr(:curly, lv(:StaticUnitRange), loop.starthint, loop.stophint))
124-
elseif startexact
125-
Expr(:call, Expr(:curly, lv(:StaticLowerUnitRange), loop.starthint), loop.stopsym)
126-
elseif stopexact
127-
Expr(:call, Expr(:curly, lv(:StaticUpperUnitRange), loop.stophint), loop.startsym)
128-
else
129-
Expr(:call, :(:), loop.startsym, loop.stopsym)
130-
end
131-
push!(lbd.args, lexpr)
132-
end
132+
foreach(loop -> push!(lbd.args, loop_boundary(loop)), ls.loops)
133133
lbd
134134
end
135135

@@ -204,6 +204,7 @@ end
204204
::Val{UT}, ::Type{OPS}, ::Type{ARF}, ::Type{AM}, ::Type{LPSYM}, lb::LB,
205205
::Val{AR}, ::Val{D}, ::Val{IND}, subsetvals, arraydescript, vargs::Vararg{<:Any,N}
206206
) where {UT, OPS, ARF, AM, LPSYM, LB, N, AR, D, IND}
207+
1 + 1
207208
num_vptrs = length(ARF.parameters)::Int
208209
vptrs = [gensym(:vptr) for _ 1:num_vptrs]
209210
call = Expr(:call, lv(:_avx_!), Val{UT}(), OPS, ARF, AM, LPSYM, :lb)
@@ -279,14 +280,23 @@ function generate_call(ls::LoopSet, IUT, debug::Bool = false)
279280
add_external_functions!(q, ls)
280281
q
281282
end
282-
283+
concat_vals() = Val{()}()
284+
# @generated concat_vals(::Val{N}) where {N} = Val{(N,)}()
285+
# @generated concat_vals(::Val{M}, ::Val{N}) where {M, N} = Val{(M,N)}()
286+
@generated function concat_vals(args...)
287+
tup = Expr(:tuple)
288+
for n in eachindex(args)
289+
push!(tup.args, args[n].parameters[1])
290+
end
291+
Expr(:call, Expr(:curly, :Val, tup))
292+
end
283293
function setup_call_noinline(ls::LoopSet, U = zero(Int8), T = zero(Int8))
284294
call = generate_call(ls, (false,U,T))
285295
hasouterreductions = length(ls.outer_reductions) > 0
286296
q = Expr(:block)
287297
vptrarrays = Expr(:tuple)
288298
vptrsubsetvals = Expr(:tuple)
289-
vptrsubsetdims = Expr(:tuple)
299+
vptrsubsetdims = Expr(:call, lv(:concat_vals))
290300
vptrindices = Expr(:tuple)
291301
stridedpointerLHS = Symbol[]
292302
loopvalueLHS = Symbol[]
@@ -304,7 +314,7 @@ function setup_call_noinline(ls::LoopSet, U = zero(Int8), T = zero(Int8))
304314
@assert array loopvalueLHS
305315
push!(vptrarrays.args, -1)
306316
end
307-
push!(vptrsubsetdims.args, nothing)
317+
push!(vptrsubsetdims.args, Expr(:call, Expr(:curly, :Val, nothing)))
308318
vp = first(ex.args)::Symbol
309319
push!(stridedpointerLHS, vp)
310320
push!(vptrindices.args, findfirst(a -> vptr(a) == vp, ls.refs_aliasing_syms))
@@ -316,7 +326,7 @@ function setup_call_noinline(ls::LoopSet, U = zero(Int8), T = zero(Int8))
316326
@assert vptrarrayid isa Int
317327
end
318328
push!(vptrarrays.args, vptrarrayid::Int)
319-
push!(vptrsubsetdims.args, ex.args[2].args[3].args[1].args[2])
329+
push!(vptrsubsetdims.args, ex.args[2].args[3])#.args[1].args[2])
320330
push!(vptrsubsetvals.args, ex.args[2].args[4])
321331
vp = first(ex.args)::Symbol
322332
push!(stridedpointerLHS, vp)
@@ -327,7 +337,8 @@ function setup_call_noinline(ls::LoopSet, U = zero(Int8), T = zero(Int8))
327337
push!(q.args, ex)
328338
end
329339
insert!(call.args, 8, Expr(:call, Expr(:curly, :Val, vptrarrays)))
330-
insert!(call.args, 9, Expr(:call, Expr(:curly, :Val, vptrsubsetdims)))
340+
# insert!(call.args, 9, Expr(:call, Expr(:curly, :Val, vptrsubsetdims)))
341+
insert!(call.args, 9, vptrsubsetdims)
331342
insert!(call.args, 10, Expr(:call, Expr(:curly, :Val, vptrindices)))
332343
insert!(call.args, 11, vptrsubsetvals)
333344
if hasouterreductions

src/lower_load.jl

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,63 @@
11
function lower_load_scalar!(
2-
q::Expr, op::Operation, vectorized::Symbol, W::Symbol, unrolled::Symbol, tiled::Symbol, U::Int,
3-
suffix::Union{Nothing,Int}, mask::Union{Nothing,Symbol,Unsigned} = nothing, umin::Int = 0
2+
q::Expr, op::Operation, vectorized::Symbol, W::Symbol, unrolled::Symbol,
3+
tiled::Symbol, U::Int, suffix::Union{Nothing,Int}, umin::Int = 0
44
)
55
loopdeps = loopdependencies(op)
66
@assert vectorized loopdeps
77
var = variable_name(op, suffix)
88
ptr = refname(op)
99
isunrolled = unrolled loopdeps
1010
U = isunrolled ? U : 1
11-
for u umin:U-1
12-
varname = varassignname(var, u, isunrolled)
13-
td = UnrollArgs(u, unrolled, tiled, suffix)
14-
push!(q.args, Expr(:(=), varname, Expr(:call, lv(:vload), ptr, mem_offset_u(op, td))))
11+
if instruction(op).instr !== :conditionalload
12+
for u umin:U-1
13+
varname = varassignname(var, u, isunrolled)
14+
td = UnrollArgs(u, unrolled, tiled, suffix)
15+
push!(q.args, Expr(:(=), varname, Expr(:call, lv(:vload), ptr, mem_offset_u(op, td))))
16+
end
17+
else
18+
condop = last(parents(op))
19+
condvar = variable_name(condop, suffix)
20+
condunrolled = any(isequal(unrolled), loopdependencies(condop))
21+
for u umin:U-1
22+
condsym = condunrolled ? Symbol(condvar, u) : condvar
23+
varname = varassignname(var, u, isunrolled)
24+
td = UnrollArgs(u, unrolled, tiled, suffix)
25+
load = Expr(:call, lv(:vload), ptr, mem_offset_u(op, td))
26+
cload = Expr(:if, condsym, load, Expr(:call, :zero, Expr(:call, :eltype, ptr)))
27+
push!(q.args, Expr(:(=), varname, cload))
28+
end
1529
end
1630
nothing
1731
end
18-
function pushvectorload!(q::Expr, op::Operation, var::Symbol, td::UnrollArgs, U::Int, W::Symbol, mask, vecnotunrolled::Bool)
19-
@unpack u, unrolled = td
32+
function pushvectorload!(
33+
q::Expr, op::Operation, var::Symbol, td::UnrollArgs, U::Int, W::Symbol, vectorized::Symbol, mask
34+
)
35+
@unpack u, unrolled, suffix = td
2036
ptr = refname(op)
37+
vecnotunrolled = vectorized !== unrolled
2138
name, mo = name_memoffset(var, op, td, W, vecnotunrolled)
2239
instrcall = Expr(:call, lv(:vload), ptr, mo)
23-
if mask !== nothing && (vecnotunrolled || u == U - 1)
40+
41+
iscondstore = instruction(op).instr === :conditionalload
42+
maskend = mask !== nothing && (vecnotunrolled || u == U - 1)
43+
if iscondstore
44+
condop = last(parents(op))
45+
# @show condop
46+
condsym = variable_name(condop, suffix)
47+
condsym = any(isequal(unrolled), loopdependencies(condop)) ? Symbol(condsym, u) : condsym
48+
if vectorized loopdependencies(condop)
49+
if maskend
50+
push!(instrcall.args, Expr(:call, :&, condsym, mask))
51+
else
52+
push!(instrcall.args, condsym)
53+
end
54+
else
55+
if maskend
56+
push!(instrcall.args, mask)
57+
end
58+
instrcall = Expr(:if, condsym, instrcall, Expr(:call, lv(:vzero), W, Expr(:call, :eltype, ptr)))
59+
end
60+
elseif maskend
2461
push!(instrcall.args, mask)
2562
end
2663
push!(q.args, Expr(:(=), name, instrcall))
@@ -40,10 +77,9 @@ function lower_load_vectorized!(
4077
end
4178
# Urange = unrolled ∈ loopdeps ? 0:U-1 : 0
4279
var = variable_name(op, suffix)
43-
vecnotunrolled = vectorized !== unrolled
4480
for u umin:U-1
4581
td = UnrollArgs(u, unrolled, tiled, suffix)
46-
pushvectorload!(q, op, var, td, U, W, mask, vecnotunrolled)
82+
pushvectorload!(q, op, var, td, U, W, vectorized, mask)
4783
end
4884
nothing
4985
end
@@ -73,6 +109,6 @@ function lower_load!(
73109
if vectorized loopdependencies(op)
74110
lower_load_vectorized!(q, op, vectorized, W, unrolled, tiled, U, suffix, mask, umin)
75111
else
76-
lower_load_scalar!(q, op, vectorized, W, unrolled, tiled, U, suffix, mask, umin)
112+
lower_load_scalar!(q, op, vectorized, W, unrolled, tiled, U, suffix, umin)
77113
end
78114
end

src/lower_store.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ end
4141
# const STOREOP = :vstore!
4242
variable_name(op::Operation, ::Nothing) = mangledvar(op)
4343
variable_name(op::Operation, suffix) = Symbol(mangledvar(op), suffix, :_)
44+
# variable_name(op::Operation, suffix, u::Int) = (n = variable_name(op, suffix); u < 0 ? n : Symbol(n, u))
4445
function reduce_range!(q::Expr, toreduct::Symbol, instr::Instruction, Uh::Int, Uh2::Int)
4546
for u 0:Uh-1
4647
tru = Symbol(toreduct, u)

src/memory_ops_common.jl

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,46 @@ function add_vptr!(ls::LoopSet, array::Symbol, vptrarray::Symbol = vptr(array),
2929
end
3030
nothing
3131
end
32-
function subset_vptr!(ls::LoopSet, vptr::Symbol, indnum::Int, ind::Union{Symbol,Int})
32+
33+
@inline valsum() = Val{0}()
34+
@inline valsum(::Val{M}) where {M} = Val{M}()
35+
@generated valsum(::Val{M}, ::Val{N}) where {M,N} = Val{M+N}()
36+
@inline valsum(::Val{M}, ::Val{N}, ::Val{K}, args...) where {M,N,K} = valsum(valsum(Val{M}(), Val{N}()), Val{K}(), args...)
37+
@inline valdims(::Any) = Val{1}()
38+
@inline valdims(::CartesianIndices{N}) where {N} = Val{N}()
39+
40+
function append_loop_valdims!(valcall::Expr, loop::Loop)
41+
if isstaticloop(loop)
42+
push!(valcall.args, :(Val{1}()))
43+
else
44+
push!(valcall.args, Expr(:call, lv(:valdims), loop_boundary(loop)))
45+
end
46+
nothing
47+
end
48+
function subset_vptr!(ls::LoopSet, vptr::Symbol, indnum::Int, ind, previndices, loopindex)
3349
subsetvptr = Symbol(vptr, "_subset_$(indnum)_with_$(ind)##")
34-
inde = ind isa Symbol ? Expr(:call, :-, ind, 1) : ind - 1
35-
pushpreamble!(ls, Expr(:(=), subsetvptr, Expr(:call, lv(:subsetview), vptr, Expr(:call, Expr(:curly, :Val, indnum)), inde)))
50+
valcall = Expr(:call, Expr(:curly, :Val, 1))
51+
if indnum > 1
52+
valcall = Expr(:call, lv(:valsum), valcall)
53+
for i 1:indnum-1
54+
if loopindex[i]
55+
append_loop_valdims!(valcall, getloop(ls, previndices[i]))
56+
else
57+
for loopdep loopdependencies(ls.opdict[previndices[i]])
58+
append_loop_valdims!(valcall, getloop(ls, loopdep))
59+
end
60+
end
61+
end
62+
end
63+
# @show valcall
64+
indm1 = ind isa Integer ? ind - 1 : Expr(:call, :-, ind, 1)
65+
pushpreamble!(ls, Expr(:(=), subsetvptr, Expr(:call, lv(:subsetview), vptr, valcall, indm1)))
3666
subsetvptr
3767
end
3868
const DISCONTIGUOUS = Symbol("##DISCONTIGUOUSSUBARRAY##")
3969
function array_reference_meta!(ls::LoopSet, array::Symbol, rawindices, elementbytes::Int, var::Union{Nothing,Symbol} = nothing)
4070
vptrarray = vptr(array)
4171
add_vptr!(ls, array, vptrarray) # now, subset
42-
4372
indices = Symbol[]
4473
loopedindex = Bool[]
4574
parents = Operation[]
@@ -49,7 +78,7 @@ function array_reference_meta!(ls::LoopSet, array::Symbol, rawindices, elementby
4978
ninds = 1
5079
for ind rawindices
5180
if ind isa Integer # subset
52-
vptrarray = subset_vptr!(ls, vptrarray, ninds, ind)
81+
vptrarray = subset_vptr!(ls, vptrarray, ninds, ind, indices, loopedindex)
5382
length(indices) == 0 && push!(indices, DISCONTIGUOUS)
5483
elseif ind isa Expr
5584
#FIXME: position (in loopnest) wont be length(ls.loopsymbols) in general
@@ -66,12 +95,11 @@ function array_reference_meta!(ls::LoopSet, array::Symbol, rawindices, elementby
6695
else
6796
indop = get(ls.opdict, ind, nothing)
6897
if indop !== nothing && !isconstant(indop)
69-
pushparent!(parents, loopdependencies, reduceddeps, parent) # FIXME where does `parent` come from?
70-
# var = get(ls.opdict, ind, nothing)
71-
push!(indices, name(parent)); ninds += 1
98+
pushparent!(parents, loopdependencies, reduceddeps, indop)
99+
push!(indices, name(indop)); ninds += 1
72100
push!(loopedindex, false)
73101
else
74-
vptrarray = subset_vptr!(ls, vptrarray, ninds, ind)
102+
vptrarray = subset_vptr!(ls, vptrarray, ninds, ind, indices, loopedindex)
75103
length(indices) == 0 && push!(indices, DISCONTIGUOUS)
76104
end
77105
end

0 commit comments

Comments
 (0)