Skip to content

Commit 3d33af4

Browse files
committed
LoopVectorization now checks array arguments, possibly falling back to loop-as-written if recursive-parents aren't StridedArrays of Union{Bool,Base.HWReal} element type.
1 parent 7204ebe commit 3d33af4

File tree

9 files changed

+91
-17
lines changed

9 files changed

+91
-17
lines changed

src/LoopVectorization.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ using LinearAlgebra: Adjoint, Transpose
1818
using Base.Meta: isexpr
1919
using DocStringExtensions
2020

21+
using Base.FastMath: add_fast, sub_fast, mul_fast, div_fast
2122

2223
const NativeTypes = Union{Bool, Base.HWReal}
2324

src/condense_loopset.jl

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -296,11 +296,24 @@ concat_vals() = Val{()}()
296296
Expr(:call, Expr(:curly, :Val, tup))
297297
end
298298

299-
# check_valid_args() = true
300-
# check_valid_args(::Any) = false
301-
# check_valid_args(::T) where {T <: Union{Base.HWReal, Bool}} = true
302-
# check_valid_args(::StridedArray{T}) where {T <: Union{Base.HWReal, Bool}} = true
303-
# check_valid_args(a, b, args...) = check_valid_args(a) && check_valid_args(b) && check_valid_args(args....)
299+
300+
# Courtesy of mcabbott
301+
@inline function check_args(A::AbstractArray)
302+
P = parent(A)
303+
if typeof(P) === typeof(A)
304+
eltype(A) <: NativeTypes && typeof(A) <: Union{StridedArray, AbstractRange}
305+
else
306+
check_args(P)
307+
end
308+
end
309+
@inline check_args(A, Bs...) = check_args(A) && check_args(Bs...)
310+
311+
function check_args_call(ls::LoopSet)
312+
q = Expr(:call, lv(:check_args))
313+
append!(q.args, ls.includedactualarrays)
314+
q
315+
end
316+
304317

305318
function setup_call_noinline(ls::LoopSet, U = zero(Int8), T = zero(Int8))
306319
call = generate_call(ls, (false,U,T))
@@ -410,15 +423,17 @@ function setup_call_debug(ls::LoopSet)
410423
pushpreamble!(ls, generate_call(ls, (true,zero(Int8),zero(Int8)), true))
411424
ls.preamble
412425
end
413-
function setup_call(ls::LoopSet, inline::Bool = true, u₁ = zero(Int8), u₂ = zero(Int8))
426+
function setup_call(ls::LoopSet, q = nothing, inline::Bool = true, u₁ = zero(Int8), u₂ = zero(Int8))
414427
# We outline/inline at the macro level by creating/not creating an anonymous function.
415428
# The old API instead was based on inlining or not inline the generated function, but
416429
# the generated function must be inlined into the initial loop preamble for performance reasons.
417430
# Creating an anonymous function and calling it also achieves the outlining, while still
418431
# inlining the generated function into the loop preamble.
419-
if inline
432+
call = if inline
420433
setup_call_inline(ls, u₁, u₂)
421434
else
422435
setup_call_noinline(ls, u₁, u₂)
423436
end
437+
isnothing(q) && return call
438+
Expr(:if, check_args_call(ls), call, q)
424439
end

src/constructors.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ and `uᵢ=-1` disables unrolling for the correspond loop.
121121
macro avx(q)
122122
q = macroexpand(__module__, q)
123123
q2 = if q.head === :for
124-
setup_call(LoopSet(q, __module__))
124+
setup_call(LoopSet(q, __module__), q)
125125
else# assume broadcast
126126
substitute_broadcast(q, Symbol(__module__))
127127
end
@@ -173,14 +173,14 @@ macro avx(arg, q)
173173
q = macroexpand(__module__, q)
174174
inline, u₁, u₂ = check_macro_kwarg(arg)
175175
ls = LoopSet(q, __module__)
176-
esc(setup_call(ls, inline, u₁, u₂))
176+
esc(setup_call(ls, q, inline, u₁, u₂))
177177
end
178178
macro avx(arg1, arg2, q)
179179
@assert q.head === :for
180180
q = macroexpand(__module__, q)
181181
inline, u₁, u₂ = check_macro_kwarg(arg1)
182182
inline, u₁, u₂ = check_macro_kwarg(arg2, inline, u₁, u₂)
183-
esc(setup_call(LoopSet(q, __module__), inline, u₁, u₂))
183+
esc(setup_call(LoopSet(q, __module__), q, inline, u₁, u₂))
184184
end
185185

