Skip to content

Commit 8cf0cbf

Browse files
committed
fixup! fixup! [AMDGPU] Eliminate likely-spurious execz checks
Make the heuristic more strict, add IR tests
1 parent 345269b commit 8cf0cbf

File tree

4 files changed

+627
-152
lines changed

4 files changed

+627
-152
lines changed

llvm/lib/Target/AMDGPU/AMDGPUAnnotateVaryingBranchWeights.cpp

Lines changed: 70 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@
2626
#include "llvm/Analysis/UniformityAnalysis.h"
2727
#include "llvm/CodeGen/TargetPassConfig.h"
2828
#include "llvm/IR/Analysis.h"
29+
#include "llvm/IR/InstrTypes.h"
2930
#include "llvm/IR/Instructions.h"
31+
#include "llvm/IR/IntrinsicInst.h"
32+
#include "llvm/IR/IntrinsicsAMDGPU.h"
33+
#include "llvm/IR/IntrinsicsR600.h"
3034
#include "llvm/IR/ProfDataUtils.h"
3135
#include "llvm/InitializePasses.h"
3236
#include "llvm/Support/Casting.h"
@@ -44,7 +48,7 @@ class AMDGPUAnnotateVaryingBranchWeightsImpl {
4448
AMDGPUAnnotateVaryingBranchWeightsImpl(const GCNSubtarget &ST,
4549
const TargetTransformInfo &GCNTTI,
4650
UniformityInfo &UA)
47-
: ST(ST), GCNTTI(GCNTTI), UA(UA) {
51+
: ST(ST), UA(UA) {
4852
// Determine weights that signal that a branch is very likely to be
4953
// predicted correctly, i.e., whose ratio exceeds
5054
// TTI.getPredictableBranchThreshold().
@@ -59,11 +63,13 @@ class AMDGPUAnnotateVaryingBranchWeightsImpl {
5963

6064
private:
6165
const GCNSubtarget &ST;
62-
const TargetTransformInfo &GCNTTI;
6366
const UniformityInfo &UA;
6467
uint32_t LikelyWeight;
6568
uint32_t UnlikelyWeight;
6669
ValueMap<const Value *, bool> LikelyVaryingCache;
70+
unsigned HighestLikelyVaryingDimension = 0;
71+
72+
bool isRelevantSourceOfDivergence(const Value *V) const;
6773

6874
/// Heuristically check if it is likely that a wavefront has dynamically
6975
/// varying values for V.
@@ -72,7 +78,7 @@ class AMDGPUAnnotateVaryingBranchWeightsImpl {
7278
/// Set branch weights that signal that the "true" successor of Term is the
7379
/// likely destination, if no prior weights are present.
7480
/// Return true if weights were set.
75-
bool setTrueSuccessorLikely(BranchInst *Term);
81+
bool setTrueSuccessorLikely(BranchInst *Term) const;
7682
};
7783

7884
class AMDGPUAnnotateVaryingBranchWeightsLegacy : public FunctionPass {
@@ -137,13 +143,43 @@ AMDGPUAnnotateVaryingBranchWeightsPass::run(Function &F,
137143
return PA;
138144
}
139145

146+
bool AMDGPUAnnotateVaryingBranchWeightsImpl::isRelevantSourceOfDivergence(
147+
const Value *V) const {
148+
auto *II = dyn_cast<IntrinsicInst>(V);
149+
if (!II)
150+
return false;
151+
152+
switch (II->getIntrinsicID()) {
153+
case Intrinsic::amdgcn_workitem_id_z:
154+
case Intrinsic::r600_read_tidig_z:
155+
return HighestLikelyVaryingDimension >= 2;
156+
case Intrinsic::amdgcn_workitem_id_y:
157+
case Intrinsic::r600_read_tidig_y:
158+
return HighestLikelyVaryingDimension >= 1;
159+
case Intrinsic::amdgcn_workitem_id_x:
160+
case Intrinsic::r600_read_tidig_x:
161+
case Intrinsic::amdgcn_mbcnt_hi:
162+
case Intrinsic::amdgcn_mbcnt_lo:
163+
return true;
164+
}
165+
166+
return false;
167+
}
168+
140169
bool AMDGPUAnnotateVaryingBranchWeightsImpl::isLikelyVarying(const Value *V) {
141170
// Check if V is a source of divergence or if it transitively uses one.
142-
if (GCNTTI.isSourceOfDivergence(V))
171+
if (isRelevantSourceOfDivergence(V))
143172
return true;
144173

145-
auto *U = dyn_cast<User>(V);
146-
if (!U)
174+
auto *I = dyn_cast<Instruction>(V);
175+
if (!I)
176+
return false;
177+
178+
// ExtractValueInst and IntrinsicInst enable looking through the
179+
// amdgcn_if/else intrinsics inserted by SIAnnotateControlFlow.
180+
// This condition excludes PHINodes, which prevents infinite recursion.
181+
if (!isa<BinaryOperator>(I) && !isa<UnaryOperator>(I) && !isa<CastInst>(I) &&
182+
!isa<CmpInst>(I) && !isa<ExtractValueInst>(I) && !isa<IntrinsicInst>(I))
147183
return false;
148184

149185
// Have we already checked V?
@@ -153,7 +189,7 @@ bool AMDGPUAnnotateVaryingBranchWeightsImpl::isLikelyVarying(const Value *V) {
153189

154190
// Does it use a likely varying Value?
155191
bool Result = false;
156-
for (const auto &Use : U->operands()) {
192+
for (const auto &Use : I->operands()) {
157193
Result |= isLikelyVarying(Use);
158194
if (Result)
159195
break;
@@ -164,7 +200,7 @@ bool AMDGPUAnnotateVaryingBranchWeightsImpl::isLikelyVarying(const Value *V) {
164200
}
165201

166202
bool AMDGPUAnnotateVaryingBranchWeightsImpl::setTrueSuccessorLikely(
167-
BranchInst *Term) {
203+
BranchInst *Term) const {
168204
assert(Term->isConditional());
169205

170206
// Don't overwrite existing branch weights.
@@ -177,9 +213,33 @@ bool AMDGPUAnnotateVaryingBranchWeightsImpl::setTrueSuccessorLikely(
177213
}
178214

179215
bool AMDGPUAnnotateVaryingBranchWeightsImpl::run(Function &F) {
216+
unsigned MinWGSize = ST.getFlatWorkGroupSizes(F).first;
217+
bool MustHaveMoreThanOneThread = MinWGSize > 1;
218+
219+
// reqd_work_group_size determines the size of the work group in every
220+
// dimension. If it is present, identify the dimensions where the workitem id
221+
// differs between the threads of the same wavefront. Otherwise assume that
222+
// only dimension 0, i.e., x, varies.
223+
//
224+
// TODO can/should we assume that workitems are grouped into waves like that?
225+
auto *Node = F.getMetadata("reqd_work_group_size");
226+
if (Node && Node->getNumOperands() == 3) {
227+
unsigned WavefrontSize = ST.getWavefrontSize();
228+
unsigned ThreadsSoFar = 1;
229+
unsigned Dim = 0;
230+
for (; Dim < 3; ++Dim) {
231+
ThreadsSoFar *=
232+
mdconst::extract<ConstantInt>(Node->getOperand(Dim))->getZExtValue();
233+
if (ThreadsSoFar >= WavefrontSize)
234+
break;
235+
}
236+
HighestLikelyVaryingDimension = Dim;
237+
LLVM_DEBUG(dbgs() << "Highest Likely Varying Dimension: " << Dim << '\n');
238+
MustHaveMoreThanOneThread |= ThreadsSoFar > 1;
239+
}
240+
180241
// If the workgroup has only a single thread, the condition cannot vary.
181-
const auto WGSizes = ST.getFlatWorkGroupSizes(F);
182-
if (WGSizes.first <= 1)
242+
if (!MustHaveMoreThanOneThread)
183243
return false;
184244

185245
bool Changed = false;

0 commit comments

Comments
 (0)