Skip to content

Commit 97daeca

Browse files
committed
Fixed a bug where loops of the form a[i] = f(a[i]) weren't handled correctly.
1 parent 58db05b commit 97daeca

File tree

5 files changed

+69
-27
lines changed

5 files changed

+69
-27
lines changed

src/determinestrategy.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,12 @@ function determine_unroll_factor(
114114

115115
# The strategy is to use an unroll factor of 1, unless there appears to be loop carried dependencies (ie, num_reductions > 0)
116116
# The assumption here is that unrolling provides no real benefit, unless it is needed to enable OOO execution by breaking up these dependency chains
117-
num_reductions = sum(isreduction, operations(ls))
117+
num_reductions = 0#sum(isreduction, operations(ls))
118+
for op operations(ls)
119+
if isreduction(op) & iscompute(op)
120+
num_reductions += 1
121+
end
122+
end
118123
# @show num_reductions
119124
if iszero(num_reductions) # the 4 is a hack, based on the idea that there is some cost to moving through columns
120125
return length(order) == 1 ? 1 : 4

src/graphs.jl

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,13 @@ struct LoopSet
8888
# ref_to_sym_aliases::Dict{ArrayReference,Symbol}
8989
end
9090

91+
# function op_to_ref(ls::LoopSet, op::Operation)
92+
# s = op.variable
93+
# id = findfirst(ls.syms_aliasing_regs)
94+
# @assert id !== nothing
95+
# ls.refs_aliasing_syms[id]
96+
# end
97+
9198
function includesarray(ls::LoopSet, array::Symbol)
9299
for (a,i) ls.includedarrays
93100
a === array && return i
@@ -235,7 +242,7 @@ function add_load!(
235242
:getindex, memload, loopdependencies(ref),
236243
NODEPENDENCY, NOPARENTS, ref
237244
)
238-
add_vptr!(ls, indexed, identifier(op))
245+
add_vptr!(ls, ref.array, identifier(op))
239246
pushop!(ls, op, var)
240247
end
241248
function add_load_ref!(ls::LoopSet, var::Symbol, ex::Expr, elementbytes::Int = 8)
@@ -311,12 +318,16 @@ function maybe_cse_load!(ls::LoopSet, expr::Expr, elementbytes::Int = 8)
311318
else
312319
return add_operation!(ls, gensym(:temporary), expr, elementbytes)
313320
end
314-
ref = ArrayReference( ex.args[1+offset], @view(ex.args[2+offset:end]) )::ArrayReference
321+
ref = ArrayReference(
322+
expr.args[1+offset],
323+
@view(expr.args[2+offset:end]),
324+
Ref(false)
325+
)::ArrayReference
315326
id = findfirst(r -> r == ref, ls.refs_aliasing_syms)
316327
if id === nothing
317-
add_load!( ls, gensym(:temporary), array, args, elementbytes )
328+
add_load!( ls, gensym(:temporary), ref, elementbytes )
318329
else
319-
ls.syms_aliasing_refs[id]
330+
getop(ls, ls.syms_aliasing_refs[id])
320331
end
321332
# id = includesarray(ls, array)
322333
# if id > 0
@@ -371,7 +382,7 @@ function add_compute!(ls::LoopSet, var::Symbol, ex::Expr, elementbytes::Int = 8,
371382
# op = Operation( length(operations(ls)), var, elementbytes, instr, compute )
372383
reduction = false
373384
for arg args
374-
if arg === var
385+
if var === arg
375386
reduction = true
376387
add_reduction!(parents, deps, reduceddeps, ls, arg, elementbytes)
377388
elseif ref == arg
@@ -389,18 +400,20 @@ function add_compute!(ls::LoopSet, var::Symbol, ex::Expr, elementbytes::Int = 8,
389400
end
390401
end
391402
function add_store!(
392-
ls::LoopSet, indexed::Symbol, var::Symbol, indices::AbstractVector, elementbytes::Int = 8
403+
ls::LoopSet, var::Symbol, ref::ArrayReference, elementbytes::Int = 8
393404
)
394405
parent = getop(ls, var)
395-
op = Operation( length(operations(ls)), indexed, elementbytes, :setindex!, memstore, indices, reduceddependencies(parent), [parent] )
396-
add_vptr!(ls, indexed, identifier(op))
406+
op = Operation( length(operations(ls)), ref.array, elementbytes, :setindex!, memstore, loopdependencies(ref), reduceddependencies(parent), [parent], ref )
407+
add_vptr!(ls, ref.array, identifier(op))
397408
pushop!(ls, op, var)
398409
end
399410
function add_store_ref!(ls::LoopSet, var::Symbol, ex::Expr, elementbytes::Int = 8)
400-
add_store!(ls, ex.args[1], var, @view(ex.args[2:end]), elementbytes)
411+
ref = ref_from_ref(ex)
412+
add_store!(ls, var, ref, elementbytes)
401413
end
402414
function add_store_setindex!(ls::LoopSet, ex::Expr, elementbytes::Int = 8)
403-
add_store!(ls, ex.args[2], ex.args[3], @view(ex.args[4:end]), elementbytes)
415+
ref = ref_from_setindex(ex)
416+
add_store!(ls, var, ref, elementbytes)
404417
end
405418
# add operation assigns X to var
406419
function add_operation!(
@@ -418,6 +431,21 @@ function add_operation!(
418431
throw("Expression not recognized:\n$x")
419432
end
420433
end
434+
function add_operation!(
435+
ls::LoopSet, LHS_sym::Symbol, RHS::Expr, LHS_ref::ArrayReference, elementbytes::Int = 8
436+
)
437+
if RHS.head === :ref# || (RHS.head === :call && first(RHS.args) === :getindex)
438+
add_load!(ls, LHS_sym, LHS_ref, elementbytes)
439+
elseif RHS.head === :call
440+
if first(RHS.args) === :getindex
441+
add_load!(ls, LHS_sym, LHS_ref, elementbytes)
442+
else
443+
add_compute!(ls, LHS_sym, RHS, elementbytes, LHS_ref)
444+
end
445+
else
446+
throw("Expression not recognized:\n$x")
447+
end
448+
end
421449
function Base.push!(ls::LoopSet, ex::Expr, elementbytes::Int = 8)
422450
if ex.head === :call
423451
finex = first(ex.args)::Symbol
@@ -446,7 +474,9 @@ function Base.push!(ls::LoopSet, ex::Expr, elementbytes::Int = 8)
446474
ref = ArrayReference(LHS)
447475
id = findfirst(r -> r == ref, ls.refs_aliasing_syms)
448476
lrhs = id === nothing ? gensym(:RHS) : ls.syms_aliasing_refs[id]
449-
add_operation!(ls, lrhs, RHS, elementbytes, ref)
477+
# we pass ref, so it can compare references within RHS, and realize
478+
# they equal lrhs
479+
add_operation!(ls, lrhs, RHS, ref, elementbytes)
450480
end
451481
add_store_ref!(ls, lrhs, LHS, elementbytes)
452482
else

src/lowering.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
function mem_offset(op::Operation, incr::Int = 0)
66
@assert accesses_memory(op) "Computing memory offset only makes sense for operations that access memory."
77
ret = Expr(:tuple, )
8-
deps = op.dependencies
8+
deps = op.ref.ref
99
if incr == 0
1010
append!(ret.args, deps)
1111
else
12-
push!(ret.args, Expr(:call, :+, first(deps), incr))
13-
for n 2:length(deps)
12+
dep = first(deps)
13+
push!(ret.args, dep isa Symbol ? Expr(:call, :+, dep, incr) : dep + incr)
14+
for n 2:length(deps)
1415
push!(ret.args, deps[n])
1516
end
1617
end
@@ -19,7 +20,7 @@ end
1920
function mem_offset(op::Operation, incr::Int, unrolled::Symbol)
2021
@assert accesses_memory(op) "Computing memory offset only makes sense for operations that access memory."
2122
ret = Expr(:tuple, )
22-
deps = op.dependencies
23+
deps = op.ref.ref
2324
if incr == 0
2425
append!(ret.args, deps)
2526
else
@@ -55,7 +56,7 @@ function lower_load_scalar!(
5556
if suffix !== nothing
5657
var = Symbol(var, :_, suffix)
5758
end
58-
ptr = Symbol("##vptr##_", first(op.reduced_deps))
59+
ptr = refname(op)
5960
push!(q.args, Expr(:(=), Symbol("##", var), Expr(:call, lv(:load), ptr, mem_offset(op))))
6061
nothing
6162
end
@@ -69,7 +70,7 @@ function lower_load_unrolled!(
6970
if suffix !== nothing
7071
var = Symbol(var, :_, suffix)
7172
end
72-
ptr = Symbol("##vptr##_", first(op.reduced_deps))
73+
ptr = refname(op)
7374
val = Expr(:call, Expr(:curly, :Val, W))
7475
if first(loopdependencies(op)) === unrolled # vload
7576
for u 0:U-1
@@ -157,7 +158,7 @@ function lower_store_reduction!(
157158
if suffix !== nothing
158159
var = Symbol(var, :_, suffix)
159160
end
160-
ptr = Symbol("##vptr##_", op.variable)
161+
ptr = refname(op)
161162
# need to find out reduction type
162163
instr = first(parents(op)).instruction
163164
reduce_expr!(q, var, instr, U) # assigns reduction to storevar
@@ -174,7 +175,7 @@ function lower_store_scalar!(
174175
if suffix !== nothing
175176
var = Symbol(var, :_, suffix)
176177
end
177-
ptr = Symbol("##vptr##_", op.variable)
178+
ptr = refname(op)
178179
push!(q.args, Expr(:call, lv(:store!), ptr, Symbol("##", var), mem_offset(op)))
179180
nothing
180181
end
@@ -188,7 +189,7 @@ function lower_store_unrolled!(
188189
if suffix !== nothing
189190
var = Symbol(var, :_, suffix)
190191
end
191-
ptr = Symbol("##vptr##_", op.variable)
192+
ptr = refname(op)
192193
if first(loopdependencies(op)) === unrolled # vstore!
193194
for u 0:U-1
194195
instrcall = Expr(:call,lv(:vstore!), ptr, Symbol("##",var,:_,u), mem_offset(op, u*W))

src/operations.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@ end
2323

2424
Base.:(==)(x::ArrayReference, y::ArrayReference) = isequal(x, y)
2525

26-
function ref_from_ref(ex::Expr)
27-
ArrayReference( ex.args[1], @view(ex.args[2:end]), Ref(false) )
28-
end
29-
function ref_from_getindex(ex::Expr)
30-
ArrayReference( ex.args[2], @view(ex.args[3:end]), Ref(false) )
26+
function ref_from_expr(ex, offset1::Int = 0, offset2 = 0)
27+
ArrayReference( ex.args[1 + offset1], @view(ex.args[2 + offset2:end]), Ref(false) )
3128
end
29+
ref_from_ref(ex::Expr) = ref_from_expr(ex, 0, 0)
30+
ref_from_getindex(ex::Expr) = ref_from_expr(ex, 1, 1)
31+
ref_from_setindex(ex::Expr) = ref_from_expr(ex, 1, 2)
3232
function ArrayReference(ex::Expr)
3333
ex.head === :ref ? ref_from_ref(ex) : ref_from_getindex(ex)
3434
end
@@ -46,7 +46,7 @@ Base.:(==)(x::ArrayReference, y) = false
4646

4747

4848
# Avoid memory allocations by accessing this
49-
const NOTAREFERENCE = ArrayReference(Symbol(""), Union{Symbol,Int}[])
49+
const NOTAREFERENCE = ArrayReference(Symbol(""), Union{Symbol,Int}[], Ref(false))
5050

5151
@enum OperationType begin
5252
constant
@@ -128,6 +128,9 @@ identifier(op::Operation) = op.identifier + 1
128128
name(op::Operation) = op.variable
129129
instruction(op::Operation) = op.instruction
130130

131+
refname(op::Operation) = Symbol("##vptr##_", op.ref.array)
132+
133+
131134
"""
132135
Returns `0` if the op is the declaration of the constant outerreduction variable.
133136
Returns `n`, where `n` is the constant declarations's index among parents(op), if op is an outter reduction.

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ function mygemmavx!(C, A, B)
7777
end
7878

7979
M, K, N = rand(70:81, 3);
80+
M, K, N = 72, 75, 71;
8081
C = Matrix{Float64}(undef, M, N); A = randn(M, K); B = randn(K, N);
8182
C2 = similar(C);
8283
mygemmavx!(C, A, B)
@@ -328,6 +329,7 @@ lscolsum = LoopVectorization.LoopSet(colsumq);
328329
lscolsum
329330
lscolsum.operations
330331

332+
LoopVectorization.choose_order(lscolsum)
331333
@test LoopVectorization.choose_order(lscolsum) == (Symbol[:j,:i], 4, -1)
332334

333335
function mycolsum!(x, A)
@@ -370,6 +372,7 @@ varq = :(for j ∈ eachindex(s²), i ∈ 1:size(A,2)
370372
s²[j] += δ*δ
371373
end)
372374
lsvar = LoopVectorization.LoopSet(varq);
375+
LoopVectorization.choose_order(lsvar)
373376
@test LoopVectorization.choose_order(lsvar) == (Symbol[:j,:i], 4, -1)
374377

375378
function myvar!(s², A, x̄)

0 commit comments

Comments
 (0)