@@ -82,16 +82,12 @@ function lower_block(
82
82
for prepost ∈ 1 : 2
83
83
# !U && !T
84
84
lower! (blockq, ops[1 ,1 ,prepost,n], vectorized, ls, unrolled, tiled, U, nothing , mask)
85
- # for u ∈ 0:U-1 # U && !T
86
- lower! (blockq, ops[2 ,1 ,prepost,n], vectorized, ls, unrolled, tiled, U, nothing , mask)
87
- # end
88
- if length (ops[1 ,2 ,prepost,n]) + length (ops[2 ,2 ,prepost,n]) > 0
85
+ opsv1 = ops[1 ,2 ,prepost,n]
86
+ opsv2 = ops[2 ,2 ,prepost,n]
87
+ if length (opsv1) + length (opsv2) > 0
89
88
for store ∈ (false ,true )
90
89
# let store = nothing
91
90
nstores = 0
92
- opsv1 = ops[1 ,2 ,prepost,n]
93
- opsv2 = ops[2 ,2 ,prepost,n]
94
- iszero (length (opsv1) + length (opsv2)) && continue
95
91
iszero (length (opsv1)) || (nstores += sum (isstore, opsv1))
96
92
iszero (length (opsv2)) || (nstores += sum (isstore, opsv2))
97
93
for t ∈ 0 : T- 1
@@ -102,21 +98,32 @@ function lower_block(
102
98
else
103
99
push! (blockq. args, Expr (:+= , tiled, 1 ))
104
100
end
105
- # !U && T
106
- if dontmaskfirsttiles && t < T - 1
101
+ if dontmaskfirsttiles && t < T - 1 # !U && T
107
102
lower! (blockq, opsv1, vectorized, ls, unrolled, tiled, U, t, nothing , store)
108
- # for u ∈ 0:U-1 # U && T
103
+ else # !U && T
104
+ lower! (blockq, opsv1, vectorized, ls, unrolled, tiled, U, t, mask, store)
105
+ end
106
+ if iszero (t) && ! store # U && !T
107
+ # for u ∈ 0:U-1
108
+ lower! (blockq, ops[2 ,1 ,prepost,n], vectorized, ls, unrolled, tiled, U, nothing , mask)
109
+ # end
110
+ end
111
+ if dontmaskfirsttiles && t < T - 1 # U && T
112
+ # for u ∈ 0:U-1
109
113
lower! (blockq, opsv2, vectorized, ls, unrolled, tiled, U, t, nothing , store)
110
114
# end
111
- else
112
- lower! (blockq, opsv1, vectorized, ls, unrolled, tiled, U, t, mask, store)
113
- # for u ∈ 0:U-1 # U && T
115
+ else # U && T
116
+ # for u ∈ 0:U-1
114
117
lower! (blockq, opsv2, vectorized, ls, unrolled, tiled, U, t, mask, store)
115
118
# end
116
119
end
117
120
end
118
121
nstores == 0 && break
119
122
end
123
+ else
124
+ # for u ∈ 0:U-1 # U && !T
125
+ lower! (blockq, ops[2 ,1 ,prepost,n], vectorized, ls, unrolled, tiled, U, nothing , mask)
126
+ # end
120
127
end
121
128
if n > 1 && prepost == 1
122
129
push! (blockq. args, lower_unrolled_dynamic (ls, us, n- 1 , ! isnothing (mask)))
0 commit comments