Skip to content

Commit 3b33759

Browse files
committed
Factor out check for ref already being in ref_aliasing_syms.
1 parent dc97301 commit 3b33759

File tree

3 files changed

+13
-7
lines changed

3 files changed

+13
-7
lines changed

src/add_loads.jl

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
1-
function add_load!(ls::LoopSet, op::Operation, actualarray::Bool = true, broadcast::Bool = false)
2-
@assert isload(op)
3-
ref = op.ref
1+
function maybeaddref!(ls::LoopSet, op, ref)
42
id = findfirst(r -> r == ref, ls.refs_aliasing_syms)
53
# try to CSE
6-
if id === nothing
4+
if isnothing(id)
75
push!(ls.syms_aliasing_refs, name(op))
86
push!(ls.refs_aliasing_syms, ref)
7+
0
98
else
9+
id
10+
end
11+
end
12+
13+
function add_load!(ls::LoopSet, op::Operation, actualarray::Bool = true, broadcast::Bool = false)
14+
@assert isload(op)
15+
ref = op.ref
16+
if (id = maybeaddref!(ls, op, ref)) > 0 # try to CSE
1017
opp = ls.opdict[ls.syms_aliasing_refs[id]] # throw an error if not found.
1118
return isstore(opp) ? getop(ls, first(parents(opp))) : opp
12-
end
19+
end
1320
add_vptr!(ls, op.ref.ref.array, vptr(op), actualarray, broadcast)
1421
pushop!(ls, op, name(op))
1522
end

src/add_stores.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ function cse_store!(ls::LoopSet, op::Operation)
77
ls.operations[id] = op
88
ls.opdict[op.variable] = op
99
end
10-
function add_store!(ls::LoopSet, op::Operation, add_pvar::Bool = name(first(parents(op))) ls.syms_aliasing_refs)
10+
function add_store!(ls::LoopSet, op::Operation, add_pvar::Bool = !any(r -> r == op.ref, ls.refs_aliasing_syms))
1111
@assert isstore(op)
1212
if add_pvar
1313
push!(ls.syms_aliasing_refs, name(first(parents(op))))

src/split_loops.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,6 @@ function lower_and_split_loops(ls::LoopSet, inline::Int)
101101
# U_1 = T_1 = U_2 = T_2 = 2
102102
# @show cost_1 + cost_2 ≤ cost_fused, cost_1, cost_2, cost_fused
103103
if cost_1 + cost_2 cost_fused
104-
# @show cost_1, cost_2, cost_fused
105104
ls_2_lowered = if length(remaining_ops) > 1
106105
inline = iszero(inline) ? (shouldinline_1 % Int) : inline
107106
lower_and_split_loops(ls_2, inline)

0 commit comments

Comments
 (0)