Skip to content

Commit 8744fac

Browse files
committed
Removed two probably unreachable add_constant methods (constants in a loop-expr may be of type Expr/Symbol/Number), adjusted tests.
1 parent c7953e0 commit 8744fac

File tree

6 files changed

+89
-91
lines changed

6 files changed

+89
-91
lines changed

src/add_compute.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,9 @@ function isreductzero(op::Operation, ls::LoopSet, reduct_zero::Symbol)
9494
isconstant(op) || return false
9595
reduct_zero === op.instruction.mod && return true
9696
if reduct_zero === :zero
97-
identifier(op) ls.preamble_zeros && return true
97+
iszero(ls, op) && return true
9898
elseif reduct_zero === :one
99-
identifier(op) ls.preamble_ones && return true
99+
isone(ls, op) && return true
100100
end
101101
false
102102
end

src/add_constants.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@ function add_constant!(ls::LoopSet, var::Symbol, elementbytes::Int)
33
pushpreamble!(ls, op, var)
44
pushop!(ls, op, var)
55
end
6-
function add_constant!(ls::LoopSet, var, elementbytes::Int = 8)
7-
sym = gensym(:loopconstant)
8-
pushpreamble!(ls, Expr(:(=), sym, var))
9-
add_constant!(ls, sym, elementbytes)
10-
end
6+
# function add_constant!(ls::LoopSet, var, elementbytes::Int = 8)
7+
# sym = gensym(:loopconstant)
8+
# pushpreamble!(ls, Expr(:(=), sym, var))
9+
# add_constant!(ls, sym, elementbytes)
10+
# end
1111
function add_constant!(ls::LoopSet, var::Number, elementbytes::Int = 8)
1212
op = Operation(length(operations(ls)), gensym(:loopconstnumber), elementbytes, LOOPCONSTANT, constant, NODEPENDENCY, Symbol[], NOPARENTS)
1313
ops = operations(ls)
@@ -52,13 +52,13 @@ function add_constant!(
5252
op = Operation(length(operations(ls)), assignedsym, elementbytes, Instruction(f, value), constant, deps, NODEPENDENCY, NOPARENTS)
5353
pushop!(ls, op, assignedsym)
5454
end
55-
function add_constant!(
56-
ls::LoopSet, value, deps::Vector{Symbol}, assignedsym::Symbol, elementbytes::Int, f::Symbol = Symbol("")
57-
)
58-
intermediary = gensym(:intermediate) # hack, passing meta info here
59-
pushpreamble!(ls, Expr(:(=), intermediary, value))
60-
add_constant!(ls, intermediary, deps, assignedsym, f, elementbytes)
61-
end
55+
# function add_constant!(
56+
# ls::LoopSet, value, deps::Vector{Symbol}, assignedsym::Symbol, elementbytes::Int, f::Symbol = Symbol("")
57+
# )
58+
# intermediary = gensym(:intermediate) # hack, passing meta info here
59+
# pushpreamble!(ls, Expr(:(=), intermediary, value))
60+
# add_constant!(ls, intermediary, deps, assignedsym, f, elementbytes)
61+
# end
6262
function add_constant!(
6363
ls::LoopSet, value::Number, deps::Vector{Symbol}, assignedsym::Symbol, elementbytes::Int
6464
)

src/costs.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -228,11 +228,11 @@ function reduction_to_single_vector(x::Float64)
228228
x == 1.0 ? :evadd : x == 2.0 ? :evmul : x == 5.0 ? :max : x == 6.0 ? :min : throw("Reduction not found.")
229229
end
230230
reduction_to_single_vector(x) = reduction_to_single_vector(reduction_instruction_class(x))
231-
function reduction_to_scalar(x::Float64)
232-
# x == 1.0 ? :vsum : x == 2.0 ? :vprod : x == 3.0 ? :vany : x == 4.0 ? :vall : x == 5.0 ? :maximum : x == 6.0 ? :minimum : throw("Reduction not found.")
233-
x == 1.0 ? :vsum : x == 2.0 ? :vprod : x == 5.0 ? :maximum : x == 6.0 ? :minimum : throw("Reduction not found.")
234-
end
235-
reduction_to_scalar(x) = reduction_to_scalar(reduction_instruction_class(x))
231+
# function reduction_to_scalar(x::Float64)
232+
# # x == 1.0 ? :vsum : x == 2.0 ? :vprod : x == 3.0 ? :vany : x == 4.0 ? :vall : x == 5.0 ? :maximum : x == 6.0 ? :minimum : throw("Reduction not found.")
233+
# x == 1.0 ? :vsum : x == 2.0 ? :vprod : x == 5.0 ? :maximum : x == 6.0 ? :minimum : throw("Reduction not found.")
234+
# end
235+
# reduction_to_scalar(x) = reduction_to_scalar(reduction_instruction_class(x))
236236
function reduction_scalar_combine(x::Float64)
237237
# x == 1.0 ? :reduced_add : x == 2.0 ? :reduced_prod : x == 3.0 ? :reduced_any : x == 4.0 ? :reduced_all : x == 5.0 ? :reduced_max : x == 6.0 ? :reduced_min : throw("Reduction not found.")
238238
x == 1.0 ? :reduced_add : x == 2.0 ? :reduced_prod : x == 5.0 ? :reduced_max : x == 6.0 ? :reduced_min : throw("Reduction not found.")

src/graphs.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,20 @@ function zerotype(ls::LoopSet, op::Operation)
234234
end
235235
INVALID
236236
end
237-
237+
function Base.iszero(ls::LoopSet, op::Operation)
238+
opid = identifier(op)
239+
for (id,_) ls.preamble_zeros
240+
opid == id && return true
241+
end
242+
false
243+
end
244+
function Base.isone(ls::LoopSet, op::Operation)
245+
opid = identifier(op)
246+
for (id,_) ls.preamble_ones
247+
opid == id && return true
248+
end
249+
false
250+
end
238251

239252

240253
includesarray(ls::LoopSet, array::Symbol) = array ls.includedarrays

src/lower_compute.jl

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,5 @@
11
# A compute op needs to know the unrolling and tiling status of each of its parents.
2-
#
3-
function lower_compute_scalar!(
4-
q::Expr, op::Operation, vectorized::Symbol, W::Symbol, unrolled::Symbol, tiled::Symbol, U::Int,
5-
suffix::Union{Nothing,Int}, mask::Union{Nothing,Symbol,Unsigned} = nothing
6-
)
7-
lower_compute!(q, op, vectorized, W, unrolled, tiled, U, suffix, mask, false)
8-
end
9-
function lower_compute_unrolled!(
10-
q::Expr, op::Operation, vectorized::Symbol, W::Symbol, unrolled::Symbol, tiled::Symbol, U::Int,
11-
suffix::Union{Nothing,Int}, mask::Union{Nothing,Symbol,Unsigned} = nothing
12-
)
13-
lower_compute!(q, op, vectorized, W, unrolled, tiled, U, suffix, mask, true)
14-
end
2+
153
struct FalseCollection end
164
Base.getindex(::FalseCollection, i...) = false
175
function lower_compute!(

test/runtests.jl

Lines changed: 55 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -995,33 +995,23 @@ end
995995
softmax3_core_avx4!(lse, qq, xx, tmpmax, maxk, nk)
996996
end
997997

998-
function mysumavx(x)
998+
function sumprodavx(x)
999999
s = zero(eltype(x))
1000+
p = one(eltype(x))
10001001
@avx for i eachindex(x)
10011002
s += x[i]
1003+
p *=x[i]
10021004
end
1003-
s
1005+
s, p
10041006
end
1005-
function mysum_avx(x)
1007+
function sumprod_avx(x)
10061008
s = zero(eltype(x))
1007-
@_avx for i eachindex(x)
1008-
s += x[i]
1009-
end
1010-
s
1011-
end
1012-
function myprodavx(x)
1013-
p = one(eltype(x))
1014-
@avx for i eachindex(x)
1015-
p *= x[i]
1016-
end
1017-
p
1018-
end
1019-
function myprod_avx(x)
10201009
p = one(eltype(x))
10211010
@_avx for i eachindex(x)
1022-
p *= x[i]
1011+
s += x[i]
1012+
p *=x[i]
10231013
end
1024-
p
1014+
s, p
10251015
end
10261016

10271017
function test_bit_shift(counter)
@@ -1155,19 +1145,21 @@ end
11551145
@test sum(q2; dims=3) ones(T,ni,nj)
11561146

11571147
x .+= 0.545;
1158-
s = sum(x)
1159-
@test s mysumavx(x)
1160-
@test s mysum_avx(x)
1161-
p = prod(x)
1162-
@test p myprodavx(x)
1163-
@test p myprod_avx(x)
1164-
r = T == Float32 ? (Int32(-10):Int32(234)) : -10:234
1165-
s = sum(r)
1166-
@test s mysumavx(r)
1167-
@test s mysum_avx(r)
1168-
p = prod(r)
1169-
@test p myprodavx(r)
1170-
@test p myprod_avx(r)
1148+
s = sum(x); p = prod(x)
1149+
s1, p1 = sumprodavx(x)
1150+
@test s s1
1151+
@test p p1
1152+
s1, p1 = sumprod_avx(x)
1153+
@test s s1
1154+
@test p p1
1155+
r = T == Float32 ? (Int32(-10):Int32(107)) : (Int64(-10):Int64(107))
1156+
s = sum(r); p = prod(r)
1157+
s1, p1 = sumprodavx(r)
1158+
@test s s1
1159+
@test p p1
1160+
s1, p1 = sumprod_avx(r)
1161+
@test s s1
1162+
@test p p1
11711163

11721164
@test test_bit_shift(r) == test_bit_shiftavx(r)
11731165
@test test_bit_shift(r) == test_bit_shift_avx(r)
@@ -1178,9 +1170,13 @@ end
11781170
else
11791171
sum(identity, r)
11801172
end
1181-
@test s mysumavx(r)
1182-
@test s mysum_avx(r)
1183-
1173+
p = prod(r);
1174+
s1, p1 = sumprodavx(r)
1175+
@test s s1
1176+
@test p p1
1177+
s1, p1 = sumprod_avx(r)
1178+
@test s s1
1179+
@test p p1
11841180
end
11851181
end
11861182

@@ -1541,34 +1537,34 @@ end
15411537
C[m,n] > 0 && (C[m,n] = Cₘₙ)
15421538
end
15431539
end
1544-
function condstore!(y, x)
1545-
@inbounds for i eachindex(y, x)
1546-
x1 = x[i]
1540+
function condstore!(x)
1541+
@inbounds for i eachindex(x)
1542+
x1 = 2*x[i]-100
15471543
x2 = x1*x1
15481544
x3 = x2 + x1
1549-
y[i] = x1
1550-
(x1 < 30) && (y[i] = x2)
1551-
(x1 < 80) || (y[i] = x3)
1545+
x[i] = x1
1546+
(x1 < -50) && (x[i] = x2)
1547+
(x1 < 60) || (x[i] = x3)
15521548
end
15531549
end
1554-
function condstoreavx!(y, x)
1555-
@avx for i eachindex(y, x)
1556-
x1 = x[i]
1550+
function condstoreavx!(x)
1551+
@avx for i eachindex(x)
1552+
x1 = 2*x[i]-100
15571553
x2 = x1*x1
15581554
x3 = x2 + x1
1559-
y[i] = x1
1560-
(x1 < 30) && (y[i] = x2)
1561-
(x1 < 80) || (y[i] = x3)
1555+
x[i] = x1
1556+
(x1 < -50) && (x[i] = x2)
1557+
(x1 < 60) || (x[i] = x3)
15621558
end
15631559
end
1564-
function condstore_avx!(y, x)
1565-
@_avx for i eachindex(y, x)
1566-
x1 = x[i]
1560+
function condstore_avx!(x)
1561+
@_avx for i eachindex(x)
1562+
x1 = 2*x[i]-100
15671563
x2 = x1*x1
15681564
x3 = x2 + x1
1569-
y[i] = x1
1570-
(x1 < 30) && (y[i] = x2)
1571-
(x1 < 80) || (y[i] = x3)
1565+
x[i] = x1
1566+
(x1 < -50) && (x[i] = x2)
1567+
(x1 < 60) || (x[i] = x3)
15721568
end
15731569
end
15741570

@@ -1603,12 +1599,13 @@ end
16031599
if T <: Union{Float32,Float64}
16041600
a .*= 100;
16051601
end
1606-
b2 = similar(b);
1607-
condstore!(b, a)
1608-
condstoreavx!(b2, a)
1609-
@test b == b2
1610-
fill!(b2, -999999); condstore_avx!(b2, a)
1611-
@test b == b2
1602+
b1 = copy(a);
1603+
b2 = copy(a);
1604+
condstore!(b1)
1605+
condstoreavx!(b2)
1606+
@test b1 == b2
1607+
copyto!(b2, a); condstore_avx!(b2)
1608+
@test b1 == b2
16121609

16131610
M, K, N = 83, 85, 79;
16141611
if T <: Integer

0 commit comments

Comments
 (0)