Skip to content

Commit bf35864

Browse files
committed
Add CloseOpen to make working with zero-indexed arrays easier and allow avoiding some +1/-1s when using them (note it isn't used internally at the moment), refactor code related to interval definitions.
1 parent 18f0cd4 commit bf35864

10 files changed

+290
-184
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.9.13"
4+
version = "0.9.14"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
@@ -14,13 +14,13 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
1414
VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
1515

1616
[compat]
17-
ArrayInterface = "2.14.9"
17+
ArrayInterface = "2.14.10"
1818
DocStringExtensions = "0.8"
1919
IfElse = "0.1"
2020
OffsetArrays = "1.4.1"
2121
SLEEFPirates = "0.6"
2222
UnPack = "1"
23-
VectorizationBase = "0.14"
23+
VectorizationBase = "0.14.10"
2424
julia = "1.5"
2525

2626
[extras]

src/LoopVectorization.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import LinearAlgebra # for check_args
3333
using Base.FastMath: add_fast, sub_fast, mul_fast, div_fast
3434

3535
using ArrayInterface
36-
using ArrayInterface: OptionallyStaticUnitRange, Zero, One
36+
using ArrayInterface: OptionallyStaticUnitRange, Zero, One, static_length
3737
const Static = ArrayInterface.StaticInt
3838

3939

@@ -52,6 +52,7 @@ const VECTORWIDTHSYMBOL, ELTYPESYMBOL = Symbol("##Wvecwidth##"), Symbol("##Tloop
5252

5353
include("vectorizationbase_compat/contract_pass.jl")
5454
include("vectorizationbase_compat/subsetview.jl")
55+
include("closeopen.jl")
5556
include("getconstindexes.jl")
5657
include("predicates.jl")
5758
include("map.jl")

src/broadcast.jl

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -102,14 +102,16 @@ function add_broadcast!(
102102
@nospecialize(prod::Type{<:Product}), elementbytes::Int
103103
)
104104
A, B = prod.parameters
105-
K = gensym!(ls, "K")
105+
Krange = gensym!(ls, "K")
106+
Klen = gensym!(ls, "K")
106107
mA = gensym!(ls, "Aₘₖ")
107108
mB = gensym!(ls, "Bₖₙ")
108109
pushprepreamble!(ls, Expr(:(=), mA, Expr(:(.), bcname, QuoteNode(:a))))
109110
pushprepreamble!(ls, Expr(:(=), mB, Expr(:(.), bcname, QuoteNode(:b))))
110-
pushprepreamble!(ls, Expr(:(=), K, Expr(:macrocall, Symbol("@inbounds"), LineNumberNode(@__LINE__,Symbol(@__FILE__)), Expr(:ref, Expr(:call, :size, mB), 1))))
111+
pushprepreamble!(ls, Expr(:(=), Klen, Expr(:macrocall, Symbol("@inbounds"), LineNumberNode(@__LINE__,Symbol(@__FILE__)), Expr(:ref, Expr(:call, :size, mB), 1))))
112+
pushpreamble!(ls, Expr(:(=), Krange, Expr(:call, :(:), staticexpr(1), Klen)))
111113
k = gensym!(ls, "k")
112-
add_loop!(ls, Loop(k, 1, K), k)
114+
add_loop!(ls, Loop(k, 1, Klen, Krange, Klen), k)
113115
m = loopsyms[1];
114116
if numdims(B) == 1
115117
bloopsyms = Symbol[k]
@@ -335,24 +337,34 @@ function add_broadcast!(
335337
pushop!(ls, op, destname)
336338
end
337339

340+
function add_broadcast_loops!(ls::LoopSet, loopsyms::Vector{Symbol}, destsym::Symbol)
341+
axes_tuple = Expr(:tuple)
342+
pushpreamble!(ls, Expr(:(=), axes_tuple, Expr(:call, :axes, destsym)))
343+
for (n,itersym) enumerate(loopsyms)
344+
Nrange = gensym!(ls, "N")
345+
Nlower = gensym!(ls, "N")
346+
Nupper = gensym!(ls, "N")
347+
Nlen = gensym!(ls, "N")
348+
add_loop!(ls, Loop(itersym, Nlower, Nupper, Nrange, Nlen), itersym)
349+
push!(axes_tuple.args, Nrange)
350+
pushpreamble!(ls, Expr(:(=), Nlower, Expr(:call, lv(:maybestaticfirst), Nrange)))
351+
pushpreamble!(ls, Expr(:(=), Nupper, Expr(:call, lv(:maybestaticlast), Nrange)))
352+
pushpreamble!(ls, Expr(:(=), Nlen, Expr(:call, lv(:static_length), Nrange)))
353+
end
354+
end
338355
# size of dest determines loops
339356
# function vmaterialize!(
340357
@generated function vmaterialize!(
341358
dest::AbstractArray{T,N}, bc::BC, ::Val{Mod}
342359
) where {T <: NativeTypes, N, BC <: Union{Broadcasted,Product}, Mod}
360+
# 2+1
343361
# we have an N dimensional loop.
344362
# need to construct the LoopSet
345363
# @show typeof(dest)
346364
ls = LoopSet(Mod)
347365
loopsyms = [gensym!(ls, "n") for n 1:N]
348366
ls.isbroadcast[] = true
349-
sizes = Expr(:tuple)
350-
for (n,itersym) enumerate(loopsyms)
351-
Nsym = gensym!(ls, "N")
352-
add_loop!(ls, Loop(itersym, 1, Nsym), itersym)
353-
push!(sizes.args, Nsym)
354-
end
355-
pushpreamble!(ls, Expr(:(=), sizes, Expr(:call, :size, :dest)))
367+
add_broadcast_loops!(ls, loopsyms, :dest)
356368
elementbytes = sizeof(T)
357369
add_broadcast!(ls, :dest, :bc, loopsyms, BC, elementbytes)
358370
add_simple_store!(ls, :dest, ArrayReference(:dest, loopsyms), elementbytes)
@@ -381,13 +393,7 @@ end
381393
loopsyms = [gensym!(ls, "n") for n 1:N]
382394
ls.isbroadcast[] = true
383395
pushprepreamble!(ls, Expr(:(=), :dest, Expr(:call, :parent, :dest′)))
384-
sizes = Expr(:tuple)
385-
for (n,itersym) enumerate(loopsyms)
386-
Nsym = gensym!(ls, "N")
387-
add_loop!(ls, Loop(itersym, 1, Nsym), itersym)
388-
push!(sizes.args, Nsym)
389-
end
390-
pushpreamble!(ls, Expr(:(=), sizes, Expr(:call, :size, :dest′)))
396+
add_broadcast_loops!(ls, loopsyms, :dest′)
391397
elementbytes = sizeof(T)
392398
add_broadcast!(ls, :dest, :bc, loopsyms, BC, elementbytes)
393399
add_simple_store!(ls, :dest, ArrayReference(:dest, reverse(loopsyms)), elementbytes)

src/closeopen.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
2+
struct CloseOpen{L <: Union{Int,StaticInt}, U <: Union{Int,StaticInt}} <: AbstractUnitRange{Int}
3+
start::L
4+
upper::U
5+
@inline CloseOpen(s::StaticInt{L}, u::StaticInt{U}) where {L,U} = new{StaticInt{L},StaticInt{U}}(s, u)
6+
@inline CloseOpen(s::Integer, u::StaticInt{U}) where {U} = new{Int,StaticInt{U}}(s % Int, u)
7+
@inline CloseOpen(s::StaticInt{L}, u::Integer) where {L} = new{StaticInt{L},Int}(s, u % Int)
8+
@inline CloseOpen(s::Integer, u::Integer) = new{Int,Int}(s % Int, u % Int)
9+
end
10+
@inline CloseOpen(len::Integer) = CloseOpen(Zero(), len)
11+
12+
@inline Base.first(r::CloseOpen) = r.start
13+
@inline Base.step(::CloseOpen) = One()
14+
@inline Base.last(r::CloseOpen) = r.upper - One()
15+
@inline Base.length(r::CloseOpen) = r.upper - r.start
16+
@inline Base.length(r::CloseOpen{Zero}) = r.upper
17+
18+
@inline Base.iterate(r::CloseOpen) = (i = Int(first(r)); (i, i))
19+
@inline Base.iterate(r::CloseOpen, i::Int) = (i += 1) == r.upper ? nothing : (i, i)
20+
21+
ArrayInterface.known_first(::Type{<:CloseOpen{StaticInt{F}}}) where {F} = F
22+
ArrayInterface.known_step(::Type{<:CloseOpen}) = 1
23+
ArrayInterface.known_last(::Type{<:CloseOpen{<:Any,StaticInt{L}}}) where {L} = L - 1
24+
ArrayInterface.known_length(::Type{CloseOpen{StaticInt{F},StaticInt{L}}}) where {F,L} = L - F
25+
26+
@inline canonicalize_range(r::OptionallyStaticUnitRange) = r
27+
@inline canonicalize_range(r::CloseOpen) = r
28+
@inline canonicalize_range(r::AbstractUnitRange) = maybestaticfirst(r):maybestaticlast(r)
29+
@inline canonicalize_range(r::CartesianIndices) = CartesianIndices(map(canonicalize_range, r.indices))
30+
31+

src/condense_loopset.jl

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -122,23 +122,26 @@ function OperationStruct!(varnames::Vector{Symbol}, ids::Vector{Int}, ls::LoopSe
122122
end
123123
## turn a LoopSet into a type object which can be used to reconstruct the LoopSet.
124124

125-
function loop_boundary(loop::Loop)
126-
startexact = loop.startexact
127-
stopexact = loop.stopexact
128-
if startexact & stopexact
129-
Expr(:call, lv(:OptionallyStaticUnitRange), staticexpr(loop.starthint), staticexpr(loop.stophint))
130-
elseif startexact
131-
Expr(:call, lv(:OptionallyStaticUnitRange), staticexpr(loop.starthint), loop.stopsym)
132-
elseif stopexact
133-
Expr(:call, lv(:OptionallyStaticUnitRange), loop.startsym, staticexpr(loop.stophint))
125+
function loop_boundary!(q::Expr, loop::Loop)
126+
if loop.startexact & loop.stopexact
127+
push!(q.args, Expr(:call, lv(:OptionallyStaticUnitRange), staticexpr(loop.starthint), staticexpr(loop.stophint)))
128+
elseif loop.rangesym === Symbol("")
129+
lb = if startexact
130+
Expr(:call, lv(:OptionallyStaticUnitRange), staticexpr(loop.starthint), loop.stopsym)
131+
elseif stopexact
132+
Expr(:call, lv(:OptionallyStaticUnitRange), loop.startsym, staticexpr(loop.stophint))
133+
else
134+
Expr(:call, :(:), loop.startsym, loop.stopsym)
135+
end
136+
push!(q.args, lb)
134137
else
135-
Expr(:call, :(:), loop.startsym, loop.stopsym)
138+
push!(q.args, loop.rangesym)
136139
end
137140
end
138141

139142
function loop_boundaries(ls::LoopSet)
140143
lbd = Expr(:tuple)
141-
foreach(loop -> push!(lbd.args, loop_boundary(loop)), ls.loops)
144+
foreach(loop -> loop_boundary!(lbd, loop), ls.loops)
142145
lbd
143146
end
144147

0 commit comments

Comments
 (0)