Skip to content

Commit 1f9e073

Browse files
committed
SOme more changes for 1.6...
1 parent bf688f8 commit 1f9e073

14 files changed

+162
-58
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
1414
VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
1515

1616
[compat]
17-
ArrayInterface = "2.13.7"
17+
ArrayInterface = "2.13.8"
1818
DocStringExtensions = "0.8"
1919
IfElse = "0"
2020
OffsetArrays = "1"

src/LoopVectorization.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@ using VectorizationBase: REGISTER_SIZE, REGISTER_COUNT, data,
1111
Zero, maybestaticrange, offsetprecalc,
1212
maybestaticfirst, maybestaticlast, scalar_less, gep, gesp, pointerforcomparison, NativeTypes,
1313
vfmadd, vfmsub, vfnmadd, vfnmsub, vfmadd231, vfmsub231, vfnmadd231, vfnmsub231, vadd, vsub, vmul,
14-
relu, stridedpointer, StridedPointer,
15-
reduced_add, reduced_prod, reduce_to_add, reduce_to_prod, reduced_max, reduced_min, reduce_to_max, reduce_to_min
14+
relu, stridedpointer, StridedPointer, AbstractStridedPointer,
15+
reduced_add, reduced_prod, reduce_to_add, reduce_to_prod, reduced_max, reduced_min, reduce_to_max, reduce_to_min,
16+
vsum, vprod, vmaximum, vminimum
1617

1718
using IfElse: ifelse
1819

@@ -35,13 +36,19 @@ using ArrayInterface
3536
using ArrayInterface: OptionallyStaticUnitRange, Zero
3637
const Static = ArrayInterface.StaticInt
3738

39+
# TODO: this is type piracy; move this elsewhere!
40+
VectorizationBase.memory_reference(A::OffsetArray) = VectorizationBase.memory_reference(parent(A))
41+
# ArrayInterface.parent_type(::Type{O}) where {T,N,A<:AbstractArray{T,N},O<:OffsetArray{T,N,A}} = A
42+
43+
3844
export LowDimArray, stridedpointer,
3945
@avx, @_avx, *ˡ, _avx_!,
4046
vmap, vmap!, vmapt, vmapt!, vmapnt, vmapnt!, vmapntt, vmapntt!,
4147
vfilter, vfilter!, vmapreduce, vreduce
4248

4349
@inline unwrap(::Val{N}) where {N} = N
4450
@inline unwrap(::Static{N}) where {N} = N
51+
@inline unwrap(x) = x
4552

4653
const VECTORWIDTHSYMBOL, ELTYPESYMBOL = Symbol("##Wvecwidth##"), Symbol("##Tloopeltype##")
4754

