Skip to content

Commit 9a509f0

Browse files
committed
Change handling of reductions and passing of loop constants defined by functions.
1 parent f53d6c6 commit 9a509f0

11 files changed

+50
-57
lines changed

src/LoopVectorization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using VectorizationBase: REGISTER_SIZE, REGISTER_COUNT, extract_data, num_vector
88
AbstractColumnMajorStridedPointer, AbstractRowMajorStridedPointer, AbstractSparseStridedPointer, AbstractStaticStridedPointer,
99
PackedStridedPointer, SparseStridedPointer, RowMajorStridedPointer, StaticStridedPointer, StaticStridedStruct,
1010
maybestaticfirst, maybestaticlast, scalar_less, scalar_greater
11-
using SIMDPirates: VECTOR_SYMBOLS, evadd, evsub, evmul, evfdiv, vrange, reduced_add, reduced_prod, reduced_max, reduced_min, vsum, vprod, vmaximum, vminimum,
11+
using SIMDPirates: VECTOR_SYMBOLS, evadd, evsub, evmul, evfdiv, vrange, reduced_add, reduced_prod, reduce_to_add, reduced_max, reduced_min, vsum, vprod, vmaximum, vminimum,
1212
sizeequivalentfloat, sizeequivalentint, vadd!, vsub!, vmul!, vfdiv!, vfmadd!, vfnmadd!, vfmsub!, vfnmsub!,
1313
vfmadd231, vfmsub231, vfnmadd231, vfnmsub231, sizeequivalentfloat, sizeequivalentint, #prefetch,
1414
vmullog2, vmullog10, vdivlog2, vdivlog10, vmullog2add!, vmullog10add!, vdivlog2add!, vdivlog10add!, vfmaddaddone

