@@ -19,6 +19,23 @@ function load_constrained(op, u₁loop, u₂loop, forprefetch = false)
19
19
dependsonu₂ && push! (unrolleddeps, u₂loop)
20
20
any (opp -> isload (opp) && all (in (loopdependencies (opp)), unrolleddeps), parents (op))
21
21
end
22
+ function check_if_remfirst (ls, ua)
23
+ usorig = ls. unrollspecification[]
24
+ @unpack u₁, u₁loopsym, u₂loopsym, u₂max = ua
25
+ u₁loop = getloop (ls, u₁loopsym)
26
+ u₂loop = getloop (ls, u₂loopsym)
27
+ if isstaticloop (u₁loop) && (usorig. u₁ != u₁)
28
+ return true
29
+ end
30
+ if isstaticloop (u₂loop) && (usorig. u₂ != u₂max)
31
+ return true
32
+ end
33
+ false
34
+ end
35
+ function sub_fmas (ls:: LoopSet , op:: Operation , ua:: UnrollArgs )
36
+ @unpack u₁, u₁loopsym, u₂loopsym, u₂max = ua
37
+ ! (load_constrained (op, u₁loopsym, u₂loopsym) || check_if_remfirst (ls, ua))
38
+ end
22
39
23
40
struct FalseCollection end
24
41
Base. getindex (:: FalseCollection , i... ) = false
@@ -106,7 +123,7 @@ function add_loopvalue!(instrcall::Expr, loopval, ua::UnrollArgs, u::Int)
106
123
end
107
124
108
125
function lower_compute! (
109
- q:: Expr , op:: Operation , ua:: UnrollArgs , mask:: Union{Nothing,Symbol,Unsigned} = nothing ,
126
+ q:: Expr , op:: Operation , ls :: LoopSet , ua:: UnrollArgs , mask:: Union{Nothing,Symbol,Unsigned} = nothing ,
110
127
)
111
128
@unpack u₁, u₁loopsym, u₂loopsym, vectorized, suffix = ua
112
129
var = name (op)
@@ -176,16 +193,16 @@ function lower_compute!(
176
193
instrfid = findfirst (isequal (instr. instr), (:vfmadd_fast , :vfnmadd_fast , :vfmsub_fast , :vfnmsub_fast ))
177
194
# want to instcombine when parent load's deps are superset
178
195
# also make sure opp is unrolled
179
- if instrfid != = nothing && (opunrolled && u₁ > 1 ) && ! load_constrained (op, u₁loopsym, u₂loopsym )
180
- specific_fmas = Base. libllvm_version > v " 11.0.0" ? (:vfmadd , :vfnmadd , :vfmsub , :vfnmsub ) : (:vfmadd231 , :vfnmadd231 , :vfmsub231 , :vfnmsub231 )
196
+ if ! isnothing ( instrfid) && (opunrolled && u₁ > 1 ) && sub_fmas (ls, op, ua )
197
+ specific_fmas = Base. libllvm_version >= v " 11.0.0" ? (:vfmadd , :vfnmadd , :vfmsub , :vfnmsub ) : (:vfmadd231 , :vfnmadd231 , :vfmsub231 , :vfnmsub231 )
181
198
# specific_fmas = (:vfmadd231, :vfnmadd231, :vfmsub231, :vfnmsub231)
182
199
instr = Instruction (specific_fmas[instrfid])
183
200
end
184
201
end
185
202
# @show instr.instr
186
203
reduceddeps = reduceddependencies (op)
187
204
vecinreduceddeps = isreduct && vectorized ∈ reduceddeps
188
- maskreduct = mask != = nothing && vecinreduceddeps # any(opp -> opp.variable === var, parents_op)
205
+ maskreduct = ! isnothing (mask) && vecinreduceddeps # any(opp -> opp.variable === var, parents_op)
189
206
# if vecinreduceddeps && vectorized ∉ loopdependencies(op) # screen parent opps for those needing a reduction to scalar
190
207
# # parents_op = reduce_vectorized_parents!(q, op, parents_op, U, u₁loopsym, u₂loopsym, vectorized, suffix)
191
208
# isreducingidentity!(q, op, parents_op, U, u₁loopsym, u₂loopsym, vectorized, suffix) && return
0 commit comments