Skip to content

Commit 8e8fbcc

Browse files
committed
Added prod tests.
1 parent 5ec4989 commit 8e8fbcc

File tree

4 files changed

+38
-8
lines changed

4 files changed

+38
-8
lines changed

src/add_compute.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,9 @@ function setdiffv!(s4::AbstractVector{T}, s3::AbstractVector{T}, s1::AbstractVec
3434
end
3535
function update_deps!(deps::Vector{Symbol}, reduceddeps::Vector{Symbol}, parent::Operation)
3636
mergesetv!(deps, loopdependencies(parent))#, reduceddependencies(parent))
37-
if !(isload(parent) || isconstant(parent)) && parent.instruction.instr (:reduced_add, :reduced_prod, :reduce_to_add, :reduce_to_prod)
37+
if !(isload(parent) || isconstant(parent)) && !isreductcombineinstr(parent)
3838
mergesetv!(reduceddeps, reduceddependencies(parent))
3939
end
40-
#
4140
nothing
4241
end
4342

@@ -139,7 +138,6 @@ function add_reduction_update_parent!(
139138
end
140139
combineddeps = copy(deps); mergesetv!(combineddeps, reduceddeps)
141140
directdependency && pushparent!(vparents, deps, reduceddeps, reductinit)#parent) # deps and reduced deps will not be disjoint
142-
# update_reduction_status!(vparents, combineddeps, name(reductinit))
143141
update_reduction_status!(vparents, reduceddeps, name(reductinit))
144142
# this is the op added by add_compute
145143
op = Operation(length(operations(ls)), reductsym, elementbytes, instr, compute, deps, reduceddeps, vparents)

src/costs.jl

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -224,26 +224,36 @@ const REDUCTION_CLASS = Dict{Symbol,Float64}(
224224
reduction_instruction_class(instr::Symbol) = get(REDUCTION_CLASS, instr, NaN)
225225
reduction_instruction_class(instr::Instruction) = get(REDUCTION_CLASS, instr.instr, NaN)
226226
function reduction_to_single_vector(x::Float64)
227-
x == 1.0 ? :evadd : x == 2.0 ? :evmul : x == 3.0 ? :vor : x == 4.0 ? :vand : x == 5.0 ? :max : x == 6.0 ? :min : throw("Reduction not found.")
227+
# x == 1.0 ? :evadd : x == 2.0 ? :evmul : x == 3.0 ? :vor : x == 4.0 ? :vand : x == 5.0 ? :max : x == 6.0 ? :min : throw("Reduction not found.")
228+
x == 1.0 ? :evadd : x == 2.0 ? :evmul : x == 5.0 ? :max : x == 6.0 ? :min : throw("Reduction not found.")
228229
end
229230
reduction_to_single_vector(x) = reduction_to_single_vector(reduction_instruction_class(x))
230231
function reduction_to_scalar(x::Float64)
231-
x == 1.0 ? :vsum : x == 2.0 ? :vprod : x == 3.0 ? :vany : x == 4.0 ? :vall : x == 5.0 ? :maximum : x == 6.0 ? :minimum : throw("Reduction not found.")
232+
# x == 1.0 ? :vsum : x == 2.0 ? :vprod : x == 3.0 ? :vany : x == 4.0 ? :vall : x == 5.0 ? :maximum : x == 6.0 ? :minimum : throw("Reduction not found.")
233+
x == 1.0 ? :vsum : x == 2.0 ? :vprod : x == 5.0 ? :maximum : x == 6.0 ? :minimum : throw("Reduction not found.")
232234
end
233235
reduction_to_scalar(x) = reduction_to_scalar(reduction_instruction_class(x))
234236
function reduction_scalar_combine(x::Float64)
235-
x == 1.0 ? :reduced_add : x == 2.0 ? :reduced_prod : x == 3.0 ? :reduced_any : x == 4.0 ? :reduced_all : x == 5.0 ? :reduced_max : x == 6.0 ? :reduced_min : throw("Reduction not found.")
237+
# x == 1.0 ? :reduced_add : x == 2.0 ? :reduced_prod : x == 3.0 ? :reduced_any : x == 4.0 ? :reduced_all : x == 5.0 ? :reduced_max : x == 6.0 ? :reduced_min : throw("Reduction not found.")
238+
x == 1.0 ? :reduced_add : x == 2.0 ? :reduced_prod : x == 5.0 ? :reduced_max : x == 6.0 ? :reduced_min : throw("Reduction not found.")
236239
end
237240
reduction_scalar_combine(x) = reduction_scalar_combine(reduction_instruction_class(x))
238241
function reduction_combine_to(x::Float64)
239-
x == 1.0 ? :reduce_to_add : x == 2.0 ? :reduce_to_prod : x == 3.0 ? :reduce_to_any : x == 4.0 ? :reduce_to_all : x == 5.0 ? :reduce_to_max : x == 6.0 ? :reduce_to_min : throw("Reduction not found.")
242+
# x == 1.0 ? :reduce_to_add : x == 2.0 ? :reduce_to_prod : x == 3.0 ? :reduce_to_any : x == 4.0 ? :reduce_to_all : x == 5.0 ? :reduce_to_max : x == 6.0 ? :reduce_to_min : throw("Reduction not found.")
243+
x == 1.0 ? :reduce_to_add : x == 2.0 ? :reduce_to_prod : x == 5.0 ? :reduce_to_max : x == 6.0 ? :reduce_to_min : throw("Reduction not found.")
240244
end
241245
reduction_combine_to(x) = reduction_combine_to(reduction_instruction_class(x))
242246
function reduction_zero(x::Float64)
243-
x == 1.0 ? :zero : x == 2.0 ? :one : x == 3.0 ? :false : x == 4.0 ? :true : x == 5.0 ? :typemin : x == 6.0 ? :typemax : throw("Reduction not found.")
247+
# x == 1.0 ? :zero : x == 2.0 ? :one : x == 3.0 ? :false : x == 4.0 ? :true : x == 5.0 ? :typemin : x == 6.0 ? :typemax : throw("Reduction not found.")
248+
x == 1.0 ? :zero : x == 2.0 ? :one : x == 5.0 ? :typemin : x == 6.0 ? :typemax : throw("Reduction not found.")
244249
end
245250
reduction_zero(x) = reduction_zero(reduction_instruction_class(x))
246251

252+
function isreductcombineinstr(instr::Symbol)
253+
instr (:reduced_add, :reduced_prod, :reduce_to_add, :reduce_to_prod, :reduced_max, :reduced_min, :reduce_to_max, :reduce_to_min)
254+
end
255+
isreductcombineinstr(instr::Instruction) = isreductcombineinstr(instr.instr)
256+
247257
const FUNCTIONSYMBOLS = Dict{Type{<:Function},Instruction}(
248258
typeof(+) => :(+),
249259
typeof(SIMDPirates.vadd) => :(+),

src/operations.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ name(op::Operation) = op.variable
153153
instruction(op::Operation) = op.instruction
154154
isreductionzero(op::Operation, instr::Symbol) = op.instruction.mod === REDUCTION_ZERO[instr]
155155
refname(op::Operation) = op.ref.ptr
156+
isreductcombineinstr(op::Operation) = iscompute(op) && isreductcombineinstr(instruction(op))
156157
"""
157158
mvar = mangledvar(op)
158159

test/runtests.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,6 +1009,20 @@ end
10091009
end
10101010
s
10111011
end
1012+
function myprodavx(x)
1013+
p = one(eltype(x))
1014+
@avx for i eachindex(x)
1015+
p *= x[i]
1016+
end
1017+
p
1018+
end
1019+
function myprod_avx(x)
1020+
p = one(eltype(x))
1021+
@_avx for i eachindex(x)
1022+
p *= x[i]
1023+
end
1024+
p
1025+
end
10121026

10131027
function test_bit_shift(counter)
10141028
accu = zero(first(counter))
@@ -1140,13 +1154,20 @@ end
11401154
@test q1 q2
11411155
@test sum(q2; dims=3) ones(T,ni,nj)
11421156

1157+
x .+= 0.545;
11431158
s = sum(x)
11441159
@test s mysumavx(x)
11451160
@test s mysum_avx(x)
1161+
p = prod(x)
1162+
@test p myprodavx(x)
1163+
@test p myprod_avx(x)
11461164
r = T == Float32 ? (Int32(-10):Int32(234)) : -10:234
11471165
s = sum(r)
11481166
@test s mysumavx(r)
11491167
@test s mysum_avx(r)
1168+
p = prod(r)
1169+
@test p myprodavx(r)
1170+
@test p myprod_avx(r)
11501171

11511172
@test test_bit_shift(r) == test_bit_shiftavx(r)
11521173
@test test_bit_shift(r) == test_bit_shift_avx(r)

0 commit comments

Comments
 (0)