186186

src/costs.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,17 @@ const COST = Dict{Symbol,InstructionCost}(
127127
:(*) => InstructionCost(4,0.5),
128128
:(/) => InstructionCost(13,4.0,-2.0),
129129
:vadd => InstructionCost(4,0.5),
130+
:add_fast => InstructionCost(4,0.5),
130131
:vsub => InstructionCost(4,0.5),
132+
:sub_fast => InstructionCost(4,0.5),
131133
:vadd! => InstructionCost(4,0.5),
132134
:vsub! => InstructionCost(4,0.5),
133135
:vmul! => InstructionCost(4,0.5),
134136
:vmul => InstructionCost(4,0.5),
137+
:mul_fast => InstructionCost(4,0.5),
135138
:vfdiv => InstructionCost(13,4.0,-2.0),
136139
:vfdiv! => InstructionCost(13,4.0,-2.0),
140+
:div_fast => InstructionCost(13,4.0,-2.0),
137141
:evadd => InstructionCost(4,0.5),
138142
:evsub => InstructionCost(4,0.5),
139143
:evmul => InstructionCost(4,0.5),

test/fallback.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
2+
@testset "Fall back behavior" begin
3+
4+
function msd(x)
5+
s = zero(eltype(x))
6+
for i in eachindex(x)
7+
s += x[i] * x[i]
8+
end
9+
s
10+
end
11+
function msdavx(x)
12+
s = zero(eltype(x))
13+
@avx for i in eachindex(x)
14+
s += x[i] * x[i]
15+
end
16+
s
17+
end
18+
19+
x = fill(sqrt(63.0), 128);
20+
x[1] = 1e9
21+
22+
@test LoopVectorization.check_args(x)
23+
@test LoopVectorization.check_args(x, x)
24+
@test LoopVectorization.check_args(x, x, x)
25+
@test !LoopVectorization.check_args(FallbackArrayWrapper(x))
26+
@test !LoopVectorization.check_args(FallbackArrayWrapper(x), x, x)
27+
@test !LoopVectorization.check_args(x, FallbackArrayWrapper(x))
28+
@test !LoopVectorization.check_args(x, FallbackArrayWrapper(x), x)
29+
@test !LoopVectorization.check_args(x, x, FallbackArrayWrapper(x))
30+
@test !LoopVectorization.check_args(x, x, FallbackArrayWrapper(x), FallbackArrayWrapper(x))
31+
32+
@test msdavx(FallbackArrayWrapper(x)) == 1e18
33+
@test msd(x) == msdavx(FallbackArrayWrapper(x))
34+
@test msdavx(x) != msdavx(FallbackArrayWrapper(x))
35+
end
36+

test/gemm.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -568,15 +568,16 @@
568568
# end)
569569
# lsmul2x2q = LoopVectorization.LoopSet(mul2x2q)
570570

571-
struct SizedMatrix{M,N,T} <: AbstractMatrix{T}
571+
struct SizedMatrix{M,N,T} <: DenseMatrix{T}
572572
data::Matrix{T}
573573
end
574+
Base.parent(A::SizedMatrix) = A.data
574575
SizedMatrix{M,N}(A::Matrix{T}) where {M,N,T} = SizedMatrix{M,N,T}(A)
575-
Base.@propagate_inbounds Base.getindex(A::SizedMatrix, i...) = getindex(A.data, i...)
576-
Base.@propagate_inbounds Base.setindex!(A::SizedMatrix, v, i...) = setindex!(A.data, v, i...)
576+
Base.@propagate_inbounds Base.getindex(A::SizedMatrix, i...) = getindex(parent(A), i...)
577+
Base.@propagate_inbounds Base.setindex!(A::SizedMatrix, v, i...) = setindex!(parent(A), v, i...)
577578
Base.size(::SizedMatrix{M,N}) where {M,N} = (M,N)
578579
@inline function LoopVectorization.stridedpointer(A::SizedMatrix{M,N,T}) where {M,N,T}
579-
LoopVectorization.StaticStridedPointer{T,Tuple{1,M}}(pointer(A.data))
580+
LoopVectorization.StaticStridedPointer{T,Tuple{1,M}}(pointer(parent(A)))
580581
end
581582
@inline function LoopVectorization.stridedpointer(A::LinearAlgebra.Adjoint{T,SizedMatrix{M,N,T}}) where {M,N,T}
582583
LoopVectorization.StaticStridedPointer{T,Tuple{M,1}}(pointer(parent(A).data))

test/gemv.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,11 @@ using Test
153153
end
154154
end)
155155
lsp = LoopVectorization.LoopSet(pq);
156-
# @test LoopVectorization.choose_order(lsp) == ([:d1, :d2], :d2, :d1, :d2, Unum, Tnum)
157-
@test LoopVectorization.choose_order(lsp) == ([:d1, :d2], :d1, :d2, :d2, Unum, Tnum)
156+
if LoopVectorization.VectorizationBase.REGISTER_COUNT == 16
157+
@test LoopVectorization.choose_order(lsp) == ([:d1, :d2], :d2, :d1, :d2, Unum, Tnum)
158+
else
159+
@test LoopVectorization.choose_order(lsp) == ([:d1, :d2], :d1, :d2, :d2, Unum, Tnum)
160+
end
158161
# lsp.preamble_symsym
159162

