Skip to content

Commit cf2df70

Browse files
committed
Tests pass locally for d=5 mixed integer-Cartesian indexing.
1 parent f702c4a commit cf2df70

File tree

5 files changed

+98
-46
lines changed

5 files changed

+98
-46
lines changed

.travis.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ notifications:
1212
after_success:
1313
- julia -e 'using Pkg; Pkg.add("Coverage"); using Coverage; Codecov.submit(process_folder())'
1414
jobs:
15+
allow_failures:
16+
- julia: nightly
1517
include:
1618
- stage: "Documentation"
1719
julia: 1.3

src/condense_loopset.jl

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -113,23 +113,23 @@ function OperationStruct!(varnames::Vector{Symbol}, ls::LoopSet, op::Operation)
113113
end
114114
## turn a LoopSet into a type object which can be used to reconstruct the LoopSet.
115115

116+
function loop_boundary(loop::Loop)
117+
startexact = loop.startexact
118+
stopexact = loop.stopexact
119+
if startexact & stopexact
120+
Expr(:call, Expr(:curly, lv(:StaticUnitRange), loop.starthint, loop.stophint))
121+
elseif startexact
122+
Expr(:call, Expr(:curly, lv(:StaticLowerUnitRange), loop.starthint), loop.stopsym)
123+
elseif stopexact
124+
Expr(:call, Expr(:curly, lv(:StaticUpperUnitRange), loop.stophint), loop.startsym)
125+
else
126+
Expr(:call, :(:), loop.startsym, loop.stopsym)
127+
end
128+
end
116129

117130
function loop_boundaries(ls::LoopSet)
118131
lbd = Expr(:tuple)
119-
for loop ls.loops
120-
startexact = loop.startexact
121-
stopexact = loop.stopexact
122-
lexpr = if startexact & stopexact
123-
Expr(:call, Expr(:curly, lv(:StaticUnitRange), loop.starthint, loop.stophint))
124-
elseif startexact
125-
Expr(:call, Expr(:curly, lv(:StaticLowerUnitRange), loop.starthint), loop.stopsym)
126-
elseif stopexact
127-
Expr(:call, Expr(:curly, lv(:StaticUpperUnitRange), loop.stophint), loop.startsym)
128-
else
129-
Expr(:call, :(:), loop.startsym, loop.stopsym)
130-
end
131-
push!(lbd.args, lexpr)
132-
end
132+
foreach(loop -> push!(lbd.args, loop_boundary(loop)), ls.loops)
133133
lbd
134134
end
135135

