Skip to content

Commit 1f3e9d2

Browse files
committed
More tweaks to correctly identifying family structure.
1 parent b356d5d commit 1f3e9d2

File tree

7 files changed

+109
-37
lines changed

7 files changed

+109
-37
lines changed

src/LoopVectorization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using VectorizationBase: REGISTER_SIZE, REGISTER_COUNT, extract_data, num_vector
88
AbstractColumnMajorStridedPointer, AbstractRowMajorStridedPointer, AbstractSparseStridedPointer, AbstractStaticStridedPointer,
99
PackedStridedPointer, SparseStridedPointer, RowMajorStridedPointer, StaticStridedPointer, StaticStridedStruct,
1010
maybestaticfirst, maybestaticlast, scalar_less, scalar_greater
11-
using SIMDPirates: VECTOR_SYMBOLS, evadd, evsub, evmul, evfdiv, vrange, reduced_add, reduced_prod, reduce_to_add, reduce_to_prod,
11+
using SIMDPirates: VECTOR_SYMBOLS, evadd, evsub, evmul, evfdiv, vrange, reduced_add, reduced_prod, reduce_to_add, reduce_to_prod, vsum, vprod, vmaximum, vminimum,
1212
sizeequivalentfloat, sizeequivalentint, vadd!, vsub!, vmul!, vfdiv!, vfmadd!, vfnmadd!, vfmsub!, vfnmsub!,
1313
vfmadd231, vfmsub231, vfnmadd231, vfnmsub231, sizeequivalentfloat, sizeequivalentint, #prefetch,
1414
vmullog2, vmullog10, vdivlog2, vdivlog10, vmullog2add!, vmullog10add!, vdivlog2add!, vdivlog10add!, vfmaddaddone

src/add_compute.jl

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -48,25 +48,30 @@ function pushparent!(mpref::ArrayReferenceMetaPosition, parent::Operation)
4848
pushparent!(mpref.parents, mpref.loopdependencies, mpref.reduceddeps, parent)
4949
end
5050
function add_parent!(
51-
parents::Vector{Operation}, deps::Vector{Symbol}, reduceddeps::Vector{Symbol}, ls::LoopSet, var, elementbytes::Int, position::Int
51+
vparents::Vector{Operation}, deps::Vector{Symbol}, reduceddeps::Vector{Symbol}, ls::LoopSet, var, elementbytes::Int, position::Int
5252
)
5353
parent = if var isa Symbol
54-
getop(ls, var, elementbytes)
54+
opp = getop(ls, var, elementbytes)
55+
if iscompute(opp) && instruction(opp).instr === :identity && length(loopdependencies(opp)) < position && isone(length(parents(opp))) && name(opp) === name(first(parents(opp)))
56+
first(parents(opp))
57+
else
58+
opp
59+
end
5560
elseif var isa Expr #CSE candidate
5661
add_operation!(ls, gensym(:temporary), var, elementbytes, position)
5762
else # assumed constant
5863
add_constant!(ls, var, elementbytes)
5964
# add_constant!(ls, var, deps, gensym(:loopredefconst), elementbytes)
6065
end
61-
pushparent!(parents, deps, reduceddeps, parent)
62-
end
63-
function add_reduction!(
64-
parents::Vector{Operation}, deps::Vector{Symbol}, reduceddeps::Vector{Symbol}, ls::LoopSet, var::Symbol, elementbytes::Int
65-
)
66-
get!(ls.opdict, var) do
67-
add_constant!(ls, var, elementbytes)
68-
end
66+
pushparent!(vparents, deps, reduceddeps, parent)
6967
end
68+
# function add_reduction!(
69+
# vparents::Vector{Operation}, deps::Vector{Symbol}, reduceddeps::Vector{Symbol}, ls::LoopSet, var::Symbol, elementbytes::Int
70+
# )
71+
# get!(ls.opdict, var) do
72+
# add_constant!(ls, var, elementbytes)
73+
# end
74+
# end
7075
function search_tree(opv::Vector{Operation}, var::Symbol) # relies on cycles being forbidden
7176
for opp opv
7277
name(opp) === var && return true
@@ -118,13 +123,6 @@ function add_reduced_deps!(op::Operation, reduceddeps::Vector{Symbol})
118123
nothing
119124
end
120125