src/add_compute.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -340,22 +340,22 @@ function add_pow!(
340340
elseif pint == 1
341341
return add_compute!(ls, var, :identity, [xop], elementbytes)
342342
elseif pint == 2
343-
return add_compute!(ls, var, :vabs2, [xop], elementbytes)
343+
return add_compute!(ls, var, :abs2, [xop], elementbytes)
344344
end
345345

346346
# Implementation from https://github.com/JuliaLang/julia/blob/a965580ba7fd0e8314001521df254e30d686afbf/base/intfuncs.jl#L216
347347
t = trailing_zeros(pint) + 1
348348
pint >>= t
349349
while (t -= 1) > 0
350350
varname = (iszero(pint) && isone(t)) ? var : gensym(:pbs)
351-
xop = add_compute!(ls, varname, :vabs2, [xop], elementbytes)
351+
xop = add_compute!(ls, varname, :abs2, [xop], elementbytes)
352352
end
353353
yop = xop
354354
while pint > 0
355355
t = trailing_zeros(pint) + 1
356356
pint >>= t
357357
while (t -= 1) >= 0
358-
xop = add_compute!(ls, gensym(:pbs), :vabs2, [xop], elementbytes)
358+
xop = add_compute!(ls, gensym(:pbs), :abs2, [xop], elementbytes)
359359
end
360360
yop = add_compute!(ls, iszero(pint) ? var : gensym(:pbs), :vmul, [xop, yop], elementbytes)
361361
end

src/broadcast.jl

Lines changed: 63 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,34 @@
1+
2+
@inline stridedpointer_for_broadcast(A) = stridedpointer_for_broadcast(ArrayInterface.size(A), stridedpointer(A))
3+
@inline stridedpointer_for_broadcast(s, ptr) = ptr
4+
function stridedpointer_for_broadcast(s, ptr::VectorizationBase.AbstractStridedPointer)
5+
# FIXME: this is unsafe for AbstractStridedPointers
6+
throw("Broadcasting not currently supported for arrays where typeof(stridedpointer(A)) === $(typeof(ptr))")
7+
end
8+
@generated function stridedpointer_for_broadcast(s::Tuple{Vararg{Any,N}}, ptr::StridedPointer{T,N,C,B,R,X,O}) where {T,N,C,B,R,X,O}
9+
q = Expr(:block, Expr(:meta,:inline), :(strd = ptr.strd))
10+
strd_tup = Expr(:tuple)
11+
for n 1:N
12+
s_type = s.parameters[n]
13+
if s_type <: Static
14+
if s_type === Static{1}
15+
push!(strd_tup.args, Expr(:call, lv(:Zero)))
16+
else
17+
push!(strd_tup.args, :(strd[$n]))
18+
end
19+
else
20+
Xₙ_type = X.parameters[n]
21+
if Xₙ_type <: Static # FIXME; what to do here? Dynamic dispatch?
22+
push!(strd_tup.args, :(strd[$n]))
23+
else
24+
push!(strd_tup.args, :(Base.ifelse(isone(s[$n]), one($Xₙ_type), strd[$n])))
25+
end
26+
end
27+
end
28+
push!(q.args, :(@inbounds StridedPointer{$T,$N,$C,$B,$R}(ptr.p, $strd_tup, ptr.offsets)))
29+
q
30+
end
31+
132
struct Product{A,B}
233
a::A
334
b::B
@@ -114,25 +145,44 @@ end
114145
struct LowDimArray{D,T,N,A<:DenseArray{T,N}} <: DenseArray{T,N}
115146
data::A
116147
end
117-
@inline Base.pointer(A::LowDimArray) = pointer(A.data)
118148
Base.@propagate_inbounds Base.getindex(A::LowDimArray, i...) = getindex(A.data, i...)
119149
@inline Base.size(A::LowDimArray) = Base.size(A.data)
120150
@inline Base.size(A::LowDimArray, i) = Base.size(A.data, i)
121-
@generated function VectorizationBase.stridedpointer(A::LowDimArray{D,T,N}) where {D,T,N}
122-
smul = Expr(:(.), Expr(:(.), :LoopVectorization, QuoteNode(:VectorizationBase)), QuoteNode(:staticmul))
123-
multup = Expr(:tuple)
124-
for n D[1]+1:N
125-
if length(D) < n
126-
push!(multup.args, Expr(:call, :ifelse, :(isone(size(A,$n))), 0, Expr(:ref, :strideA, n)))
127-
elseif D[n]
128-
push!(multup.args, Expr(:ref, :strideA, n))
151+
@inline Base.strides(A::LowDimArray) = strides(A.data)
152+
@inline ArrayInterface.strides(A::LowDimArray) = ArrayInterface.strides(A.data)
153+
@generated function ArrayInterface.size(A::LowDimArray{D,T,N}) where {D,T,N}
154+
t = Expr(:tuple)
155+
for n 1:N
156+
if D[n]
157+
push!(t.args, Expr(:ref, :s, n))
158+
else
159+
push!(t.args, Expr(:call, Expr(:curly, lv(:Static), 1)))
129160
end
130161
end
131-
s = Expr(:call, smul, T, multup)
132-
f = D[1] ? :PackedStridedPointer : :SparseStridedPointer
133-
Expr(:block, Expr(:meta,:inline), Expr(:(=), :strideA, Expr(:call, :strides, Expr(:(.), :A, QuoteNode(:data)))),
134-
Expr(:call, Expr(:(.), :VectorizationBase, QuoteNode(f)), Expr(:call, :pointer, :A), s))
162+
Expr(:block, Expr(:meta,:inline), :(s = size(A)), t)
135163
end
164+
Base.parent(A::SizedOffsetMatrix) = A.data
165+
Base.unsafe_convert(::Type{Ptr{T}}, A::LowDimArray{D,T}) where {D,T} = pointer(A.data)
166+
ArrayInterface.contiguous_axis(A::LowDimArray) = ArrayInterface.contiguous_axis(A.data)
167+
ArrayInterface.contiguous_batch_size(A::LowDimArray) = ArrayInterface.contiguous_batch_size(A.data)
168+
ArrayInterface.stride_rank(A::LowDimArray) = ArrayInterface.stride_rank(A.data)
169+
ArrayInterface.offsets(A::LowDimArray) = ArrayInterface.offsets(A.data)
170+
171+
# @generated function VectorizationBase.stridedpointer(A::LowDimArray{D,T,N}) where {D,T,N}
172+
# smul = Expr(:(.), Expr(:(.), :LoopVectorization, QuoteNode(:VectorizationBase)), QuoteNode(:staticmul))
173+
# multup = Expr(:tuple)
174+
# for n ∈ D[1]+1:N
175+
# if length(D) < n
176+
# push!(multup.args, Expr(:call, :ifelse, :(isone(size(A,$n))), 0, Expr(:ref, :strideA, n)))
177+
# elseif D[n]
178+
# push!(multup.args, Expr(:ref, :strideA, n))
179+
# end
180+
# end
181+
# s = Expr(:call, smul, T, multup)
182+
# f = D[1] ? :PackedStridedPointer : :SparseStridedPointer
183+
# Expr(:block, Expr(:meta,:inline), Expr(:(=), :strideA, Expr(:call, :strides, Expr(:(.), :A, QuoteNode(:data)))),
184+
# Expr(:call, Expr(:(.), :VectorizationBase, QuoteNode(f)), Expr(:call, :pointer, :A), s))
185+
# end
136186
function LowDimArray{D}(data::A) where {D,T,N,A <: AbstractArray{T,N}}
137187
LowDimArray{D,T,N,A}(data)
138188
end

src/costs.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ const COST = Dict{Symbol,InstructionCost}(
152152
:reduce_to_prod => InstructionCost(0,0.0,0.0,0),
153153
:abs => InstructionCost(1, 0.5),
154154
:abs2 => InstructionCost(4,0.5),
155-
:vabs2 => InstructionCost(4,0.5),
155+
# :vabs2 => InstructionCost(4,0.5),
156156
:(==) => InstructionCost(1, 0.5),
157157
:(!=) => InstructionCost(1, 0.5),
158158
:(isnan) => InstructionCost(1, 0.5),

src/filter.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,21 @@ if (Base.libllvm_version ≥ v"7" && VectorizationBase.AVX512F) || Base.libllvm_
66
Nrep = N >>> Wshift
77
Nrem = N & (W - 1)
88
j = 0
9+
st = VectorizationBase.static_sizeof(T)
910
GC.@preserve x y begin
1011
ptr_x = pointer(x)
1112
ptr_y = pointer(y)
1213
for _ 1:Nrep
13-
vy = vload(Vec{W,T}, ptr_y)
14-
mask = f(Vec(vy))
15-
VectorizationBase.compressstore!(gep(ptr_x, j), vy, mask)
16-
ptr_y = gepbyte(ptr_y, VectorizationBase.REGISTER_SIZE)
14+
vy = vload(ptr_y, MM{W}(Static(0), st))
15+
mask = f(vy)
16+
VectorizationBase.compressstore!(gep(ptr_x, VectorizationBase.lazymul(st, j)), vy, mask)
17+
ptr_y = gep(ptr_y, VectorizationBase.REGISTER_SIZE)
1718
j = vadd(j, count_ones(mask))
1819
end
1920
rem_mask = VectorizationBase.mask(T, Nrem)
20-
vy = vload(Vec{W,T}, ptr_y, rem_mask)
21-
mask = rem_mask & f(Vec(vy))
22-
VectorizationBase.compressstore!(gep(ptr_x, j), vy, mask)
21+
vy = vload(ptr_y, MM{W}(Static(0), st), rem_mask)
22+
mask = rem_mask & f(vy)
23+
VectorizationBase.compressstore!(gep(ptr_x, VectorizationBase.lazymul(st, j)), vy, mask)
2324
j = vadd(j, count_ones(mask))
2425
Base._deleteend!(x, N-j) # resize!(x, j)
2526
end

src/loopstartstopmanager.jl

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ function indices_calculated_by_pointer_offsets(ls::LoopSet, ar::ArrayReferenceMe
5151
out
5252
end
5353

54+
@inline onetozeroindexgephack(sptr::AbstractStridedPointer) = gesp(sptr, (Static{-1}(),)) # go backwords
55+
@inline onetozeroindexgephack(sptr::AbstractStridedPointer{T,1}) where {T} = sptr
56+
@inline onetozeroindexgephack(x) = x
57+
5458
"""
5559
Returns a vector of length equal to the number of indices.
5660
A value > 0 indicates which loop number that index corresponds to when incrementing the pointer.
@@ -82,7 +86,8 @@ function use_loop_induct_var!(ls::LoopSet, q::Expr, ar::ArrayReferenceMeta, alla
8286
# else
8387
if (!li[i])
8488
uliv[i] = 0
85-
push!(gespinds.args, Expr(:call, lv(:Zero)))
89+
# push!(gespinds.args, Expr(:call, lv(:Zero)))
90+
push!(gespinds.args, Expr(:call, Expr(:curly, lv(:Static), 1)))
8691
push!(offsetprecalc_descript.args, 0)
8792
elseif isbroadcast ||
8893
((isone(ii) && (last(looporder) === ind)) && !(otherindexunrolled(ls, ind, ar)) ||
@@ -92,16 +97,23 @@ function use_loop_induct_var!(ls::LoopSet, q::Expr, ar::ArrayReferenceMeta, alla
9297

9398
# Not doing normal offset indexing
9499
uliv[i] = -findfirst(isequal(ind), looporder)::Int
95-
push!(gespinds.args, Expr(:call, lv(:Zero)))
100+
# push!(gespinds.args, Expr(:call, lv(:Zero)))
101+
push!(gespinds.args, Expr(:call, Expr(:curly, lv(:Static), 1)))
102+
96103
push!(offsetprecalc_descript.args, 0) # not doing offset indexing, so push 0
97104
else
98105
uliv[i] = findfirst(isequal(ind), looporder)::Int
99106
loop = getloop(ls, ind)
100107
if loop.startexact
101-
push!(gespinds.args, Expr(:call, Expr(:curly, lv(:Static), loop.starthint - 1)))
108+
push!(gespinds.args, Expr(:call, Expr(:curly, lv(:Static), loop.starthint)))
102109
else
103-
push!(gespinds.args, Expr(:call, lv(:staticm1), loop.startsym))
110+
push!(gespinds.args, loop.startsym)
104111
end
112+
# if loop.startexact
113+
# push!(gespinds.args, Expr(:call, Expr(:curly, lv(:Static), loop.starthint - 1)))
114+
# else
115+
# push!(gespinds.args, Expr(:call, lv(:staticm1), loop.startsym))
116+
# end
105117
if ind === names(ls)[us.vectorizedloopnum]
106118
push!(offsetprecalc_descript.args, 0)
107119
elseif (ind === names(ls)[us.u₁loopnum]) & (us.u₁ > 3)
@@ -115,11 +127,17 @@ function use_loop_induct_var!(ls::LoopSet, q::Expr, ar::ArrayReferenceMeta, alla
115127
end
116128
end
117129
end
130+
vptr_ar = if isone(length(li))
131+
# Workaround for fact that 1-d OffsetArrays are offset when using 1 index, but multi-dim ones are not
132+
Expr(:call, lv(:onetozeroindexgephack), vptr(ar))
133+
else
134+
vptr(ar)
135+
end
118136
if use_offsetprecalc
119-
push!(q.args, Expr(:(=), vptr(ar), Expr(:call, lv(:offsetprecalc), Expr(:call, lv(:gesp), vptr(ar), gespinds), Expr(:call, Expr(:curly, :Val, offsetprecalc_descript)))))
137+
push!(q.args, Expr(:(=), vptr(ar), Expr(:call, lv(:offsetprecalc), Expr(:call, lv(:gesp), vptr_ar, gespinds), Expr(:call, Expr(:curly, :Val, offsetprecalc_descript)))))
120138
else
121-
push!(q.args, Expr(:(=), vptr(ar), Expr(:call, lv(:gesp), vptr(ar), gespinds)))
122-
end
139+
push!(q.args, Expr(:(=), vptr(ar), Expr(:call, lv(:gesp), vptr_ar, gespinds)))
140+
end
123141
uliv
124142
end
125143

@@ -181,8 +199,10 @@ function pointermax(ls::LoopSet, ar::ArrayReferenceMeta, n::Int, sub::Int, isvec
181199
# @unpack u₁loopnum, u₂loopnum, vectorizedloopnum, u₁, u₂ = us
182200
loopsym = names(ls)[n]
183201
index = Expr(:tuple)
202+
found_loop_sym = false
184203
for i getindicesonly(ar)
185204
if i === loopsym
205+
found_loop_sym = true
186206
if iszero(sub)
187207
push!(index.args, stophint)
188208
elseif isvectorized
@@ -195,36 +215,42 @@ function pointermax(ls::LoopSet, ar::ArrayReferenceMeta, n::Int, sub::Int, isvec
195215
push!(index.args, staticexpr(stophint - sub))
196216
end
197217
ptr = vptr(ar)
198-
return Expr(:call, lv(:pointerforcomparison), ptr, index)
218+
# return
199219
else
200220
push!(index.args, Expr(:call, lv(:Zero)))
201221
end
202222
end
203-
@show ar, loopsym
223+
@assert found_loop_sym "Failed to find $loopsym"
224+
Expr(:call, lv(:pointerforcomparison), ptr, index)
225+
# @show ar, loopsym
204226
end
205227
function pointermax(ls::LoopSet, ar::ArrayReferenceMeta, n::Int, sub::Int, isvectorized::Bool, stopsym)::Expr
206228
# @unpack u₁loopnum, u₂loopnum, vectorizedloopnum, u₁, u₂ = us
207229
loopsym = names(ls)[n]
208230
index = Expr(:tuple)
231+
found_loop_sym = false
209232
for i getindicesonly(ar)
210233
if i === loopsym
234+
found_loop_sym = true
211235
if iszero(sub)
212236
push!(index.args, stopsym)
213237
elseif isvectorized
214238
if isone(sub)
215-
push!(index.args, Expr(:call, lv(:valsub), stopsym, VECTORWIDTHSYMBOL))
239+
push!(index.args, Expr(:call, lv(:vsub), stopsym, VECTORWIDTHSYMBOL))
216240
else
217-
push!(index.args, Expr(:call, lv(:vsub), stopsym, Expr(:call, lv(:valmul), VECTORWIDTHSYMBOL, sub)))
241+
push!(index.args, Expr(:call, lv(:vsub), stopsym, Expr(:call, lv(:vmul), VECTORWIDTHSYMBOL, sub)))
218242
end
219243
else
220244
push!(index.args, Expr(:call, lv(:vsub), stopsym, sub))
221245
end
222-
return Expr(:call, lv(:pointerforcomparison), vptr(ar), index)
246+
# return
223247
else
224248
push!(index.args, Expr(:call, lv(:Zero)))
225249
end
226250
end
227-
@show ar, loopsym
251+
@assert found_loop_sym "Failed to find $loopsym"
252+
Expr(:call, lv(:pointerforcomparison), vptr(ar), index)
253+
# @show ar, loopsym
228254
end
229255

230256
function defpointermax(ls::LoopSet, ar::ArrayReferenceMeta, n::Int, sub::Int, isvectorized::Bool)::Expr
@@ -280,7 +306,7 @@ function offset_ptr(ar::ArrayReferenceMeta, us::UnrollSpecification, loopsym::Sy
280306
else
281307
push!(gespinds.args, Expr(:call, lv(:Zero)))
282308
end
283-
ind == loopsym && break
309+
# ind == loopsym && break
284310
end
285311
Expr(:(=), vptr(ar), Expr(:call, lv(:gesp), vptr(ar), gespinds))
286312
end

src/lower_constant.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,23 @@ function should_broadcast_op(op::Operation)
77
true
88
end
99

10+
11+
@inline sizeequivalentfloat(::Type{T}) where {T<:Union{Float16,Float32,Float64}} = T
12+
@inline sizeequivalentfloat(::Type{T}) where {T <: Union{Int8,UInt8}} = Float32
13+
@inline sizeequivalentfloat(::Type{T}) where {T <: Union{Int16,UInt16}} = Float16
14+
@inline sizeequivalentfloat(::Type{T}) where {T <: Union{Int32,UInt32}} = Float32
15+
@inline sizeequivalentfloat(::Type{T}) where {T <: Union{Int64,UInt64}} = Float64
16+
@inline sizeequivalentint(::Type{T}) where {T <: Integer} = T
17+
@inline sizeequivalentint(::Type{Float16}) = Int16
18+
@inline sizeequivalentint(::Type{Float32}) = Int32
19+
@inline sizeequivalentfloat(::Type{T}, x) where {T} = sizeequivalentfloat(T)(x)
20+
@inline sizeequivalentint(::Type{T}, x) where {T} = sizeequivalentint(T)(x)
21+
if VectorizationBase.AVX512DQ || !((Sys.ARCH === :x86_64) || (Sys.ARCH === :i686))
22+
@inline sizeequivalentint(::Type{Float64}) = Int64
23+
else
24+
@inline sizeequivalentint(::Type{Float64}) = Int32
25+
end
26+
1027
# @inline onefloat(::Type{T}) where {T} = one(sizeequivalentfloat(T))
1128
# @inline oneinteger(::Type{T}) where {T} = one(sizeequivalentint(T))
1229
@inline zerofloat(::Type{T}) where {T} = zero(sizeequivalentfloat(T))

src/lowering.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ function reduce_expr!(q::Expr, ls::LoopSet, U::Int)
544544
reduce_expr!(q, mvar, instr, U)
545545
if !iszero(length(ls.opdict))
546546
if (isu₁unrolled(op) | isu₂unrolled(op))
547-
push!(q.args, Expr(:(=), var, Expr(:call, lv(reduction_scalar_combine(instr)), var, Symbol(mvar, 0))))
547+
push!(q.args, Expr(:(=), var, Expr(:call, lv(reduction_scalar_combine(instr)), Symbol(mvar, 0), var)))
548548
else
549549
push!(q.args, Expr(:(=), var, mvar))
550550
end

src/mapreduce.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ function vmapreduce(f::F, op::OP, arg1::DenseArray{T}, args::Vararg{DenseArray{T
4242
_vmapreduce(f, op, V, N, T, arg1, args...)
4343
end
4444
end
45-
function _vmapreduce(f::F, op::OP, ::Val{W}, N, ::Type{T}, args::Vararg{DenseArray{<:NativeTypes},A}) where {F,OP,A,W,T}
45+
function _vmapreduce(f::F, op::OP, ::StaticInt{W}, N, ::Type{T}, args::Vararg{DenseArray{<:NativeTypes},A}) where {F,OP,A,W,T}
4646
ptrargs = pointer.(args)
4747
a_0 = f(vload.(Val{W}(), ptrargs)...); i = W
4848
if N 4W

0 commit comments

Comments
 (0)