src/add_compute.jl

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -161,19 +161,9 @@ function add_reduction_update_parent!(
161161
reductinit = add_constant!(ls, gensym(:reductzero), loopdependencies(parent), reductsym, elementbytes, :numericconstant)
162162
if reduct_zero === :zero
163163
push!(ls.preamble_zeros, (identifier(reductinit), IntOrFloat))
164-
elseif reduct_zero === :one
165-
push!(ls.preamble_ones, (identifier(reductinit), IntOrFloat))
166164
else
167-
if reductzero === :true || reductzero === :false
168-
pushpreamble!(ls, Expr(:(=), name(reductinit), reductzero))
169-
else
170-
pushpreamble!(ls, Expr(:(=), name(reductinit), Expr(:call, reductzero, ls.T)))
171-
end
172-
pushpreamble!(ls, op, name, reductinit)
165+
push!(ls.preamble_funcofeltypes, (identifier(reductinit), reduct_zero))
173166
end
174-
# if
175-
# reductcombine = reduction_combine_to(instrclass)
176-
# end
177167
else
178168
reductinit = parent
179169
reductsym = var
@@ -328,7 +318,7 @@ function add_pow!(
328318
end
329319
if pint == 0
330320
op = Operation(length(operations(ls)), var, elementbytes, LOOPCONSTANT, constant, NODEPENDENCY, Symbol[], NOPARENTS)
331-
push!(ls.preamble_ones, (identifier(op),IntOrFloat))
321+
push!(ls.preamble_funcofeltypes, (identifier(op),:one))
332322
return pushop!(ls, op)
333323
elseif pint == 1
334324
return add_compute!(ls, var, :identity, [xop], elementbytes)

src/add_constants.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,6 @@ function add_constant!(ls::LoopSet, var::Number, elementbytes::Int = 8)
1717
(instruction(ops[id]) === LOOPCONSTANT && typ == typ_) && return ops[id]
1818
end
1919
push!(ls.preamble_zeros, (identifier(op),typ))
20-
elseif isone(var)
21-
for (id,typ_) ls.preamble_ones
22-
(instruction(ops[id]) === LOOPCONSTANT && typ == typ_) && return ops[id]
23-
end
24-
push!(ls.preamble_ones, (identifier(op),typ))
2520
elseif var isa Integer
2621
for (id,ivar) ls.preamble_symint
2722
(instruction(ops[id]) === LOOPCONSTANT && ivar == var) && return ops[id]

src/condense_loopset.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ function argmeta_and_consts_description(ls::LoopSet, arraysymbolinds)
150150
Expr(:curly, :Tuple, ls.preamble_symint...),
151151
Expr(:curly, :Tuple, ls.preamble_symfloat...),
152152
Expr(:curly, :Tuple, ls.preamble_zeros...),
153-
Expr(:curly, :Tuple, ls.preamble_ones...)
153+
Expr(:curly, :Tuple, ls.preamble_funcofeltypes...)
154154
)
155155
end
156156

src/costs.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,8 @@ const COST = Dict{Symbol,InstructionCost}(
146146
:vprod => InstructionCost(6,2.0),
147147
:reduced_add => InstructionCost(4,0.5),# ignoring reduction part of cost, might be nop
148148
:reduced_prod => InstructionCost(4,0.5),# ignoring reduction part of cost, might be nop
149+
:reduced_max => InstructionCost(4,0.5),# ignoring reduction part of cost, might be nop
150+
:reduced_min => InstructionCost(4,0.5),# ignoring reduction part of cost, might be nop
149151
:reduce_to_add => InstructionCost(0,0.0,0.0,0),
150152
:reduce_to_prod => InstructionCost(0,0.0,0.0,0),
151153
:abs => InstructionCost(1, 0.5),

src/graphs.jl

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ struct LoopSet
172172
preamble_symint::Vector{Tuple{Int,Int}}
173173
preamble_symfloat::Vector{Tuple{Int,Float64}}
174174
preamble_zeros::Vector{Tuple{Int,NumberType}}
175-
preamble_ones::Vector{Tuple{Int,NumberType}}
175+
preamble_funcofeltypes::Vector{Tuple{Int,Symbol}}
176176
includedarrays::Vector{Symbol}
177177
includedactualarrays::Vector{Symbol}
178178
syms_aliasing_refs::Vector{Symbol}
@@ -220,7 +220,7 @@ function pushpreamble!(ls::LoopSet, op::Operation, v::Number)
220220
if iszero(v)
221221
push!(ls.preamble_zeros, (id, typ))
222222
elseif isone(v)
223-
push!(ls.preamble_ones, (id, typ))
223+
push!(ls.preamble_funcofeltypes, (id, :one))
224224
elseif v isa Integer
225225
push!(ls.preamble_symint, (id, convert(Int,v)))
226226
else
@@ -233,7 +233,7 @@ function pushpreamble!(ls::LoopSet, op::Operation, RHS::Expr)
233233
if RHS.head === :call && first(RHS.args) === :zero
234234
push!(ls.preamble_zeros, (identifier(op), IntOrFloat))
235235
elseif RHS.head === :call && first(RHS.args) === :one
236-
push!(ls.preamble_ones, (identifier(op), IntOrFloat))
236+
push!(ls.preamble_funcofeltypes, (identifier(op), :one))
237237
else
238238
pushpreamble!(ls, Expr(:(=), c, RHS))
239239
pushpreamble!(ls, op, c)
@@ -247,20 +247,6 @@ function zerotype(ls::LoopSet, op::Operation)
247247
end
248248
INVALID
249249
end
250-
# function Base.iszero(ls::LoopSet, op::Operation)
251-
# opid = identifier(op)
252-
# for (id,_) ∈ ls.preamble_zeros
253-
# opid == id && return true
254-
# end
255-
# false
256-
# end
257-
# function Base.isone(ls::LoopSet, op::Operation)
258-
# opid = identifier(op)
259-
# for (id,_) ∈ ls.preamble_ones
260-
# opid == id && return true
261-
# end
262-
# false
263-
# end
264250

265251
includesarray(ls::LoopSet, array::Symbol) = array ls.includedarrays
266252

@@ -496,7 +482,11 @@ function add_operation!(
496482
elseif f === :zero || f === :one
497483
c = gensym(f)
498484
op = add_constant!(ls, c, ls.loopsymbols[1:position], LHS, elementbytes, :numericconstant)
499-
push!(f === :zero ? ls.preamble_zeros : ls.preamble_ones, (identifier(op), IntOrFloat))
485+
if f === :zero
486+
push!(ls.preamble_zeros, (identifier(op), IntOrFloat))
487+
else
488+
push!(ls.preamble_funcofeltypes, (identifier(op), :one))
489+
end
500490
op
501491
else
502492
add_compute!(ls, LHS, RHS, elementbytes, position)
@@ -524,7 +514,11 @@ function add_operation!(
524514
elseif f === :zero || f === :one
525515
c = gensym(f)
526516
op = add_constant!(ls, c, ls.loopsymbols[1:position], LHS_sym, elementbytes, :numericconstant)
527-
push!(f === :zero ? ls.preamble_zeros : ls.preamble_ones, (identifier(op), IntOrFloat))
517+
if f === :zero
518+
push!(ls.preamble_zeros, (identifier(op), IntOrFloat))
519+
else
520+
push!(ls.preamble_funcofeltypes, (identifier(op), :one))
521+
end
528522
op
529523
else
530524
add_compute!(ls, LHS_sym, RHS, elementbytes, position, LHS_ref)

src/lower_constant.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

2-
@inline onefloat(::Type{T}) where {T} = one(sizeequivalentfloat(T))
3-
@inline oneinteger(::Type{T}) where {T} = one(sizeequivalentint(T))
2+
# @inline onefloat(::Type{T}) where {T} = one(sizeequivalentfloat(T))
3+
# @inline oneinteger(::Type{T}) where {T} = one(sizeequivalentint(T))
44
@inline zerofloat(::Type{T}) where {T} = zero(sizeequivalentfloat(T))
55
@inline zerointeger(::Type{T}) where {T} = zero(sizeequivalentint(T))
66

@@ -64,6 +64,12 @@ function lower_constant!(
6464
Expr(:call, Expr(:(.), Expr(:(.), :LoopVectorization, QuoteNode(:SIMDPirates)), QuoteNode(:addscalar)), Expr(:call, lv(:vzero), VECTORWIDTHSYMBOL, ELTYPESYMBOL), constsym)
6565
elseif instrclass == MULTIPLICATIVE_IN_REDUCTIONS
6666
Expr(:call, Expr(:(.), Expr(:(.), :LoopVectorization, QuoteNode(:SIMDPirates)), QuoteNode(:mulscalar)), Expr(:call, lv(:vbroadcast), VECTORWIDTHSYMBOL, Expr(:call, :one, ELTYPESYMBOL)), constsym)
67+
elseif instrclass == MAX
68+
Expr(:call, Expr(:(.), Expr(:(.), :LoopVectorization, QuoteNode(:SIMDPirates)), QuoteNode(:maxscalar)), Expr(:call, lv(:vbroadcast), VECTORWIDTHSYMBOL, Expr(:call, :typemin, ELTYPESYMBOL)), constsym)
69+
70+
elseif instrclass == MIN
71+
Expr(:call, Expr(:(.), Expr(:(.), :LoopVectorization, QuoteNode(:SIMDPirates)), QuoteNode(:minscalar)), Expr(:call, lv(:vbroadcast), VECTORWIDTHSYMBOL, Expr(:call, :typemax, ELTYPESYMBOL)), constsym)
72+
6773
else
6874
throw("Reductions of type $(reduction_zero(reinstrclass)) not yet supported; please file an issue as a reminder to take care of this.")
6975
end
@@ -132,14 +138,8 @@ function lower_licm_constants!(ls::LoopSet)
132138
setconstantop!(ls, ops[id], Expr(:call, lv(:zerofloat), ELTYPESYMBOL))
133139
end
134140
end
135-
for (id,typ) ls.preamble_ones
136-
if typ == IntOrFloat
137-
setop!(ls, ops[id], Expr(:call, :one, ELTYPESYMBOL))
138-
elseif typ == HardInt
139-
setop!(ls, ops[id], Expr(:call, lv(:oneinteger), ELTYPESYMBOL))
140-
else#if typ == HardFloat
141-
setop!(ls, ops[id], Expr(:call, lv(:onefloat), ELTYPESYMBOL))
142-
end
141+
for (id,f) ls.preamble_funcofeltypes
142+
setop!(ls, ops[id], Expr(:call, f, ELTYPESYMBOL))
143143
end
144144
end
145145

src/reconstruct_loopset.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ function process_metadata!(ls::LoopSet, AM, num_arrays::Int)
203203
expandbyoffset!(ls.preamble_symint, AM[4].parameters, opoffsets)
204204
expandbyoffset!(ls.preamble_symfloat, AM[5].parameters, opoffsets)
205205
expandbyoffset!(ls.preamble_zeros, AM[6].parameters, opoffsets)
206-
expandbyoffset!(ls.preamble_ones, AM[7].parameters, opoffsets)
206+
expandbyoffset!(ls.preamble_funcofeltypes, AM[7].parameters, opoffsets)
207207
nothing
208208
end
209209
function expandbyoffset!(indexpand::Vector{T}, inds, offsets::Vector{Int}, expand::Bool = true) where {T <: Union{Int,Tuple{Int,<:Any}}}

src/split_loops.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ function split_loopset(ls::LoopSet, ids)
4343
append_if_included!(ls_new.preamble_symint, ls.preamble_symint, included)
4444
append_if_included!(ls_new.preamble_symfloat, ls.preamble_symfloat, included)
4545
append_if_included!(ls_new.preamble_zeros, ls.preamble_zeros, included)
46-
append_if_included!(ls_new.preamble_ones, ls.preamble_ones, included)
46+
append_if_included!(ls_new.preamble_funcofeltypes, ls.preamble_funcofeltypes, included)
4747
ls_new
4848
end
4949

test/mapreduce.jl

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11

2+
23
@testset "mapreduce" begin
34
function maximum_avx(x)
45
s = typemin(eltype(x))
@@ -8,23 +9,34 @@
89
s
910
end
1011
for T (Int32, Int64, Float32, Float64)
12+
@show T, @__LINE__
1113
if T <: Integer
1214
R = T(1):T(100)
1315
x7 = rand(R, 7); y7 = rand(R, 7);
1416
x = rand(R, 127); y = rand(R, 127);
1517
else
1618
x7 = rand(T, 7); y7 = rand(T, 7);
1719
x = rand(T, 127); y = rand(T, 127);
18-
@test vmapreduce(hypot, +, x, y) mapreduce(hypot, +, x, y)
19-
@test vmapreduce(^, (a,b) -> a + b, x7, y7) mapreduce(^, (a,b) -> a + b, x7, y7)
20+
if VERSION v"1.4"
21+
@test vmapreduce(hypot, +, x, y) mapreduce(hypot, +, x, y)
22+
@test vmapreduce(^, (a,b) -> a + b, x7, y7) mapreduce(^, +, x7, y7)
23+
else
24+
@test vmapreduce(hypot, +, x, y) sum(hypot.(x, y))
25+
@test vmapreduce(^, (a,b) -> a + b, x7, y7) sum(x7 .^ y7)
26+
end
2027
end
2128
@test vreduce(+, x7) sum(x7)
2229
@test vreduce(+, x) sum(x)
2330
@test_throws AssertionError vmapreduce(hypot, +, x7, x)
24-
@test vmapreduce(a -> 2a, *, x) mapreduce(a -> 2a, *, x)
25-
@test vmapreduce(sin, +, x7) mapreduce(sin, +, x7)
26-
@test vmapreduce(log, +, x) mapreduce(log, +, x)
27-
@test vmapreduce(abs2, +, x) mapreduce(abs2, +, x)
31+
if VERSION v"1.4"
32+
@test vmapreduce(a -> 2a, *, x) mapreduce(a -> 2a, *, x)
33+
@test vmapreduce(sin, +, x7) mapreduce(sin, +, x7)
34+
else
35+
@test vmapreduce(a -> 2a, *, x) prod(2 .* x)
36+
@test vmapreduce(sin, +, x7) sum(sin.(x7))
37+
end
38+
@test vmapreduce(log, +, x) sum(log, x)
39+
@test vmapreduce(abs2, +, x) sum(abs2, x)
2840
@test maximum(x) == vreduce(max, x) == maximum_avx(x)
2941
end
3042

0 commit comments

Comments
 (0)