Skip to content

Commit accdebd

Browse files
committed
Update to Polyester 0.4, fixes #330.
1 parent 83f1e75 commit accdebd

File tree

11 files changed

+359
-294
lines changed

11 files changed

+359
-294
lines changed

Project.toml

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.12.66"
4+
version = "0.12.67"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
88
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
99
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
10+
LayoutPointers = "10f19ff3-798f-405d-979b-55457f8fc047"
1011
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1112
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
1213
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
@@ -19,16 +20,17 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
1920
VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
2021

2122
[compat]
22-
ArrayInterface = "3.1.9 - 3.1.23"
23+
ArrayInterface = "3.1.9 - 3.1.23, 3.1.25"
2324
DocStringExtensions = "0.8"
2425
IfElse = "0.1"
26+
LayoutPointers = "0.1.2"
2527
OffsetArrays = "1.4.1"
26-
Polyester = "0.3"
28+
Polyester = "0.4.0"
2729
Requires = "1"
2830
SLEEFPirates = "0.6.23"
2931
Static = "0.2, 0.3"
30-
StrideArraysCore = "0.1.12"
32+
StrideArraysCore = "0.2"
3133
ThreadingUtilities = "0.4.5"
3234
UnPack = "1"
33-
VectorizationBase = "0.20.36"
35+
VectorizationBase = "0.21.1"
3436
julia = "1.5"

src/LoopVectorization.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ using VectorizationBase: register_size, register_count, cache_linesize, cache_si
1010
maybestaticfirst, maybestaticlast, gep, gesp, NativeTypes, #llvmptr,
1111
vfmadd, vfmsub, vfnmadd, vfnmsub, vfmadd_fast, vfmsub_fast, vfnmadd_fast, vfnmsub_fast, vfmadd231, vfmsub231, vfnmadd231, vfnmsub231,
1212
vfma_fast, vmuladd_fast, vdiv_fast, vadd_fast, vsub_fast, vmul_fast,
13-
relu, stridedpointer, stridedpointer_preserve, StridedPointer, StridedBitPointer, AbstractStridedPointer, _vload, _vstore!,
13+
relu, stridedpointer, StridedPointer, StridedBitPointer, AbstractStridedPointer, _vload, _vstore!,
1414
reduced_add, reduced_prod, reduce_to_add, reduce_to_prod, reduced_max, reduced_min, reduce_to_max, reduce_to_min,
1515
reduced_all, reduced_any, reduce_to_all, reduce_to_any,
1616
vsum, vprod, vmaximum, vminimum, vany, vall, unwrap, Unroll, VecUnroll,
@@ -22,9 +22,11 @@ using VectorizationBase: register_size, register_count, cache_linesize, cache_si
2222
contract_and, collapse_and,
2323
contract_or, collapse_or,
2424
num_threads, num_cores,
25-
max_mask#,zero_mask
25+
max_mask, maybestaticsize#,zero_mask
2626

27-
using VectorizationBase: maybestaticsize # for compatibility
27+
28+
29+
using LayoutPointers: stridedpointer_preserve, GroupedStridedPointers
2830

2931
using IfElse: ifelse
3032

@@ -39,7 +41,7 @@ using Base.FastMath: add_fast, sub_fast, mul_fast, div_fast, inv_fast, abs2_fast
3941
using SLEEFPirates: log_fast, log2_fast, log10_fast, pow, sin_fast, cos_fast, sincos_fast
4042

4143
using ArrayInterface
42-
using ArrayInterface: OptionallyStaticUnitRange, OptionallyStaticRange, Zero, One, StaticBool, True, False, reduce_tup, indices, UpTri, LoTri
44+
using ArrayInterface: OptionallyStaticUnitRange, OptionallyStaticRange, Zero, One, StaticBool, True, False, reduce_tup, indices, UpTri, LoTri, strides, offsets, size, StrideIndex
4345
using StrideArraysCore: CloseOpen, PtrArray
4446
# @static if VERSION ≥ v"1.6.0-rc1" #TODO: delete `else` when dropping 1.5 support
4547
# using ArrayInterface: static_step

src/broadcast.jl