@@ -204,6 +204,7 @@ end
204204
::Val{UT}, ::Type{OPS}, ::Type{ARF}, ::Type{AM}, ::Type{LPSYM}, lb::LB,
205205
::Val{AR}, ::Val{D}, ::Val{IND}, subsetvals, arraydescript, vargs::Vararg{<:Any,N}
206206
) where {UT, OPS, ARF, AM, LPSYM, LB, N, AR, D, IND}
207+
1 + 1
207208
num_vptrs = length(ARF.parameters)::Int
208209
vptrs = [gensym(:vptr) for _ 1:num_vptrs]
209210
call = Expr(:call, lv(:_avx_!), Val{UT}(), OPS, ARF, AM, LPSYM, :lb)
@@ -279,14 +280,23 @@ function generate_call(ls::LoopSet, IUT, debug::Bool = false)
279280
add_external_functions!(q, ls)
280281
q
281282
end
282-
283+
concat_vals() = Val{()}()
284+
# @generated concat_vals(::Val{N}) where {N} = Val{(N,)}()
285+
# @generated concat_vals(::Val{M}, ::Val{N}) where {M, N} = Val{(M,N)}()
286+
@generated function concat_vals(args...)
287+
tup = Expr(:tuple)
288+
for n in eachindex(args)
289+
push!(tup.args, args[n].parameters[1])
290+
end
291+
Expr(:call, Expr(:curly, :Val, tup))
292+
end
283293
function setup_call_noinline(ls::LoopSet, U = zero(Int8), T = zero(Int8))
284294
call = generate_call(ls, (false,U,T))
285295
hasouterreductions = length(ls.outer_reductions) > 0
286296
q = Expr(:block)
287297
vptrarrays = Expr(:tuple)
288298
vptrsubsetvals = Expr(:tuple)
289-
vptrsubsetdims = Expr(:tuple)
299+
vptrsubsetdims = Expr(:call, lv(:concat_vals))
290300
vptrindices = Expr(:tuple)
291301
stridedpointerLHS = Symbol[]
292302
loopvalueLHS = Symbol[]
@@ -304,7 +314,7 @@ function setup_call_noinline(ls::LoopSet, U = zero(Int8), T = zero(Int8))
304314
@assert array loopvalueLHS
305315
push!(vptrarrays.args, -1)
306316
end
307-
push!(vptrsubsetdims.args, nothing)
317+
push!(vptrsubsetdims.args, Expr(:call, Expr(:curly, :Val, nothing)))
308318
vp = first(ex.args)::Symbol
309319
push!(stridedpointerLHS, vp)
310320
push!(vptrindices.args, findfirst(a -> vptr(a) == vp, ls.refs_aliasing_syms))
@@ -316,7 +326,7 @@ function setup_call_noinline(ls::LoopSet, U = zero(Int8), T = zero(Int8))
316326
@assert vptrarrayid isa Int
317327
end
318328
push!(vptrarrays.args, vptrarrayid::Int)
319-
push!(vptrsubsetdims.args, ex.args[2].args[3].args[1].args[2])
329+
push!(vptrsubsetdims.args, ex.args[2].args[3])#.args[1].args[2])
320330
push!(vptrsubsetvals.args, ex.args[2].args[4])
321331
vp = first(ex.args)::Symbol
322332
push!(stridedpointerLHS, vp)
@@ -327,7 +337,8 @@ function setup_call_noinline(ls::LoopSet, U = zero(Int8), T = zero(Int8))
327337
push!(q.args, ex)
328338
end
329339
insert!(call.args, 8, Expr(:call, Expr(:curly, :Val, vptrarrays)))
330-
insert!(call.args, 9, Expr(:call, Expr(:curly, :Val, vptrsubsetdims)))
340+
# insert!(call.args, 9, Expr(:call, Expr(:curly, :Val, vptrsubsetdims)))
341+
insert!(call.args, 9, vptrsubsetdims)
331342
insert!(call.args, 10, Expr(:call, Expr(:curly, :Val, vptrindices)))
332343
insert!(call.args, 11, vptrsubsetvals)
333344
if hasouterreductions

