@@ -16,6 +16,25 @@ const CACHELINE_SIZE = something(VectorizationBase.L₁CACHE.linesize, 64)
16
16
# factor = instruction(op).instr ∈ (:+, :vadd, :add_fast, :evadd) ? 1 : 10
17
17
# newapp * factor
18
18
# end
19
+ function check_linear_parents (ls:: LoopSet , op:: Operation , s:: Symbol )
20
+ (s ∈ loopdependencies (op)) || return true
21
+ if isload (op) # TODO : handle loading from ranges.
22
+ return false
23
+ elseif ! iscompute (op)
24
+ return true
25
+ end
26
+ op_is_linear = false
27
+ instr_op = instruction (op). instr
28
+ for instr ∈ (:(+ ), :vadd , :vadd1 , :add_fast , :(- ), :vsub , :sub_fast )
29
+ (op_is_linear = instr === instr_op) && break
30
+ end
31
+ op_is_linear || return false
32
+ for opp ∈ parents (op)
33
+ check_linear_parents (ls, opp, s) || return false
34
+ end
35
+ true
36
+ end
37
+
19
38
function findparent (ls:: LoopSet , s:: Symbol )# opdict isn't filled when reconstructing
20
39
id = findfirst (op -> name (op) === s, operations (ls))
21
40
id === nothing && throw (" $s not found" )
@@ -33,6 +52,10 @@ function unitstride(ls::LoopSet, op::Operation, s::Symbol)
33
52
# parent = findparent(ls, fi)
34
53
# indexappearences(parent, s) > 1 && return false
35
54
end
55
+ if length (li) > 0 && ! first (li)
56
+ parent = findparent (ls, first (inds))
57
+ check_linear_parents (ls, parent, s) || return false
58
+ end
36
59
for i ∈ 2 : length (inds)
37
60
if li[i]
38
61
s === inds[i] && return false
@@ -164,6 +187,7 @@ function evaluate_cost_unroll(
164
187
end
165
188
end
166
189
included_vars[id] = true
190
+ # @show op, cost(ls, op, vectorized, Wshift, size_T)
167
191
total_cost += iter * first (cost (ls, op, vectorized, Wshift, size_T))
168
192
total_cost > max_cost && return total_cost # abort if more expensive; we only want to know the cheapest
169
193
end
0 commit comments