@@ -39,30 +39,27 @@ function lower_compute!(
39
39
parentsunrolled[i] || continue
40
40
parentsunrolled[i] = false
41
41
parentop = parents_op[i]
42
+ i == tiledouterreduction && isconstant (parentop) && continue
42
43
newparentop = Operation (
43
44
parentop. identifier, gensym (parentop. variable), parentop. elementbytes, parentop. instruction, parentop. node_type,
44
45
parentop. dependencies, parentop. reduced_deps, parentop. parents, parentop. ref, parentop. reduced_children
45
46
)
46
47
parentname = mangledvar (parentop)
47
48
newparentname = mangledvar (newparentop)
48
49
parents_op[i] = newparentop
49
- if i == tiledouterreduction && isconstant (newparentop)
50
- push! (q. args, Expr (:(= ), Symbol (newparentname, suffix), Symbol (parentname, suffix)))
50
+ if parentstiled[i]
51
+ parentname = Symbol (parentname, suffix_)
52
+ newparentname = Symbol (newparentname, suffix_)
53
+ end
54
+ if isconstant (newparentop)
55
+ # @show i, parentstiled[i], newparentname, parentname
56
+ push! (q. args, Expr (:(= ), newparentname, Symbol (parentname, 0 )))
51
57
else
52
- if parentstiled[i]
53
- parentname = Symbol (parentname, suffix_)
54
- newparentname = Symbol (newparentname, suffix_)
55
- end
56
- if isconstant (newparentop)
57
- # @show i, parentstiled[i], newparentname, parentname
58
- push! (q. args, Expr (:(= ), newparentname, Symbol (parentname, 0 )))
59
- else
60
- for u ∈ 0 : U- 1
61
- push! (q. args, Expr (:(= ), Symbol (newparentname, u), Symbol (parentname, u)))
62
- end
63
- reduce_expr! (q, newparentname, Instruction (reduction_to_single_vector (instruction (newparentop))), U)
64
- push! (q. args, Expr (:(= ), newparentname, Symbol (newparentname, 0 )))
58
+ for u ∈ 0 : U- 1
59
+ push! (q. args, Expr (:(= ), Symbol (newparentname, u), Symbol (parentname, u)))
65
60
end
61
+ reduce_expr! (q, newparentname, Instruction (reduction_to_single_vector (instruction (newparentop))), U)
62
+ push! (q. args, Expr (:(= ), newparentname, Symbol (newparentname, 0 )))
66
63
end
67
64
end
68
65
end
@@ -96,8 +93,8 @@ function lower_compute!(
96
93
for u ∈ 0 : Uiter
97
94
instrcall = Expr (instr) # Expr(:call, instr)
98
95
varsym = if tiledouterreduction > 0 # then suffix !== nothing
99
- # modsuffix = ((u + suffix*U) & 3)
100
- modsuffix = suffix # (suffix & 3)
96
+ modsuffix = ((u + suffix* U) & 3 )
97
+ # modsuffix = suffix # (suffix & 3)
101
98
Symbol (mvar, modsuffix)
102
99
elseif unrollsym
103
100
Symbol (mvar, u)
0 commit comments