src/memory_ops_common.jl

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,46 @@ function add_vptr!(ls::LoopSet, array::Symbol, vptrarray::Symbol = vptr(array),
2929
end
3030
nothing
3131
end
32-
function subset_vptr!(ls::LoopSet, vptr::Symbol, indnum::Int, ind::Union{Symbol,Int})
32+
33+
@inline valsum() = Val{0}()
34+
@inline valsum(::Val{M}) where {M} = Val{M}()
35+
@generated valsum(::Val{M}, ::Val{N}) where {M,N} = Val{M+N}()
36+
@inline valsum(::Val{M}, ::Val{N}, ::Val{K}, args...) where {M,N,K} = valsum(valsum(Val{M}(), Val{N}()), Val{K}(), args...)
37+
@inline valdims(::Any) = Val{1}()
38+
@inline valdims(::CartesianIndices{N}) where {N} = Val{N}()
39+
40+
function append_loop_valdims!(valcall::Expr, loop::Loop)
41+
if isstaticloop(loop)
42+
push!(valcall.args, :(Val{1}()))
43+
else
44+
push!(valcall.args, Expr(:call, lv(:valdims), loop_boundary(loop)))
45+
end
46+
nothing
47+
end
48+
function subset_vptr!(ls::LoopSet, vptr::Symbol, indnum::Int, ind, previndices, loopindex)
3349
subsetvptr = Symbol(vptr, "_subset_$(indnum)_with_$(ind)##")
34-
inde = ind isa Symbol ? Expr(:call, :-, ind, 1) : ind - 1
35-
pushpreamble!(ls, Expr(:(=), subsetvptr, Expr(:call, lv(:subsetview), vptr, Expr(:call, Expr(:curly, :Val, indnum)), inde)))
50+
valcall = Expr(:call, Expr(:curly, :Val, 1))
51+
if indnum > 1
52+
valcall = Expr(:call, lv(:valsum), valcall)
53+
for i 1:indnum-1
54+
if loopindex[i]
55+
append_loop_valdims!(valcall, getloop(ls, previndices[i]))
56+
else
57+
for loopdep loopdependencies(ls.opdict[previndices[i]])
58+
append_loop_valdims!(valcall, getloop(ls, loopdep))
59+
end
60+
end
61+
end
62+
end
63+
# @show valcall
64+
indm1 = ind isa Integer ? ind - 1 : Expr(:call, :-, ind, 1)
65+
pushpreamble!(ls, Expr(:(=), subsetvptr, Expr(:call, lv(:subsetview), vptr, valcall, indm1)))
3666
subsetvptr
3767
end
3868
const DISCONTIGUOUS = Symbol("##DISCONTIGUOUSSUBARRAY##")
3969
function array_reference_meta!(ls::LoopSet, array::Symbol, rawindices, elementbytes::Int)
4070
vptrarray = vptr(array)
4171
add_vptr!(ls, array, vptrarray) # now, subset
42-
4372
indices = Symbol[]
4473
loopedindex = Bool[]
4574
parents = Operation[]
@@ -49,7 +78,7 @@ function array_reference_meta!(ls::LoopSet, array::Symbol, rawindices, elementby
4978
ninds = 1
5079
for ind rawindices
5180
if ind isa Integer # subset
52-
vptrarray = subset_vptr!(ls, vptrarray, ninds, ind)
81+
vptrarray = subset_vptr!(ls, vptrarray, ninds, ind, indices, loopedindex)
5382
length(indices) == 0 && push!(indices, DISCONTIGUOUS)
5483
elseif ind isa Expr
5584
#FIXME: position (in loopnest) wont be length(ls.loopsymbols) in general
@@ -66,12 +95,11 @@ function array_reference_meta!(ls::LoopSet, array::Symbol, rawindices, elementby
6695
else
6796
indop = get(ls.opdict, ind, nothing)
6897
if indop !== nothing && !isconstant(indop)
69-
pushparent!(parents, loopdependencies, reduceddeps, parent) # FIXME where does `parent` come from?
70-
# var = get(ls.opdict, ind, nothing)
71-
push!(indices, name(parent)); ninds += 1
98+
pushparent!(parents, loopdependencies, reduceddeps, indop)
99+
push!(indices, name(indop)); ninds += 1
72100
push!(loopedindex, false)
73101
else
74-
vptrarray = subset_vptr!(ls, vptrarray, ninds, ind)
102+
vptrarray = subset_vptr!(ls, vptrarray, ninds, ind, indices, loopedindex)
75103
length(indices) == 0 && push!(indices, DISCONTIGUOUS)
76104
end
77105
end

src/reconstruct_loopset.jl

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
const NOpsType = Union{Int,Vector{Int}}
1+
const NOpsType = Int#Union{Int,Vector{Int}}
22

33
function Loop(ls::LoopSet, ex::Expr, sym::Symbol, ::Type{<:AbstractUnitRange})
44
ssym = String(sym)
@@ -239,11 +239,12 @@ function calcnops(ls::LoopSet, os::OperationStruct)
239239
offsets = ls.loopsymbol_offsets
240240
idxs = loopindex(ls, os.loopdeps, 0x04) # FIXME DRY
241241
iszero(length(idxs)) && return 1
242-
return map(i->offsets[i+1]-offsets[i], idxs)
242+
return maximum(i->offsets[i+1]-offsets[i], idxs)
243243
end
244244
function isexpanded(ls::LoopSet, ops::Vector{OperationStruct}, nopsv::Vector{NOpsType}, i::Int)
245245
nops = nopsv[i]
246-
(nops === 1 || nops == [1]) && return false
246+
# nops isa Vector{Int} only if accesses_memory(os), which means isexpanded must be false
247+
(nops === 1 || isa(nops, Vector{Int})) && return false
247248
os = ops[i]
248249
optyp = optype(os)
249250
if optyp == compute
@@ -260,7 +261,6 @@ function add_op!(
260261
mrefs::Vector{ArrayReferenceMeta}, opsymbol, elementbytes::Int
261262
)
262263
os = ops[i]
263-
nops = nopsv[i]
264264
# opsymbol = (isconstant(os) && instr != LOOPCONSTANT) ? instr.instr : opsymbol
265265
# If it's a CartesianIndex add or subtract, we may have to add multiple operations
266266
expanded = expandedv[i]# isexpanded(ls, ops, nopsv, i)
@@ -278,12 +278,7 @@ function add_op!(
278278
push!(opoffsets, opoffsets[end] + 1)
279279
return
280280
end
281-
if isa(nops, Vector)
282-
n = first(nops)
283-
if all(isequal(n), nops)
284-
nops = n
285-
end
286-
end
281+
nops = (nopsv[i])::Int # if it were a vector, it would have to have been expanded
287282
# if expanded, optyp must be either loopvalue, or compute (with loopvalues in its ancestry, not cutoff by loads)
288283
for offset = 0:nops-1
289284
sym = nops === 1 ? opsymbol : expandedopname(opsymbol, offset)

test/miscellaneous.jl

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -680,30 +680,46 @@ using Test
680680
function smoothdim_avx!(s, x, α, Rpre, irng::AbstractUnitRange, Rpost)
681681
ifirst, ilast = first(irng), last(irng)
682682
ifirst > ilast && return s
683-
@avx for Ipost in Rpost
684-
# Initialize the first value along the filtered dimension
683+
@avx tile=(1,1) for Ipost in Rpost
684+
# Handle all other entries
685685
for Ipre in Rpre
686686
s[Ipre, ifirst, Ipost] = x[Ipre, ifirst, Ipost]
687+
for i = ifirst+1:ilast
688+
s[Ipre, i, Ipost] = α*x[Ipre, i, Ipost] + (1-α)*x[Ipre, i-1, Ipost]
689+
end
687690
end
691+
end
692+
s
693+
end
694+
function smoothdim_ifelse_avx!(s, x, α, Rpre, irng::AbstractUnitRange, Rpost)
695+
ifirst, ilast = first(irng), last(irng)
696+
ifirst > ilast && return s
697+
@avx tile=(1,1) for Ipost in Rpost
688698
# Handle all other entries
689-
for i = ifirst+1:ilast
699+
for i = ifirst:ilast
690700
for Ipre in Rpre
691-
s[Ipre, i, Ipost] = α*x[Ipre, i, Ipost] + (1-α)*x[Ipre, i-1, Ipost]
701+
xi = x[Ipre, i, Ipost]
702+
xim = ifelse(i == ifirst, xi, x[Ipre, i-1, Ipost])
703+
s[Ipre, i, Ipost] = α*xi + (1-α)*xim
692704
end
693705
end
694706
end
695707
s
696708
end
697709

698-
x = rand(11,11,11) # ,11,11)
699-
dest1, dest2 = similar(x), similar(x)
710+
x = rand(11,11,11,11,11);
711+
dest1, dest2 = similar(x), similar(x);
700712
α = 0.3
701713
for d = 1:ndims(x)
702-
Rpre = CartesianIndices(axes(x)[1:d-1])
703-
Rpost = CartesianIndices(axes(x)[d+1:end])
714+
# @show d
715+
Rpre = CartesianIndices(axes(x)[1:d-1]);
716+
Rpost = CartesianIndices(axes(x)[d+1:end]);
704717
smoothdim!(dest1, x, α, Rpre, axes(x, d), Rpost)
705718
smoothdim_avx!(dest2, x, α, Rpre, axes(x, d), Rpost)
706719
@test dest1 dest2
720+
fill!(dest2, NaN); smoothdim_ifelse_avx!(dest2, x, α, Rpre, axes(x, d), Rpost)
721+
@test dest1 dest2
707722
end
708723
end
709724
end
725+

0 commit comments

Comments
 (0)