121-
# function substitute_op_in_parents!(
122-
# vparents::Vector{Operation}, replacer::Operation, replacee::Operation, reduceddeps::Vector{Symbol}
123-
# )
124-
# @show replacer replacee
125-
# #
126-
# substitute_op_in_parents_recurse!(vparents, replacer, replacee)
127-
# end
128126
function substitute_op_in_parents!(
129127
vparents::Vector{Operation}, replacer::Operation, replacee::Operation, reduceddeps::Vector{Symbol}
130128
)
@@ -188,7 +186,7 @@ function add_reduction_update_parent!(
188186
if instr.instr (:-, :vsub!, :vsub, :/, :vfdiv!, :vfidiv!)
189187
update_deps!(deps, reduceddeps, reductinit)#parent) # deps and reduced deps will not be disjoint
190188
end
191-
elseif !isouterreduction
189+
elseif !isouterreduction && reductinit !== parent
192190
substitute_op_in_parents!(vparents, reductinit, parent, reduceddeps)
193191
end
194192
update_reduction_status!(vparents, reduceddeps, name(reductinit))
@@ -228,7 +226,8 @@ function add_compute!(
228226
for (ind,arg) enumerate(args)
229227
if var === arg
230228
reduction_ind = ind
231-
add_reduction!(vparents, deps, reduceddeps, ls, arg, elementbytes)
229+
# add_reduction!(vparents, deps, reduceddeps, ls, arg, elementbytes)
230+
getop(ls, arg, elementbytes)
232231
elseif arg isa Expr
233232
isref, argref = tryrefconvert(ls, arg, elementbytes, varname(mpref))
234233
if isref
@@ -270,7 +269,8 @@ function add_compute!(
270269
parent = ls.opdict[var]
271270
setdiffv!(reduceddeps, deps, loopdependencies(parent))
272271
# parent = getop(ls, var, elementbytes)
273-
if length(reduceddeps) == 0
272+
# if length(reduceddeps) == 0
273+
if all(!in(deps), reduceddeps)
274274
insert!(vparents, reduction_ind, parent)
275275
mergesetv!(deps, loopdependencies(parent))
276276
op = Operation(length(operations(ls)), var, elementbytes, instr, compute, deps, reduceddeps, vparents)

src/costs.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ const COST = Dict{Symbol,InstructionCost}(
138138
:evsub => InstructionCost(4,0.5),
139139
:evmul => InstructionCost(4,0.5),
140140
:evfdiv => InstructionCost(13,4.0,-2.0),
141+
:vsum => InstructionCost(6,2.0),
142+
:vprod => InstructionCost(6,2.0),
141143
:reduced_add => InstructionCost(4,0.5),# ignoring reduction part of cost, might be nop
142144
:reduced_prod => InstructionCost(4,0.5),# ignoring reduction part of cost, might be nop
143145
:reduce_to_add => InstructionCost(0,0.0,0.0,0),
@@ -307,6 +309,10 @@ reduction_to_single_vector(x) = reduction_to_single_vector(reduction_instruction
307309
# x == 1.0 ? :vsum : x == 2.0 ? :vprod : x == 5.0 ? :maximum : x == 6.0 ? :minimum : throw("Reduction not found.")
308310
# end
309311
# reduction_to_scalar(x) = reduction_to_scalar(reduction_instruction_class(x))
312+
function reduction_to_scalar(x::Float64)
313+
x == ADDITIVE_IN_REDUCTIONS ? :vsum : x == MULTIPLICATIVE_IN_REDUCTIONS ? :vprod : x == MAX ? :vmaximum : x == MIN ? :vminimum : throw("Reduction not found.")
314+
end
315+
reduction_to_scalar(x) = reduction_to_scalar(reduction_instruction_class(x))
310316
function reduction_scalar_combine(x::Float64)
311317
# x == 1.0 ? :reduced_add : x == 2.0 ? :reduced_prod : x == 3.0 ? :reduced_any : x == 4.0 ? :reduced_all : x == 5.0 ? :reduced_max : x == 6.0 ? :reduced_min : throw("Reduction not found.")
312318
x == ADDITIVE_IN_REDUCTIONS ? :reduced_add : x == MULTIPLICATIVE_IN_REDUCTIONS ? :reduced_prod : x == MAX ? :reduced_max : x == MIN ? :reduced_min : throw("Reduction not found.")

src/lower_compute.jl

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,39 @@ function parent_unroll_status(op::Operation, u₁loop::Symbol, u₂loop::Symbol,
3636
parents_u₁syms, parents_u₂syms
3737
end
3838

39+
# """
40+
# Requires a parents_op argument, because it may `===` parents(op), due to previous divergence, e.g. to handle unrolling.
41+
# """
42+
# function isreducingidentity!(q::Expr, op::Operation, parents_op::Vector{Operation}, U::Int, u₁loop::Symbol, u₂loop::Symbol, vectorized::Symbol, suffix)
43+
# vparents = copy(parents_op) # don't mutate the original!
44+
# for (i,opp) ∈ enumerate(parents_op)
45+
# @show opp vectorized ∈ loopdependencies(opp), vectorized ∈ reducedchildren(opp) # must reduce
46+
# @show vectorized, loopdependencies(opp), reducedchildren(opp) # must reduce
47+
# if vectorized ∈ loopdependencies(opp) || vectorized ∈ reducedchildren(opp) # must reduce
48+
# loopdeps = [l for l ∈ loopdependencies(opp) if l !== vectorized]
49+
# @show opp
50+
# reductinstruct = reduction_to_scalar(instruction(opp))
51+
52+
# reducedparent = Operation(
53+
# opp.identifier, gensym(opp.variable), opp.elementbytes, Instruction(:LoopVectorization, reductinstruct), opp.node_type,
54+
# loopdeps, opp.reduced_deps, opp.parents, opp.ref, opp.reduced_children
55+
# )
56+
# pname, pu₁, pu₂ = variable_name_and_unrolled(opp, u₁loop, u₂loop, suffix)
57+
# rpname, rpu₁, rpu₂ = variable_name_and_unrolled(reducedparent, u₁loop, u₂loop, suffix)
58+
# @assert pu₁ == rpu₁ && pu₂ == rpu₂
59+
# if rpu₁
60+
# for u ∈ 0:U-1
61+
# push!(q.args, Expr(:(=), Symbol(rpname,u), Expr(:call, lv(reductinstruct), Symbol(pname,u))))
62+
# end
63+
# else
64+
# push!(q.args, Expr(:(=), rpname, Expr(:call, lv(reductinstruct), pname)))
65+
# end
66+
# vparents[i] = reducedparent
67+
# end
68+
# end
69+
# vparents
70+
# end
71+
3972
function lower_compute!(
4073
q::Expr, op::Operation, vectorized::Symbol, u₁loop::Symbol, u₂loop::Symbol, U::Int,
4174
suffix::Union{Nothing,Int}, mask::Union{Nothing,Symbol,Unsigned} = nothing,
@@ -65,7 +98,7 @@ function lower_compute!(
6598
# end
6699
# unrollsym = isunrolled_sym(op, unrolled)
67100
if !opunrolled && any(parents_u₁syms) # TODO: Clean up this mess, refactor the naming code, putting it in one place and have everywhere else use it for easy equivalence.
68-
parents_op = copy(parents_op)
101+
parents_op = copy(parents_op) # don't mutate the original!
69102
for i eachindex(parents_u₁syms)
70103
parents_u₁syms[i] || continue
71104
parents_u₁syms[i] = false
@@ -115,7 +148,13 @@ function lower_compute!(
115148
end
116149
end
117150
# @show instr.instr
118-
maskreduct = mask !== nothing && isreduct && vectorized reduceddependencies(op) #any(opp -> opp.variable === var, parents_op)
151+
reduceddeps = reduceddependencies(op)
152+
vecinreduceddeps = isreduct && vectorized reduceddeps
153+
maskreduct = mask !== nothing && vecinreduceddeps #any(opp -> opp.variable === var, parents_op)
154+
# if vecinreduceddeps && vectorized ∉ loopdependencies(op) # screen parent opps for those needing a reduction to scalar
155+
# # parents_op = reduce_vectorized_parents!(q, op, parents_op, U, u₁loop, u₂loop, vectorized, suffix)
156+
# isreducingidentity!(q, op, parents_op, U, u₁loop, u₂loop, vectorized, suffix) && return
157+
# end
119158
# if a parent is not unrolled, the compiler should handle broadcasting CSE.
120159
# because unrolled/tiled parents result in an unrolled/tiled dependendency,
121160
# we handle both the tiled and untiled case here.

src/lowering.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ function lower(ls::LoopSet, us::UnrollSpecification)
424424
end
425425

426426
function lower(ls::LoopSet, order, u₁loop, u₂loop, vectorized, u₁, u₂)
427-
fillorder!(ls, order, u₁loop, u₂loop, u₂ != -1)
427+
fillorder!(ls, order, u₁loop, u₂loop, u₂ != -1, vectorized)
428428
q = lower(ls, UnrollSpecification(ls, u₁loop, u₂loop, vectorized, u₁, u₂))
429429
iszero(length(ls.opdict)) && pushfirst!(q.args, Expr(:meta, :inline))
430430
q

src/operation_evaluation_order.jl

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,31 @@
1515
# function iterate()
1616

1717
# end
18-
function isnopidentity(op::Operation, u₁loop::Symbol, u₂loop::Symbol, suffix)
18+
19+
function dependent_outer_reducts(ls::LoopSet, op)
20+
for i ls.outer_reductions
21+
search_tree(parents(operations(ls)[i]), name(op)) && return true
22+
end
23+
false
24+
end
25+
26+
function isnopidentity(ls::LoopSet, op::Operation, u₁loop::Symbol, u₂loop::Symbol, vectorized::Symbol, suffix)
1927
parents_op = parents(op)
2028
if iscompute(op) && instruction(op).instr === :identity && name(first(parents_op)) === name(op) && isone(length(parents_op))
2129
mvar, u₁unrolledsym, u₂unrolledsym = variable_name_and_unrolled(op, u₁loop, u₂loop, suffix)
2230
parents_u₁syms, parents_u₂syms = parent_unroll_status(op, u₁loop, u₂loop, suffix)
2331
if (u₁unrolledsym == first(parents_u₁syms)) && ((!isnothing(suffix)) == parents_u₂syms[1])
24-
true
32+
#TODO: identifer(first(parents_op)) ∉ ls.outer_reductions is going to miss a lot of cases
33+
#Should probably replace that with `DVec` (demoting Vec) types, that demote to scalar.
34+
if (vectorized loopdependencies(first(parents_op)) && vectorized loopdependencies(op)) && !dependent_outer_reducts(ls, op)
35+
op.instruction = reduction_to_scalar(instruction(first(parents_op)))
36+
op.mangledvariable = gensym(op.mangledvariable)
37+
false
38+
else
39+
true
40+
end
41+
else
42+
false
2543
end
2644
else
2745
false
@@ -30,7 +48,6 @@ end
3048

3149
function set_upstream_family!(adal::Vector{T}, op::Operation, val::T, ld::Vector{Symbol}, id::Int) where {T}
3250
adal[identifier(op)] == val && return # must already have been set
33-
# @show op
3451
if ld != loopdependencies(op) || id == identifier(op)
3552
(adal[identifier(op)] = val)
3653
end
@@ -41,26 +58,27 @@ function set_upstream_family!(adal::Vector{T}, op::Operation, val::T, ld::Vector
4158
end
4259

4360
function addoptoorder!(
44-
lo::LoopOrder, included_vars::Vector{Bool}, place_after_loop::Vector{Bool}, op::Operation, loopsym::Symbol, _n::Int, u₁loop::Symbol, u₂loop::Symbol, loopistiled::Bool
61+
ls::LoopSet, included_vars::Vector{Bool}, place_after_loop::Vector{Bool}, op::Operation,
62+
loopsym::Symbol, _n::Int, u₁loop::Symbol, u₂loop::Symbol, vectorized::Symbol, loopistiled::Bool
4563
)
64+
lo = ls.loop_order
4665
id = identifier(op)
4766
included_vars[id] && return nothing
4867
loopsym loopdependencies(op) || return nothing
4968
for opp parents(op) # ensure parents are added first
50-
addoptoorder!(lo, included_vars, place_after_loop, opp, loopsym, _n, u₁loop, u₂loop, loopistiled)
69+
addoptoorder!(ls, included_vars, place_after_loop, opp, loopsym, _n, u₁loop, u₂loop, vectorized, loopistiled)
5170
end
5271
included_vars[id] && return nothing
5372
included_vars[id] = true
5473
isunrolled = (u₁loop loopdependencies(op)) + 1
5574
istiled = u₂loop loopdependencies(op)
5675
# optype = Int(op.node_type) + 1
5776
after_loop = place_after_loop[id] + 1
58-
# @show place_after_loop[id], op
5977
if !isloopvalue(op)
6078
if istiled
61-
isnopidentity(op, u₁loop, u₂loop, 0) || push!(lo[isunrolled,2,after_loop,_n], op)
79+
isnopidentity(ls, op, u₁loop, u₂loop, vectorized, 0) || push!(lo[isunrolled,2,after_loop,_n], op)
6280
else
63-
isnopidentity(op, u₁loop, u₂loop, nothing) || push!(lo[isunrolled,1,after_loop,_n], op)
81+
isnopidentity(ls, op, u₁loop, u₂loop, vectorized, nothing) || push!(lo[isunrolled,1,after_loop,_n], op)
6482
end
6583
end
6684
# isloopvalue(op) || push!(lo[isunrolled,istiled,after_loop,_n], op)
@@ -69,7 +87,7 @@ function addoptoorder!(
6987
nothing
7088
end
7189

72-
function fillorder!(ls::LoopSet, order::Vector{Symbol}, unrolled::Symbol, tiled::Symbol, loopistiled::Bool)
90+
function fillorder!(ls::LoopSet, order::Vector{Symbol}, u₁loop::Symbol, u₂loop::Symbol, loopistiled::Bool, vectorized::Symbol)
7391
lo = ls.loop_order
7492
resize!(lo, length(ls.loopsymbols))
7593
ro = lo.loopnames # reverse order; will have same order as lo
@@ -85,7 +103,7 @@ function fillorder!(ls::LoopSet, order::Vector{Symbol}, unrolled::Symbol, tiled:
85103
ro[_n] = loopsym = order[n]
86104
#loopsym = order[n]
87105
for op ops
88-
addoptoorder!( lo, included_vars, place_after_loop, op, loopsym, _n, unrolled, tiled, loopistiled )
106+
addoptoorder!( ls, included_vars, place_after_loop, op, loopsym, _n, u₁loop, u₂loop, vectorized, loopistiled )
89107
end
90108
end
91109
end

test/gemm.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
@testset "GEMM" begin
2-
# using LoopVectorization, LinearAlgebra, Test; T = Float64
2+
# using LoopVectorization, LinearAlgebra, Test; T = Float64
33
Unum, Tnum = LoopVectorization.VectorizationBase.REGISTER_COUNT == 16 ? (3, 4) : (5, 5)
44
AmulBtq1 = :(for m 1:size(A,1), n 1:size(B,2)
55
C[m,n] = zeroB
@@ -102,8 +102,17 @@
102102
ΔCₘₙ += A[m,k] * B[k,n]
103103
end
104104
C[m,n] += ΔCₘₙ * factor
105-
end)
105+
end);
106106
lsAmuladd = LoopVectorization.LoopSet(Amuladdq);
107+
Atmuladdq = :(for m 1:size(A,2), n 1:size(B,2)
108+
ΔCₘₙ = zero(eltype(C))
109+
for k 1:size(A,1)
110+
ΔCₘₙ += A[k,m] * B[k,n]
111+
end
112+
C[m,n] += ΔCₘₙ * factor
113+
end);
114+
lsAtmuladd = LoopVectorization.LoopSet(Atmuladdq);
115+
LoopVectorization.lower(lsAtmuladd, 2, 2)
107116
# lsAmuladd.operations
108117
# LoopVectorization.loopdependencies.(lsAmuladd.operations)
109118
# LoopVectorization.reduceddependencies.(lsAmuladd.operations)
@@ -520,7 +529,7 @@
520529
C[m,n] = Cmn_hi
521530
end
522531
end
523-
532+
524533
function threegemms!(Ab, Bb, Cb, A, B, C)
525534
M, N = size(Cb); K = size(B,1)
526535
@avx for m in 1:M, k in 1:K, n in 1:N

0 commit comments

Comments
 (0)