diff --git a/llvm/include/llvm/ADT/GenericUniformityImpl.h b/llvm/include/llvm/ADT/GenericUniformityImpl.h index d10355fff1bea..1318f285717dd 100644 --- a/llvm/include/llvm/ADT/GenericUniformityImpl.h +++ b/llvm/include/llvm/ADT/GenericUniformityImpl.h @@ -51,6 +51,7 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SparseBitVector.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/Uniformity.h" #include "llvm/Support/raw_ostream.h" #define DEBUG_TYPE "uniformity" @@ -407,6 +408,11 @@ template class GenericUniformityAnalysisImpl { void recordTemporalDivergence(ConstValueRefT, const InstructionT *, const CycleT *); + bool isAnyOperandUniform(const InstructionT &Instr) const; + + /// \brief keep track of special target intrinsics that can be proven uniform. + void addSpecialUniformIntrinsic(const InstructionT &Instr); + protected: /// \brief Value/block pair representing a single phi input. struct PhiInput { @@ -429,6 +435,8 @@ template class GenericUniformityAnalysisImpl { // Internal worklist for divergence propagation. std::vector Worklist; + // Special intrinsics list which can be proven uniform. + llvm::SmallPtrSet SpecialUniformIntrinsics; /// \brief Mark \p Term as divergent and push all Instructions that become /// divergent as a result on the worklist. void analyzeControlDivergence(const InstructionT &Term); @@ -824,6 +832,12 @@ void GenericUniformityAnalysisImpl::addUniformOverride( UniformOverrides.insert(&Instr); } +template +void GenericUniformityAnalysisImpl::addSpecialUniformIntrinsic( + const InstructionT &Instr) { + SpecialUniformIntrinsics.insert(&Instr); +} + // Mark as divergent all external uses of values defined in \p DefCycle. // // A value V defined by a block B inside \p DefCycle may be used outside the diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h index 022530dc846ea..c5e0b56d5f91c 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -23,6 +23,7 @@ #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Uniformity.h" #include "llvm/Analysis/IVDescriptors.h" #include "llvm/IR/FMF.h" #include "llvm/IR/InstrTypes.h" @@ -1916,6 +1917,8 @@ class TargetTransformInfo { const Function &F, SmallVectorImpl> &LB) const; + bool isSpecialUniformIntrinsic(const Instruction &I) const; + private: std::unique_ptr TTIImpl; }; diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h index 990252b1e5743..b1d780162ac9e 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -1147,6 +1147,10 @@ class TargetTransformInfoImplBase { const Function &F, SmallVectorImpl> &LB) const {} + virtual bool isSpecialUniformIntrinsic(const Instruction &I) const { + return false; + } + protected: // Obtain the minimum required size to hold the value (without the sign) // In case of a vector it returns the min required size for one element. diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp index 8548afea72964..15083a6c50e02 100644 --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -1476,6 +1476,11 @@ void TargetTransformInfo::collectKernelLaunchBounds( return TTIImpl->collectKernelLaunchBounds(F, LB); } +bool TargetTransformInfo::isSpecialUniformIntrinsic( + const Instruction &I) const { + return TTIImpl->isSpecialUniformIntrinsic(I); +} + TargetTransformInfoImplBase::~TargetTransformInfoImplBase() = default; TargetIRAnalysis::TargetIRAnalysis() : TTICallback(&getDefaultTTI) {} diff --git a/llvm/lib/Analysis/UniformityAnalysis.cpp b/llvm/lib/Analysis/UniformityAnalysis.cpp index 2101fdfacfc8f..b71e82b73130f 100644 --- a/llvm/lib/Analysis/UniformityAnalysis.cpp +++ b/llvm/lib/Analysis/UniformityAnalysis.cpp @@ -29,12 +29,39 @@ bool llvm::GenericUniformityAnalysisImpl::markDefsDivergent( return markDivergent(cast(&Instr)); } +template <> +bool llvm::GenericUniformityAnalysisImpl::isDivergentUse( + const Use &U) const { + const auto *V = U.get(); + if (isDivergent(V)) + return true; + if (const auto *DefInstr = dyn_cast(V)) { + const auto *UseInstr = cast(U.getUser()); + return isTemporalDivergent(*UseInstr->getParent(), *DefInstr); + } + return false; +} + +template <> +bool llvm::GenericUniformityAnalysisImpl::isAnyOperandUniform( + const Instruction &I) const { + for (unsigned i = 0, e = I.getNumOperands(); i != e; ++i) { + if (!isa(I.getOperand(i)) && !isa(I.getOperand(i))) + continue; + if (!isDivergentUse(I.getOperandUse(i))) + return true; + } + return false; +} + template <> void llvm::GenericUniformityAnalysisImpl::initialize() { for (auto &I : instructions(F)) { if (TTI->isSourceOfDivergence(&I)) markDivergent(I); else if (TTI->isAlwaysUniform(&I)) addUniformOverride(I); + else if (TTI->isSpecialUniformIntrinsic(I)) + addSpecialUniformIntrinsic(I); } for (auto &Arg : F.args()) { if (TTI->isSourceOfDivergence(&Arg)) { @@ -48,6 +75,11 @@ void llvm::GenericUniformityAnalysisImpl::pushUsers( const Value *V) { for (const auto *User : V->users()) { if (const auto *UserInstr = dyn_cast(User)) { + if (SpecialUniformIntrinsics.count(UserInstr) && + isAnyOperandUniform(*UserInstr)) { + addUniformOverride(*UserInstr); + continue; + } markDivergent(*UserInstr); } } @@ -88,19 +120,6 @@ void llvm::GenericUniformityAnalysisImpl< } } -template <> -bool llvm::GenericUniformityAnalysisImpl::isDivergentUse( - const Use &U) const { - const auto *V = U.get(); - if (isDivergent(V)) - return true; - if (const auto *DefInstr = dyn_cast(V)) { - const auto *UseInstr = cast(U.getUser()); - return isTemporalDivergent(*UseInstr->getParent(), *DefInstr); - } - return false; -} - // This ensures explicit instantiation of // GenericUniformityAnalysisImpl::ImplDeleter::operator() template class llvm::GenericUniformityInfo; diff --git a/llvm/lib/Target/AMDGPU/AMDGPUSearchableTables.td b/llvm/lib/Target/AMDGPU/AMDGPUSearchableTables.td index 3b62dcf3c92cd..da3776761ab34 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUSearchableTables.td +++ b/llvm/lib/Target/AMDGPU/AMDGPUSearchableTables.td @@ -317,8 +317,6 @@ def : SourceOfDivergence; def : SourceOfDivergence; def : SourceOfDivergence; def : SourceOfDivergence; -def : SourceOfDivergence; -def : SourceOfDivergence; def : SourceOfDivergence; def : SourceOfDivergence; def : SourceOfDivergence; diff --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp index 204d3df546bbf..9acf536f86eba 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp @@ -1422,3 +1422,16 @@ void GCNTTIImpl::collectKernelLaunchBounds( LB.push_back({"amdgpu-waves-per-eu[0]", WavesPerEU.first}); LB.push_back({"amdgpu-waves-per-eu[1]", WavesPerEU.second}); } + +bool GCNTTIImpl::isSpecialUniformIntrinsic(const Instruction &I) const { + if (const auto *II = dyn_cast(&I)) { + switch (II->getIntrinsicID()) { + case Intrinsic::amdgcn_permlane16: + case Intrinsic::amdgcn_permlanex16: + return true; + default: + return false; + } + } + return false; +} diff --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h index f6f7bd4bfcf5b..ed169efbeb047 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h +++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h @@ -290,6 +290,7 @@ class GCNTTIImpl final : public BasicTTIImplBase { void collectKernelLaunchBounds( const Function &F, SmallVectorImpl> &LB) const override; + bool isSpecialUniformIntrinsic(const Instruction &I) const override; }; } // end namespace llvm diff --git a/llvm/test/Analysis/UniformityAnalysis/AMDGPU/intrinsics.ll b/llvm/test/Analysis/UniformityAnalysis/AMDGPU/intrinsics.ll index bb840023daf5d..f209c996c7692 100644 --- a/llvm/test/Analysis/UniformityAnalysis/AMDGPU/intrinsics.ll +++ b/llvm/test/Analysis/UniformityAnalysis/AMDGPU/intrinsics.ll @@ -7,14 +7,14 @@ define amdgpu_kernel void @ds_swizzle(ptr addrspace(1) %out, i32 %src) #0 { ret void } -; CHECK: DIVERGENT: %v = call i32 @llvm.amdgcn.permlane16.i32(i32 %src0, i32 %src0, i32 %src1, i32 %src2, i1 false, i1 false) #0 +; CHECK: ALL VALUES UNIFORM define amdgpu_kernel void @v_permlane16_b32(ptr addrspace(1) %out, i32 %src0, i32 %src1, i32 %src2) #0 { %v = call i32 @llvm.amdgcn.permlane16.i32(i32 %src0, i32 %src0, i32 %src1, i32 %src2, i1 false, i1 false) #0 store i32 %v, ptr addrspace(1) %out ret void } -; CHECK: DIVERGENT: %v = call i32 @llvm.amdgcn.permlanex16.i32(i32 %src0, i32 %src0, i32 %src1, i32 %src2, i1 false, i1 false) #0 +; CHECK: ALL VALUES UNIFORM define amdgpu_kernel void @v_permlanex16_b32(ptr addrspace(1) %out, i32 %src0, i32 %src1, i32 %src2) #0 { %v = call i32 @llvm.amdgcn.permlanex16.i32(i32 %src0, i32 %src0, i32 %src1, i32 %src2, i1 false, i1 false) #0 store i32 %v, ptr addrspace(1) %out diff --git a/llvm/test/Analysis/UniformityAnalysis/AMDGPU/uniform_intrinsic.ll b/llvm/test/Analysis/UniformityAnalysis/AMDGPU/uniform_intrinsic.ll new file mode 100644 index 0000000000000..e5eb5ebebf897 --- /dev/null +++ b/llvm/test/Analysis/UniformityAnalysis/AMDGPU/uniform_intrinsic.ll @@ -0,0 +1,19 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt -mtriple amdgcn-- -passes='print' -disable-output %s 2>&1 | FileCheck %s + +; CHECK: ALL VALUES UNIFORM +define amdgpu_kernel void @v_permlane16_b32(ptr addrspace(1) %out, i32 %src0, i32 %src1, i32 %src2) #0 { + %v = call i32 @llvm.amdgcn.permlane16.i32(i32 %src0, i32 %src0, i32 %src1, i32 %src2, i1 false, i1 false) #0 + store i32 %v, ptr addrspace(1) %out + ret void +} + +; CHECK: ALL VALUES UNIFORM +define amdgpu_kernel void @v_permlanex16_b32(ptr addrspace(1) %out, i32 %src0, i32 %src1, i32 %src2) #0 { + %v = call i32 @llvm.amdgcn.permlanex16.i32(i32 %src0, i32 %src0, i32 %src1, i32 %src2, i1 false, i1 false) #0 + store i32 %v, ptr addrspace(1) %out + ret void +} + +;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line: +; CHECK: {{.*}}