|
| 1 | + |
| 2 | + |
| 3 | +function add_operation!(ls_new::LoopSet, included::Vector{Int}, ls::LoopSet, op::Operation) |
| 4 | + newid = included[identifier(op)] |
| 5 | + iszero(newid) || return operations(ls_new)[newid] |
| 6 | + vparents = Operation[] |
| 7 | + for opp ∈ parents(op) |
| 8 | + push!(vparents, add_operation!(ls_new, included, ls, opp)) |
| 9 | + end |
| 10 | + opnew = Operation( |
| 11 | + length(operations(ls_new)), name(op), op.elementbytes, instruction(op), op.node_type, |
| 12 | + loopdependencies(op), reduceddependencies(op), vparents, ref(op), reducedchildren(op) |
| 13 | + ) |
| 14 | + included[identifier(op)] = identifier(opnew) |
| 15 | + opnew |
| 16 | +end |
| 17 | + |
| 18 | +function append_if_included!(vnew, vold, included) |
| 19 | + for (i, v) ∈ vold |
| 20 | + id = included[i] |
| 21 | + iszero(id) && continue |
| 22 | + push!(vnew, (id, v)) |
| 23 | + end |
| 24 | +end |
| 25 | + |
| 26 | +function split_loopset(ls::LoopSet, ids) |
| 27 | + ls_new = LoopSet(:LoopVectorization) |
| 28 | + included = zeros(Int, length(operations(ls))) |
| 29 | + for i ∈ ids |
| 30 | + add_operation!(ls_new, included, ls, operations(ls)[i]) |
| 31 | + end |
| 32 | + for op ∈ operations(ls_new) |
| 33 | + for l ∈ loopdependencies(op) |
| 34 | + if l ∉ ls_new.loopsymbols |
| 35 | + add_loop!(ls_new, getloop(ls, l)) |
| 36 | + end |
| 37 | + end |
| 38 | + length(ls_new.loopsymbols) == length(ls.loopsymbols) && break |
| 39 | + end |
| 40 | + append_if_included!(ls_new.preamble_symsym, ls.preamble_symsym, included) |
| 41 | + append_if_included!(ls_new.preamble_symint, ls.preamble_symint, included) |
| 42 | + append_if_included!(ls_new.preamble_symfloat, ls.preamble_symfloat, included) |
| 43 | + append_if_included!(ls_new.preamble_zeros, ls.preamble_zeros, included) |
| 44 | + append_if_included!(ls_new.preamble_ones, ls.preamble_ones, included) |
| 45 | + ls_new |
| 46 | +end |
| 47 | + |
| 48 | + |
| 49 | +function lower_and_split_loops(ls::LoopSet) |
| 50 | + ops = operations(ls) |
| 51 | + split_candidates = Int[] |
| 52 | + for op ∈ ops |
| 53 | + isstore(op) && push!(split_candidates, identifier(op)) |
| 54 | + end |
| 55 | + for i ∈ ls.outer_reductions |
| 56 | + push!(split_candidates, i) |
| 57 | + end |
| 58 | + length(split_candidates) > 1 || return lower(ls) |
| 59 | + order_fused, unrolled_fused, tiled_fused, vectorized_fused, U_fused, T_fused, cost_fused = choose_order(ls) |
| 60 | + remaining_ops = Vector{Int}(undef, length(split_candidates) - 1); split_1 = Int[0]; |
| 61 | + for (ind,i) ∈ enumerate(split_candidates) |
| 62 | + split_1[1] = i |
| 63 | + ls_1 = split_loopset(ls, split_1) |
| 64 | + order_1, unrolled_1, tiled_1, vectorized_1, U_1, T_1, cost_1 = choose_order(ls_1) |
| 65 | + reaminig_ops[1:ind-1] .= @view(split_candidates[1:ind-1]); reaminig_ops[ind:end] .= @view(split_candidates[ind+1:end]) |
| 66 | + ls_2 = split_loopset(ls, remaining_ops) |
| 67 | + order_2, unrolled_2, tiled_2, vectorized_2, U_2, T_2, cost_2 = choose_order(ls_2) |
| 68 | + if cost_1 + cost_2 < cost_fused |
| 69 | + ls_2_lowered = if length(remaining_ops) > 1 |
| 70 | + lower_and_split_loops(ls_2) |
| 71 | + else |
| 72 | + lower(ls_2, unrolled_2, tiled_2, vectorized_2, U_2, T_2) |
| 73 | + end |
| 74 | + Expr( |
| 75 | + :block, |
| 76 | + ls.preamble, |
| 77 | + lower(ls_1, unrolled_1, tiled_1, vectorized_1, U_1, T_1), |
| 78 | + ls_2_lowered |
| 79 | + ) |
| 80 | + end |
| 81 | + end |
| 82 | + lower(ls, order_fused, unrolled_fused, tiled_fused, vectorized_fused, U_fused, T_fused) |
| 83 | +end |
| 84 | + |
| 85 | + |
| 86 | + |
| 87 | + |
0 commit comments