Skip to content

Commit e13c990

Browse files
committed
Fix bug in add_or and add test.
1 parent 8e8fbcc commit e13c990

File tree

4 files changed

+51
-45
lines changed

4 files changed

+51
-45
lines changed

src/LoopVectorization.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ using VectorizationBase: REGISTER_SIZE, REGISTER_COUNT, extract_data, num_vector
66
maybestaticlength, maybestaticsize, staticm1, subsetview, vzero, stridedpointer_for_broadcast,
77
Static, StaticUnitRange, StaticLowerUnitRange, StaticUpperUnitRange,
88
PackedStridedPointer, SparseStridedPointer, RowMajorStridedPointer, StaticStridedPointer, StaticStridedStruct
9-
using SIMDPirates: VECTOR_SYMBOLS, evadd, evmul, vrange, reduced_add, reduced_prod, reduce_to_add, reduce_to_prod#,
9+
using SIMDPirates: VECTOR_SYMBOLS, evadd, evmul, vrange, reduced_add, reduced_prod, reduce_to_add, reduce_to_prod,
10+
sizeequivalentfloat, sizeequivalentint
1011
# vmullog2, vmullog10, vdivlog2, vdivlog2add, vdivlog10, vdivlog10add, vfmaddaddone
1112
using Base.Broadcast: Broadcasted, DefaultArrayStyle
1213
using LinearAlgebra: Adjoint, Transpose

