Skip to content

Commit afec697

Browse files
add target hook to capture special operand uniformity and update UA to use it
1 parent 860b485 commit afec697

File tree

12 files changed

+156
-18
lines changed

12 files changed

+156
-18
lines changed

llvm/include/llvm/ADT/GenericUniformityImpl.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
#include "llvm/ADT/SmallPtrSet.h"
5252
#include "llvm/ADT/SparseBitVector.h"
5353
#include "llvm/ADT/StringExtras.h"
54+
#include "llvm/ADT/Uniformity.h"
5455
#include "llvm/Support/raw_ostream.h"
5556

5657
#define DEBUG_TYPE "uniformity"
@@ -407,6 +408,11 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {
407408
void recordTemporalDivergence(ConstValueRefT, const InstructionT *,
408409
const CycleT *);
409410

411+
bool isOperandUniform(const InstructionT &I, InstructionUniformity IU) const;
412+
413+
/// \brief keep track of target instruction that can be proven uniform.
414+
void addUniformInstruction(const InstructionT *I, InstructionUniformity IU);
415+
410416
protected:
411417
/// \brief Value/block pair representing a single phi input.
412418
struct PhiInput {
@@ -429,6 +435,11 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {
429435
// Internal worklist for divergence propagation.
430436
std::vector<const InstructionT *> Worklist;
431437

438+
// Map containing tracked instruction that can be proven uniform based on its
439+
// operand Uniformity.
440+
llvm::DenseMap<const InstructionT *, InstructionUniformity>
441+
UniformInstruction;
442+
432443
/// \brief Mark \p Term as divergent and push all Instructions that become
433444
/// divergent as a result on the worklist.
434445
void analyzeControlDivergence(const InstructionT &Term);
@@ -793,6 +804,11 @@ void GenericUniformityAnalysisImpl<ContextT>::markDivergent(
793804
const InstructionT &I) {
794805
if (isAlwaysUniform(I))
795806
return;
807+
auto It = UniformInstruction.find(&I);
808+
if (It != UniformInstruction.end() && isOperandUniform(I, It->second)) {
809+
addUniformOverride(I);
810+
return;
811+
}
796812
bool Marked = false;
797813
if (I.isTerminator()) {
798814
Marked = DivergentTermBlocks.insert(I.getParent()).second;
@@ -824,6 +840,12 @@ void GenericUniformityAnalysisImpl<ContextT>::addUniformOverride(
824840
UniformOverrides.insert(&Instr);
825841
}
826842

843+
template <typename ContextT>
844+
void GenericUniformityAnalysisImpl<ContextT>::addUniformInstruction(
845+
const InstructionT *I, InstructionUniformity IU) {
846+
UniformInstruction[I] = IU;
847+
}
848+
827849
// Mark as divergent all external uses of values defined in \p DefCycle.
828850
//
829851
// A value V defined by a block B inside \p DefCycle may be used outside the

llvm/include/llvm/ADT/Uniformity.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@ enum class InstructionUniformity {
2323
AlwaysUniform,
2424

2525
/// The result values can never be assumed to be uniform.
26-
NeverUniform
26+
NeverUniform,
27+
28+
/// Result value can be uniform if either of first two operand are uniform.
29+
EitherOfFirstTwoOp
2730
};
2831

2932
} // namespace llvm

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
#include "llvm/ADT/APInt.h"
2525
#include "llvm/ADT/ArrayRef.h"
26+
#include "llvm/ADT/Uniformity.h"
2627
#include "llvm/Analysis/IVDescriptors.h"
2728
#include "llvm/IR/FMF.h"
2829
#include "llvm/IR/InstrTypes.h"
@@ -1916,6 +1917,8 @@ class TargetTransformInfo {
19161917
const Function &F,
19171918
SmallVectorImpl<std::pair<StringRef, int64_t>> &LB) const;
19181919

1920+
InstructionUniformity getInstructionUniformity(const Instruction &I) const;
1921+
19191922
private:
19201923
std::unique_ptr<const TargetTransformInfoImplBase> TTIImpl;
19211924
};

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,6 +1147,11 @@ class TargetTransformInfoImplBase {
11471147
const Function &F,
11481148
SmallVectorImpl<std::pair<StringRef, int64_t>> &LB) const {}
11491149

1150+
virtual InstructionUniformity
1151+
getInstructionUniformity(const Instruction &I) const {
1152+
return InstructionUniformity::Default;
1153+
}
1154+
11501155
protected:
11511156
// Obtain the minimum required size to hold the value (without the sign)
11521157
// In case of a vector it returns the min required size for one element.

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1476,6 +1476,11 @@ void TargetTransformInfo::collectKernelLaunchBounds(
14761476
return TTIImpl->collectKernelLaunchBounds(F, LB);
14771477
}
14781478

1479+
InstructionUniformity
1480+
TargetTransformInfo::getInstructionUniformity(const Instruction &I) const {
1481+
return TTIImpl->getInstructionUniformity(I);
1482+
}
1483+
14791484
TargetTransformInfoImplBase::~TargetTransformInfoImplBase() = default;
14801485

14811486
TargetIRAnalysis::TargetIRAnalysis() : TTICallback(&getDefaultTTI) {}

llvm/lib/Analysis/UniformityAnalysis.cpp

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "llvm/Analysis/UniformityAnalysis.h"
1010
#include "llvm/ADT/GenericUniformityImpl.h"
11+
#include "llvm/ADT/Uniformity.h"
1112
#include "llvm/Analysis/CycleAnalysis.h"
1213
#include "llvm/Analysis/TargetTransformInfo.h"
1314
#include "llvm/IR/Dominators.h"
@@ -29,25 +30,15 @@ bool llvm::GenericUniformityAnalysisImpl<SSAContext>::markDefsDivergent(
2930
return markDivergent(cast<Value>(&Instr));
3031
}
3132

32-
template <>
33-
bool llvm::GenericUniformityAnalysisImpl<SSAContext>::isDivergentUse(
34-
const Use &U) const {
35-
const auto *V = U.get();
36-
if (isDivergent(V))
37-
return true;
38-
if (const auto *DefInstr = dyn_cast<Instruction>(V)) {
39-
const auto *UseInstr = cast<Instruction>(U.getUser());
40-
return isTemporalDivergent(*UseInstr->getParent(), *DefInstr);
41-
}
42-
return false;
43-
}
44-
4533
template <> void llvm::GenericUniformityAnalysisImpl<SSAContext>::initialize() {
4634
for (auto &I : instructions(F)) {
4735
if (TTI->isSourceOfDivergence(&I))
4836
markDivergent(I);
4937
else if (TTI->isAlwaysUniform(&I))
5038
addUniformOverride(I);
39+
InstructionUniformity IU = TTI->getInstructionUniformity(I);
40+
if (IU != InstructionUniformity::Default)
41+
addUniformInstruction(&I, IU);
5142
}
5243
for (auto &Arg : F.args()) {
5344
if (TTI->isSourceOfDivergence(&Arg)) {
@@ -101,6 +92,31 @@ void llvm::GenericUniformityAnalysisImpl<
10192
}
10293
}
10394

95+
template <>
96+
bool llvm::GenericUniformityAnalysisImpl<SSAContext>::isDivergentUse(
97+
const Use &U) const {
98+
const auto *V = U.get();
99+
if (isDivergent(V))
100+
return true;
101+
if (const auto *DefInstr = dyn_cast<Instruction>(V)) {
102+
const auto *UseInstr = cast<Instruction>(U.getUser());
103+
return isTemporalDivergent(*UseInstr->getParent(), *DefInstr);
104+
}
105+
return false;
106+
}
107+
108+
template <>
109+
bool GenericUniformityAnalysisImpl<SSAContext>::isOperandUniform(
110+
const Instruction &I, InstructionUniformity IU) const {
111+
switch (IU) {
112+
case InstructionUniformity::EitherOfFirstTwoOp:
113+
return !isDivergentUse(I.getOperandUse(0)) ||
114+
!isDivergentUse(I.getOperandUse(1));
115+
default:
116+
return false;
117+
}
118+
}
119+
104120
// This ensures explicit instantiation of
105121
// GenericUniformityAnalysisImpl::ImplDeleter::operator()
106122
template class llvm::GenericUniformityInfo<SSAContext>;

llvm/lib/CodeGen/MachineUniformityAnalysis.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,17 @@ bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::isDivergentUse(
147147
return isTemporalDivergent(*UseInstr->getParent(), *DefInstr);
148148
}
149149

150+
template <>
151+
bool GenericUniformityAnalysisImpl<MachineSSAContext>::isOperandUniform(
152+
const MachineInstr &I, InstructionUniformity IU) const {
153+
switch (IU) {
154+
case InstructionUniformity::EitherOfFirstTwoOp:
155+
return !isDivergentUse(I.getOperand(0)) || !isDivergentUse(I.getOperand(1));
156+
default:
157+
return false;
158+
}
159+
}
160+
150161
// This ensures explicit instantiation of
151162
// GenericUniformityAnalysisImpl::ImplDeleter::operator()
152163
template class llvm::GenericUniformityInfo<MachineSSAContext>;

llvm/lib/Target/AMDGPU/AMDGPUSearchableTables.td

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,8 +317,6 @@ def : SourceOfDivergence<int_amdgcn_live_mask>;
317317
def : SourceOfDivergence<int_amdgcn_ds_swizzle>;
318318
def : SourceOfDivergence<int_amdgcn_ds_ordered_add>;
319319
def : SourceOfDivergence<int_amdgcn_ds_ordered_swap>;
320-
def : SourceOfDivergence<int_amdgcn_permlane16>;
321-
def : SourceOfDivergence<int_amdgcn_permlanex16>;
322320
def : SourceOfDivergence<int_amdgcn_permlane16_var>;
323321
def : SourceOfDivergence<int_amdgcn_permlanex16_var>;
324322
def : SourceOfDivergence<int_amdgcn_mov_dpp>;

llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1422,3 +1422,17 @@ void GCNTTIImpl::collectKernelLaunchBounds(
14221422
LB.push_back({"amdgpu-waves-per-eu[0]", WavesPerEU.first});
14231423
LB.push_back({"amdgpu-waves-per-eu[1]", WavesPerEU.second});
14241424
}
1425+
1426+
InstructionUniformity
1427+
GCNTTIImpl::getInstructionUniformity(const Instruction &I) const {
1428+
if (const auto *II = dyn_cast<IntrinsicInst>(&I)) {
1429+
switch (II->getIntrinsicID()) {
1430+
case Intrinsic::amdgcn_permlane16:
1431+
case Intrinsic::amdgcn_permlanex16:
1432+
return InstructionUniformity::EitherOfFirstTwoOp;
1433+
default:
1434+
break;
1435+
}
1436+
}
1437+
return InstructionUniformity::Default;
1438+
}

llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,8 @@ class GCNTTIImpl final : public BasicTTIImplBase<GCNTTIImpl> {
290290
void collectKernelLaunchBounds(
291291
const Function &F,
292292
SmallVectorImpl<std::pair<StringRef, int64_t>> &LB) const override;
293+
InstructionUniformity
294+
getInstructionUniformity(const Instruction &I) const override;
293295
};
294296

295297
} // end namespace llvm

0 commit comments

Comments
 (0)