160163
function hhavx!(A::AbstractVector{T}, B, C, D) where {T}

test/offsetarrays.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,12 @@ T = Float64
8989
data::Matrix{T}
9090
end
9191
Base.axes(::SizedOffsetMatrix{T,LR,UR,LC,UC}) where {T,LR,UR,LC,UC} = (StaticUnitRange{LR,UR}(),StaticUnitRange{LC,UC}())
92+
Base.parent(A::SizedOffsetMatrix) = A.data
9293
@generated function LoopVectorization.stridedpointer(A::SizedOffsetMatrix{T,LR,UR,LC,RC}) where {T,LR,UR,LC,RC}
9394
quote
9495
$(Expr(:meta,:inline))
9596
LoopVectorization.OffsetStridedPointer(
96-
LoopVectorization.StaticStridedPointer{$T,Tuple{1,$(UR-LR+1)}}(pointer(A.data)),
97+
LoopVectorization.StaticStridedPointer{$T,Tuple{1,$(UR-LR+1)}}(pointer(parent(A))),
9798
($(LR-2), $(LC-2))
9899
)
99100
end

test/runtests.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,25 @@ function clenshaw(x,coeff)
1414
return ret
1515
end
1616

17+
"""
18+
Causes `check_args` to fail.
19+
"""
20+
struct FallbackArrayWrapper{T,N} <: AbstractArray{T,N}
21+
data::Array{T,N}
22+
end
23+
Base.size(A::FallbackArrayWrapper) = size(A.data)
24+
Base.@propagate_inbounds Base.getindex(A::FallbackArrayWrapper, i::Vararg{Int, N}) where {N} = getindex(A.data, i...)
25+
Base.@propagate_inbounds Base.setindex!(A::FallbackArrayWrapper, v, i::Vararg{Int, N}) where {N} = setindex!(A.data, v, i...)
26+
Base.IndexStyle(::Type{<:FallbackArrayWrapper}) = IndexLinear()
27+
1728
@show LoopVectorization.VectorizationBase.REGISTER_COUNT
1829

1930
@time @testset "LoopVectorization.jl" begin
2031

2132
@time include("printmethods.jl")
2233

34+
@time include("fallback.jl")
35+
2336
@time include("utils.jl")
2437

2538
@time include("offsetarrays.jl")

0 commit comments

Comments
 (0)