src/add_ifelse.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ function add_orblock!(ls::LoopSet, condop::Operation, LHS, RHS::Expr, elementbyt
6565
add_orblock!(ls, condop, LHS, rhsop, elementbytes, position)
6666
end
6767
function add_orblock!(ls::LoopSet, condop::Operation, LHS, RHS, elementbytes::Int, position::Int)
68-
rhsop = getop(ls, RHS)
68+
rhsop = getop(ls, RHS, elementbytes)
6969
add_orblock!(ls, condop, LHS, rhsop, elementbytes, position)
7070
end
7171
function add_orblock!(ls::LoopSet, condexpr::Expr, condeval::Expr, elementbytes::Int, position::Int)

src/lower_constant.jl

Lines changed: 6 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,22 @@
11

2-
@inline zerointeger(::Type{Float16}) = zero(Int16)
3-
@inline zerointeger(::Type{Float32}) = zero(Int32)
4-
@inline zerointeger(::Type{Float64}) = zero(Int64)
5-
@inline zerointeger(::Type{I}) where {I<:Integer} = zero(I)
6-
@inline zerofloat(::Type{Float16}) = zero(Float16)
7-
@inline zerofloat(::Type{Float32}) = zero(Float32)
8-
@inline zerofloat(::Type{Float64}) = zero(Float64)
9-
@inline zerofloat(::Type{UInt16}) = zero(Float16)
10-
@inline zerofloat(::Type{UInt32}) = zero(Float32)
11-
@inline zerofloat(::Type{UInt64}) = zero(Float64)
12-
@inline zerofloat(::Type{Int16}) = zero(Float16)
13-
@inline zerofloat(::Type{Int32}) = zero(Float32)
14-
@inline zerofloat(::Type{Int64}) = zero(Float64)
2+
@inline onefloat(::Type{T}) where {T} = one(sizeequivalentfloat(T))
3+
@inline oneinteger(::Type{T}) where {T} = one(sizeequivalentint(T))
4+
@inline zerofloat(::Type{T}) where {T} = zero(sizeequivalentfloat(T))
5+
@inline zerointeger(::Type{T}) where {T} = zero(sizeequivalentint(T))
156

167

17-
@inline oneinteger(::Type{Float16}) = one(Int16)
18-
@inline oneinteger(::Type{Float32}) = one(Int32)
19-
@inline oneinteger(::Type{Float64}) = one(Int64)
20-
@inline oneinteger(::Type{I}) where {I<:Integer} = one(I)
21-
@inline onefloat(::Type{Float16}) = one(Float16)
22-
@inline onefloat(::Type{Float32}) = one(Float32)
23-
@inline onefloat(::Type{Float64}) = one(Float64)
24-
@inline onefloat(::Type{UInt16}) = one(Float16)
25-
@inline onefloat(::Type{UInt32}) = one(Float32)
26-
@inline onefloat(::Type{UInt64}) = one(Float64)
27-
@inline onefloat(::Type{Int16}) = one(Float16)
28-
@inline onefloat(::Type{Int32}) = one(Float32)
29-
@inline onefloat(::Type{Int64}) = one(Float64)
30-
31-
@inline equivalentint(::Type{I}) where {I<:Integer} = I
32-
@inline equivalentint(::Type{Float16}) = Int16
33-
@inline equivalentint(::Type{Float32}) = Int32
34-
@inline equivalentint(::Type{Float64}) = Int64
35-
@inline equivalentfloat(::Type{Float16}) = Float16
36-
@inline equivalentfloat(::Type{Float32}) = Float64
37-
@inline equivalentfloat(::Type{Float64}) = Float64
38-
@inline equivalentfloat(::Type{Int16}) = Float16
39-
@inline equivalentfloat(::Type{Int32}) = Float64
40-
@inline equivalentfloat(::Type{Int64}) = Float64
41-
@inline equivalentfloat(::Type{UInt16}) = Float16
42-
@inline equivalentfloat(::Type{UInt32}) = Float64
43-
@inline equivalentfloat(::Type{UInt64}) = Float64
44-
458
function lower_zero!(
469
q::Expr, op::Operation, vectorized::Symbol, ls::LoopSet, unrolled::Symbol, U::Int, suffix::Union{Nothing,Int}, zerotyp::NumberType = zerotype(ls, op)
4710
)
4811
W = ls.W; typeT = ls.T
4912
mvar = variable_name(op, suffix)
5013
if zerotyp == HardInt
5114
newtypeT = gensym(:IntType)
52-
pushpreamble!(ls, Expr(:(=), newtypeT, Expr(:call, lv(:equivalentint), typeT)))
15+
pushpreamble!(ls, Expr(:(=), newtypeT, Expr(:call, lv(:sizeequivalentint), typeT)))
5316
typeT = newtypeT
5417
elseif zerotyp == HardFloat
5518
newtypeT = gensym(:FloatType)
56-
pushpreamble!(ls, Expr(:(=), newtypeT, Expr(:call, lv(:equivalentfloat), typeT)))
19+
pushpreamble!(ls, Expr(:(=), newtypeT, Expr(:call, lv(:sizeequivalentfloat), typeT)))
5720
typeT = newtypeT
5821
end
5922
if vectorized loopdependencies(op) || vectorized reducedchildren(op) || vectorized reduceddependencies(op)

test/runtests.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1395,6 +1395,7 @@ end
13951395
fill!(D2, -999999); D2 = @avx C .+ At' *ˡ B;
13961396
@test D1 D2
13971397
if T <: Union{Float32,Float64}
1398+
@show T, @__LINE__
13981399
D3 = cos.(B');
13991400
D4 = @avx cos.(B');
14001401
@test D3 D4
@@ -1541,10 +1542,41 @@ end
15411542
C[m,n] > 0 && (C[m,n] = Cₘₙ)
15421543
end
15431544
end
1545+
function condstore!(y, x)
1546+
@inbounds for i eachindex(y, x)
1547+
x1 = x[i]
1548+
x2 = x1*x1
1549+
x3 = x2 + x1
1550+
y[i] = x1
1551+
(x1 < 30) && (y[i] = x2)
1552+
(x1 < 80) || (y[i] = x3)
1553+
end
1554+
end
1555+
function condstoreavx!(y, x)
1556+
@avx for i eachindex(y, x)
1557+
x1 = x[i]
1558+
x2 = x1*x1
1559+
x3 = x2 + x1
1560+
y[i] = x1
1561+
(x1 < 30) && (y[i] = x2)
1562+
(x1 < 80) || (y[i] = x3)
1563+
end
1564+
end
1565+
function condstore_avx!(y, x)
1566+
@_avx for i eachindex(y, x)
1567+
x1 = x[i]
1568+
x2 = x1*x1
1569+
x3 = x2 + x1
1570+
y[i] = x1
1571+
(x1 < 30) && (y[i] = x2)
1572+
(x1 < 80) || (y[i] = x3)
1573+
end
1574+
end
15441575

15451576

15461577
N = 117
15471578
@time for T (Float32, Float64, Int32, Int64)
1579+
@show T, @__LINE__
15481580
if T <: Integer
15491581
a = rand(-T(100):T(100), N); b = rand(-T(100):T(100), N);
15501582
else
@@ -1569,6 +1601,16 @@ end
15691601
fill!(c2, -999999999); maybewriteoravx!(c2, a, b)
15701602
@test c1 c2
15711603

1604+
if T <: Union{Float32,Float64}
1605+
a .*= 100;
1606+
end
1607+
b2 = similar(b);
1608+
condstore!(b, a)
1609+
condstoreavx!(b2, a)
1610+
@test b == b2
1611+
fill!(b2, -999999); condstore_avx!(b2, a)
1612+
@test b == b2
1613+
15721614
M, K, N = 83, 85, 79;
15731615
if T <: Integer
15741616
A = rand(T(-100):T(100), K, M);

0 commit comments

Comments
 (0)