Skip to content

Commit 9bfa9a5

Browse files
committed
Added possibility of splitting loops.
1 parent 1b9c97e commit 9bfa9a5

File tree

4 files changed

+91
-3
lines changed

4 files changed

+91
-3
lines changed

src/LoopVectorization.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ include("lower_memory_common.jl")
4646
include("lower_load.jl")
4747
include("lower_store.jl")
4848
include("lowering.jl")
49+
include("split_loops.jl")
4950
include("condense_loopset.jl")
5051
include("reconstruct_loopset.jl")
5152
include("constructors.jl")

src/determinestrategy.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -659,10 +659,10 @@ function choose_order(ls::LoopSet)
659659
end
660660
uorder, uvec, uc = choose_unroll_order(ls, tc)
661661
if num_loops(ls) > 1 && tc uc
662-
return torder, tunroll, ttile, tvec, min(tU, tT), tT
662+
return torder, tunroll, ttile, tvec, min(tU, tT), tT, tc
663663
# return torder, tvec, 4, 4#5, 5
664664
else
665-
return uorder, first(uorder), Symbol("##undefined##"), uvec, determine_unroll_factor(ls, uorder, first(uorder), uvec), -1
665+
return uorder, first(uorder), Symbol("##undefined##"), uvec, determine_unroll_factor(ls, uorder, first(uorder), uvec), -1, uc
666666
end
667667
end
668668

src/reconstruct_loopset.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ function avx_loopset(instr, ops, arf, AM, LPSYM, LB, vargs)
391391
end
392392
function avx_body(ls, UT)
393393
U, T = UT
394-
q = iszero(U) ? lower(ls) : lower(ls, U, T)
394+
q = iszero(U) ? lower_and_split_loops(ls) : lower(ls, U, T)
395395
length(ls.outer_reductions) == 0 ? push!(q.args, nothing) : push!(q.args, loopset_return_value(ls, Val(true)))
396396
# @show q
397397
q

src/split_loops.jl

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
2+
3+
function add_operation!(ls_new::LoopSet, included::Vector{Int}, ls::LoopSet, op::Operation)
4+
newid = included[identifier(op)]
5+
iszero(newid) || return operations(ls_new)[newid]
6+
vparents = Operation[]
7+
for opp parents(op)
8+
push!(vparents, add_operation!(ls_new, included, ls, opp))
9+
end
10+
opnew = Operation(
11+
length(operations(ls_new)), name(op), op.elementbytes, instruction(op), op.node_type,
12+
loopdependencies(op), reduceddependencies(op), vparents, ref(op), reducedchildren(op)
13+
)
14+
included[identifier(op)] = identifier(opnew)
15+
opnew
16+
end
17+
18+
function append_if_included!(vnew, vold, included)
19+
for (i, v) vold
20+
id = included[i]
21+
iszero(id) && continue
22+
push!(vnew, (id, v))
23+
end
24+
end
25+
26+
function split_loopset(ls::LoopSet, ids)
27+
ls_new = LoopSet(:LoopVectorization)
28+
included = zeros(Int, length(operations(ls)))
29+
for i ids
30+
add_operation!(ls_new, included, ls, operations(ls)[i])
31+
end
32+
for op operations(ls_new)
33+
for l loopdependencies(op)
34+
if l ls_new.loopsymbols
35+
add_loop!(ls_new, getloop(ls, l))
36+
end
37+
end
38+
length(ls_new.loopsymbols) == length(ls.loopsymbols) && break
39+
end
40+
append_if_included!(ls_new.preamble_symsym, ls.preamble_symsym, included)
41+
append_if_included!(ls_new.preamble_symint, ls.preamble_symint, included)
42+
append_if_included!(ls_new.preamble_symfloat, ls.preamble_symfloat, included)
43+
append_if_included!(ls_new.preamble_zeros, ls.preamble_zeros, included)
44+
append_if_included!(ls_new.preamble_ones, ls.preamble_ones, included)
45+
ls_new
46+
end
47+
48+
49+
function lower_and_split_loops(ls::LoopSet)
50+
ops = operations(ls)
51+
split_candidates = Int[]
52+
for op ops
53+
isstore(op) && push!(split_candidates, identifier(op))
54+
end
55+
for i ls.outer_reductions
56+
push!(split_candidates, i)
57+
end
58+
length(split_candidates) > 1 || return lower(ls)
59+
order_fused, unrolled_fused, tiled_fused, vectorized_fused, U_fused, T_fused, cost_fused = choose_order(ls)
60+
remaining_ops = Vector{Int}(undef, length(split_candidates) - 1); split_1 = Int[0];
61+
for (ind,i) enumerate(split_candidates)
62+
split_1[1] = i
63+
ls_1 = split_loopset(ls, split_1)
64+
order_1, unrolled_1, tiled_1, vectorized_1, U_1, T_1, cost_1 = choose_order(ls_1)
65+
reaminig_ops[1:ind-1] .= @view(split_candidates[1:ind-1]); reaminig_ops[ind:end] .= @view(split_candidates[ind+1:end])
66+
ls_2 = split_loopset(ls, remaining_ops)
67+
order_2, unrolled_2, tiled_2, vectorized_2, U_2, T_2, cost_2 = choose_order(ls_2)
68+
if cost_1 + cost_2 < cost_fused
69+
ls_2_lowered = if length(remaining_ops) > 1
70+
lower_and_split_loops(ls_2)
71+
else
72+
lower(ls_2, unrolled_2, tiled_2, vectorized_2, U_2, T_2)
73+
end
74+
Expr(
75+
:block,
76+
ls.preamble,
77+
lower(ls_1, unrolled_1, tiled_1, vectorized_1, U_1, T_1),
78+
ls_2_lowered
79+
)
80+
end
81+
end
82+
lower(ls, order_fused, unrolled_fused, tiled_fused, vectorized_fused, U_fused, T_fused)
83+
end
84+
85+
86+
87+

0 commit comments

Comments
 (0)