Lines changed: 140 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,21 @@ end
77
Base.@propagate_inbounds Base.getindex(A::LowDimArray, i::Vararg{Union{Integer,CartesianIndex},K}) where {K} = getindex(A.data, i...)
88
@inline Base.size(A::LowDimArray) = Base.size(A.data)
99
@inline Base.size(A::LowDimArray, i) = Base.size(A.data, i)
10-
@inline Base.strides(A::LowDimArray) = strides(A.data)
10+
1111
@inline ArrayInterface.parent_type(::Type{LowDimArray{D,T,N,A}}) where {T,D,N,A} = A
12-
@inline ArrayInterface.strides(A::LowDimArray) = ArrayInterface.strides(A.data)
12+
@inline Base.strides(A::LowDimArray) = map(Int, strides(A))
1313
@inline ArrayInterface.device(A::LowDimArray) = ArrayInterface.CPUPointer()
1414
@generated function ArrayInterface.size(A::LowDimArray{D,T,N}) where {D,T,N}
15-
t = Expr(:tuple)
16-
gf = GlobalRef(Core,:getfield)
17-
for n 1:N
18-
if n > length(D) || D[n]
19-
push!(t.args, Expr(:call, gf, :s, n, false))
20-
else
21-
push!(t.args, Expr(:call, Expr(:curly, lv(:StaticInt), 1)))
22-
end
15+
t = Expr(:tuple)
16+
gf = GlobalRef(Core,:getfield)
17+
for n 1:N
18+
if n > length(D) || D[n]
19+
push!(t.args, Expr(:call, gf, :s, n, false))
20+
else
21+
push!(t.args, Expr(:call, Expr(:curly, lv(:StaticInt), 1)))
2322
end
24-
Expr(:block, Expr(:meta,:inline), :(s = ArrayInterface.size(parent(A))), t)
23+
end
24+
Expr(:block, Expr(:meta,:inline), :(s = ArrayInterface.size(parent(A))), t)
2525
end
2626
Base.parent(A::LowDimArray) = getfield(A, :data)
2727
Base.unsafe_convert(::Type{Ptr{T}}, A::LowDimArray{D,T}) where {D,T} = pointer(parent(A))
@@ -31,103 +31,174 @@ ArrayInterface.stride_rank(A::LowDimArray) = ArrayInterface.stride_rank(parent(A
3131
ArrayInterface.offsets(A::LowDimArray) = ArrayInterface.offsets(parent(A))
3232

3333
@generated function _lowdimfilter(::Val{D}, tup::Tuple{Vararg{Any,N}}) where {D,N}
34-
t = Expr(:tuple)
35-
gf = GlobalRef(Core,:getfield)
36-
for n 1:N
37-
if n > length(D) || D[n]
38-
push!(t.args, Expr(:call, gf, :tup, n, false))
39-
end
34+
t = Expr(:tuple)
35+
gf = GlobalRef(Core,:getfield)
36+
for n 1:N
37+
if n > length(D) || D[n]
38+
push!(t.args, Expr(:call, gf, :tup, n, false))
4039
end
41-
Expr(:block, Expr(:meta,:inline), t)
40+
end
41+
Expr(:block, Expr(:meta,:inline), t)
4242
end
4343

4444
struct ForBroadcast{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
45-
data::A
45+
data::A
4646
end
4747
@inline Base.parent(fb::ForBroadcast) = getfield(fb, :data)
48+
@inline ArrayInterface.parent_type(::Type{ForBroadcast{T,N,A}}) where {T,N,A} = A
4849
Base.@propagate_inbounds Base.getindex(A::ForBroadcast, i::Vararg{Any,K}) where {K} = parent(A)[i...]
4950
const LowDimArrayForBroadcast{D,T,N,A} = ForBroadcast{T,N,LowDimArray{D,T,N,A}}
5051
@inline function VectorizationBase.contiguous_axis(fb::LowDimArrayForBroadcast{D,T,N,A}) where {D,T,N,A}
51-
_contiguous_axis(Val{D}(), VectorizationBase.contiguous_axis(parent(parent(fb))))
52+
_contiguous_axis(Val{D}(), VectorizationBase.contiguous_axis(parent(parent(fb))))
5253
end
5354
@inline forbroadcast(A::AbstractArray) = ForBroadcast(A)
5455
@inline forbroadcast(A::AbstractRange) = A
5556
@inline forbroadcast(A) = A
5657
@inline forbroadcast(A::Adjoint) = forbroadcast(parent(A))
5758
@inline forbroadcast(A::Transpose) = forbroadcast(parent(A))
58-
Base.size(fb::ForBroadcast) = size(parent(fb))
59+
@inline function ArrayInterface.strides(A::Union{LowDimArray,ForBroadcast})
60+
B = parent(A)
61+
_strides(size(A), strides(B), VectorizationBase.val_stride_rank(B), VectorizationBase.val_dense_dims(B))
62+
end
63+
5964

6065
# @inline function VectorizationBase.contiguous_batch_size(fb::LowDimArrayForBroadcast{D,T,N,A}) where {D,T,N,A}
6166
# _contiguous_axis(Val{D}(), VectorizationBase.contiguous_batch_size(parent(parent(fb))))
6267
# end
6368
@generated function _contiguous_axis(::Val{D}, ::StaticInt{C}) where {D,C}
64-
Dlen = length(D)
65-
(C > 0) || return Expr(:block,Expr(:meta,:inline), staticexpr(C))
66-
if C Dlen
67-
D[C] || return Expr(:block,Expr(:meta,:inline), staticexpr(-1))
68-
end
69-
Cnew = 0
70-
for n 1:C
71-
Cnew += ((n > Dlen)) || D[n]
72-
end
73-
Expr(:block,Expr(:meta,:inline), staticexpr(Cnew))
69+
Dlen = length(D)
70+
(C > 0) || return Expr(:block,Expr(:meta,:inline), staticexpr(C))
71+
if C Dlen
72+
D[C] || return Expr(:block,Expr(:meta,:inline), staticexpr(-1))
73+
end
74+
Cnew = 0
75+
for n 1:C
76+
Cnew += ((n > Dlen)) || D[n]
77+
end
78+
Expr(:block,Expr(:meta,:inline), staticexpr(Cnew))
7479
end
75-
@inline function VectorizationBase.val_stride_rank(fb::LowDimArrayForBroadcast{D}) where {D}
76-
VectorizationBase.asvalint(_lowdimfilter(Val(D), ArrayInterface.stride_rank(parent(parent(fb)))))
80+
@inline function ArrayInterface.stride_rank(::Type{LowDimArrayForBroadcast{D,T,N,A}}) where {D,T,N,A}
81+
_lowdimfilter(Val(D), ArrayInterface.stride_rank(A))
7782
end
78-
@inline function VectorizationBase.val_dense_dims(fb::LowDimArrayForBroadcast{D}) where {D}
79-
VectorizationBase.asvalbool(_lowdimfilter(Val(D), ArrayInterface.dense_dims(parent(parent(fb)))))
83+
@inline function ArrayInterface.dense_dims(::Type{LowDimArrayForBroadcast{D,T,N,A}}) where {D,T,N,A}
84+
_lowdimfilter(Val(D), ArrayInterface.dense_dims(A))
8085
end
81-
@inline function VectorizationBase.bytestrides(fb::LowDimArrayForBroadcast{D}) where {D}
82-
p = parent(parent(fb))
83-
s = _bytestrides(ArrayInterface.size(p), ArrayInterface.strides(p), p)
84-
_lowdimfilter(Val(D), s)
86+
@inline function ArrayInterface.strides(fb::LowDimArrayForBroadcast{D}) where {D}
87+
_lowdimfilter(Val(D), strides(parent(fb)))
8588
end
8689
@inline function ArrayInterface.offsets(fb::LowDimArrayForBroadcast{D}) where {D}
87-
_lowdimfilter(Val(D), ArrayInterface.offsets(parent(parent(fb))))
90+
_lowdimfilter(Val(D), ArrayInterface.offsets(parent(parent(fb))))
91+
end
92+
@inline function ArrayInterface.StrideIndex(a::A) where {A<:LowDimArrayForBroadcast}
93+
_stride_index(ArrayInterface.stride_rank(A), ArrayInterface.contiguous_axis(A), a)
94+
end
95+
@inline function _stride_index(r::Tuple{Vararg{StaticInt,N}}, ::StaticInt{C}, A) where {N,C}
96+
StrideIndex{N,ArrayInterface.known(r),C}(A)
8897
end
8998

99+
for f [ # groupedstridedpointer support
100+
:(ArrayInterface.contiguous_axis),
101+
:(ArrayInterface.contiguous_batch_size),
102+
:(ArrayInterface.device),
103+
:(ArrayInterface.stride_rank),
104+
]
105+
@eval @inline $f(::Type{ForBroadcast{T,N,A}}) where {T,N,A} = $f(A)
106+
end
90107
for f [ # groupedstridedpointer support
91108
:(VectorizationBase.memory_reference),
92109
:(ArrayInterface.contiguous_axis),
93110
:(ArrayInterface.contiguous_batch_size),
94111
:(ArrayInterface.device),
95-
:(VectorizationBase.val_stride_rank),
112+
:(ArrayInterface.stride_rank),
96113
:(VectorizationBase.val_dense_dims),
97-
:(ArrayInterface.offsets)
114+
:(ArrayInterface.offsets),
115+
:(Base.size)#, :(ArrayInterface.strides)
98116
]
99117
@eval @inline $f(fb::ForBroadcast) = $f(getfield(fb, :data))
100118
end
101-
@inline _bytestrides(s,paren) = VectorizationBase.bytestrides(paren)
102-
@generated function _bytestrides(s::Tuple{Vararg{Integer,N}}, x::Tuple{Vararg{Integer,N}}, paren::AbstractArray{T,N}) where {T,N}
103-
q = Expr(:block, Expr(:meta,:inline))
104-
strd_tup = Expr(:tuple)
105-
gf = GlobalRef(Core, :getfield)
106-
ifel = GlobalRef(Core, :ifelse)
107-
st = staticexpr(sizeof(T))
108-
for n 1:N
109-
s_type = s.parameters[n]
110-
if s_type <: Static
111-
if s_type === One
112-
push!(strd_tup.args, Expr(:call, lv(:Zero)))
113-
else
114-
push!(strd_tup.args, :($gf(x, $n, false) * $st))
115-
end
119+
120+
function is_column_major(x)
121+
for (i, j) enumerate(x)
122+
i == j || return false
123+
end
124+
true
125+
end
126+
is_row_major(x) = is_column_major(reverse(x))
127+
# @inline _bytestrides(s,paren) = VectorizationBase.bytestrides(paren)
128+
function _strides_expr(@nospecialize(s), @nospecialize(x), R::Vector{Int}, D::Vector{Bool})
129+
N = length(R)
130+
q = Expr(:block, Expr(:meta,:inline))
131+
strd_tup = Expr(:tuple)
132+
gf = GlobalRef(Core, :getfield)
133+
ifel = GlobalRef(Core, :ifelse)
134+
Nrange = 1:1:N # type stability w/ respect to reverse
135+
use_stride_acc = true
136+
stride_acc::Int = 1
137+
if is_column_major(R)
138+
elseif is_row_major(R)
139+
Nrange = reverse(Nrange)
140+
else # not worth my time optimizing this case at the moment...
141+
# will write something generic stride-rank agnostic eventually
142+
use_stride_acc = false
143+
stride_acc = 0
144+
end
145+
sₙ_value::Int = 0
146+
for n Nrange
147+
xₙ_type = x[n]
148+
# xₙ_type = typeof(x).parameters[n]
149+
xₙ_static = xₙ_type <: Static
150+
xₙ_value::Int = xₙ_static ? (xₙ_type.parameters[1])::Int : 0
151+
s_type = s[n]
152+
# s_type = typeof(s).parameters[n]
153+
sₙ_static = s_type <: Static
154+
if sₙ_static
155+
sₙ_value = s_type.parameters[1]
156+
if s_type === One
157+
push!(strd_tup.args, Expr(:call, lv(:Zero)))
158+
elseif stride_acc 0
159+
push!(strd_tup.args, staticexpr(stride_acc))
160+
else
161+
push!(strd_tup.args, :($gf(x, $n, false)))
162+
end
163+
else
164+
if xₙ_static
165+
push!(strd_tup.args, staticexpr(xₙ_value))
166+
elseif stride_acc 0
167+
push!(strd_tup.args, staticexpr(stride_acc))
168+
else
169+
push!(strd_tup.args, :($ifel(isone($gf(s, $n, false)), zero($xₙ_type), $gf(x, $n, false))))
170+
end
171+
end
172+
if (n last(Nrange)) && use_stride_acc
173+
nnext = n + step(Nrange)
174+
if D[nnext]
175+
if xₙ_static & sₙ_static
176+
stride_acc = xₙ_value * sₙ_value
177+
elseif sₙ_static
178+
if stride_acc 0
179+
stride_acc *= sₙ_value
180+
end
116181
else
117-
Xₙ_type = x.parameters[n]
118-
if Xₙ_type <: Static # FIXME; what to do here? Dynamic dispatch?
119-
push!(strd_tup.args, :($gf(x, $n, false)*$st))
120-
else
121-
push!(strd_tup.args, :($ifel(isone($gf(s, $n, false)), zero($Xₙ_type), $gf(x, $n, false)*$st)))
122-
end
182+
stride_acc = 0
123183
end
184+
else
185+
stride_acc = 0
186+
end
124187
end
125-
push!(q.args, strd_tup)
126-
q
127-
end
128-
@inline function VectorizationBase.bytestrides(fb::ForBroadcast)
129-
p = getfield(fb,:data)
130-
_bytestrides(ArrayInterface.size(p), ArrayInterface.strides(p), p)
188+
end
189+
push!(q.args, strd_tup)
190+
q
191+
end
192+
@generated function _strides(
193+
s::Tuple{Vararg{Integer,N}}, x::Tuple{Vararg{Integer,N}}, ::Val{R}, ::Val{D}
194+
) where {N,R,D}
195+
Rv = Vector{Int}(undef, N)
196+
Dv = Vector{Bool}(undef, N)
197+
for n in 1:N
198+
Rv[n] = R[n]
199+
Dv[n] = D[n]
200+
end
201+
_strides_expr(s.parameters, x.parameters, Rv, Dv)
131202
end
132203

133204
struct Product{A,B}

0 commit comments

Comments
 (0)