Skip to content

Commit b588263

Browse files
committed
Improvement of stridepenalty calc. Split loops currently seems fragile; need to correctly split load->stores into same memory location.
1 parent 7c86064 commit b588263

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

src/determinestrategy.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -499,8 +499,11 @@ function stride_penalty(ls::LoopSet, op::Operation, order::Vector{Symbol})
499499
iter = 1
500500
for i 0:num_loops - 1
501501
loopsym = order[num_loops - i]
502-
loopsym === contigsym && return iter
503-
iter *= length(getloop(ls, loopsym))
502+
if loopsym === contigsym
503+
return iter
504+
elseif loopsym loopdependencies(op)
505+
iter *= length(getloop(ls, loopsym))
506+
end
504507
end
505508
iter
506509
end

src/split_loops.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,21 @@ function add_operation!(ls_new::LoopSet, included::Vector{Int}, ls::LoopSet, op:
55
iszero(newid) || return operations(ls_new)[newid]
66
vparents = Operation[]
77
for opp parents(op)
8+
# TODO: get it so that
9+
# a[i] = f(a[i]) will split into one loop computing and storing f(a[i]), and the other loading from that storage if it needs it.
10+
# if iscompute(opp) && (!isstore(op)) # search for stores
11+
# found = false
12+
# for oppp ∈ operations(ls)
13+
# isstore(oppp) || continue
14+
# if first(parents(oppp)) === op
15+
# found = true
16+
17+
# push!(vparents, add_operation!(ls_new, included, ls, opppp))
18+
# break
19+
# end
20+
# end
21+
# found && continue
22+
# end
823
push!(vparents, add_operation!(ls_new, included, ls, opp))
924
end
1025
opnew = Operation(
@@ -84,7 +99,7 @@ function lower_and_split_loops(ls::LoopSet, inline::Int)
8499
order_2, unrolled_2, tiled_2, vectorized_2, U_2, T_2, cost_2, shouldinline_2 = choose_order_cost(ls_2)
85100
# U_1 = T_1 = U_2 = T_2 = 2
86101
if cost_1 + cost_2 cost_fused
87-
# @show cost_1, cost_2 cost_fused
102+
@show cost_1, cost_2 cost_fused
88103
ls_2_lowered = if length(remaining_ops) > 1
89104
inline = iszero(inline) ? (shouldinline_1 % Int) : inline
90105
lower_and_split_loops(ls_2, inline)

0 commit comments

Comments
 (0)