Skip to content
Open
28 changes: 27 additions & 1 deletion llvm/include/llvm/ADT/GenericUniformityImpl.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//===- GenericUniformityImpl.h -----------------------*- C++ -*------------===//
//===- GenericUniformityImpl.h -----------------------*- C++ -*------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand Down Expand Up @@ -51,6 +51,7 @@
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SparseBitVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/Uniformity.h"
#include "llvm/Support/raw_ostream.h"

#define DEBUG_TYPE "uniformity"
Expand Down Expand Up @@ -407,6 +408,13 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {
void recordTemporalDivergence(ConstValueRefT, const InstructionT *,
const CycleT *);

/// Check if an instruction with Custom uniformity can be proven uniform
/// based on its operands. This queries the target-specific callback.
bool isCustomUniform(const InstructionT &I) const;

/// \brief keep track of instructions that require custom uniformity analysis.
void addUniformInstruction(const InstructionT *I, InstructionUniformity IU);

protected:
const ContextT &Context;
const FunctionT &F;
Expand All @@ -420,6 +428,10 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {
// Internal worklist for divergence propagation.
std::vector<const InstructionT *> Worklist;

// Map containing tracked instruction that can be proven uniform based on its
// operand Uniformity.
DenseMap<const InstructionT *, InstructionUniformity> UniformInstruction;

/// \brief Mark \p Term as divergent and push all Instructions that become
/// divergent as a result on the worklist.
void analyzeControlDivergence(const InstructionT &Term);
Expand Down Expand Up @@ -785,6 +797,14 @@ void GenericUniformityAnalysisImpl<ContextT>::markDivergent(
const InstructionT &I) {
if (isAlwaysUniform(I))
return;
// Check if instruction requires custom uniformity analysis
auto It = UniformInstruction.find(&I);
if (It != UniformInstruction.end()) {
if (It->second == InstructionUniformity::Custom && isCustomUniform(I)) {
addUniformOverride(I);
return;
}
}
bool Marked = false;
if (I.isTerminator()) {
Marked = DivergentTermBlocks.insert(I.getParent()).second;
Expand Down Expand Up @@ -816,6 +836,12 @@ void GenericUniformityAnalysisImpl<ContextT>::addUniformOverride(
UniformOverrides.insert(&Instr);
}

template <typename ContextT>
void GenericUniformityAnalysisImpl<ContextT>::addUniformInstruction(
const InstructionT *I, InstructionUniformity IU) {
UniformInstruction[I] = IU;
}

// Mark as divergent all external uses of values defined in \p DefCycle.
//
// A value V defined by a block B inside \p DefCycle may be used outside the
Expand Down
7 changes: 6 additions & 1 deletion llvm/include/llvm/ADT/Uniformity.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@ enum class InstructionUniformity {
AlwaysUniform,

/// The result values can never be assumed to be uniform.
NeverUniform
NeverUniform,

/// If all operands are uniform, the result values are uniform. Otherwise,
/// the result values may be divergent, and a custom check may be used to
/// determine uniformity via a callback.
Custom
};

} // namespace llvm
Expand Down
15 changes: 15 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/BitmaskEnum.h"
#include "llvm/ADT/Uniformity.h"
#include "llvm/Analysis/IVDescriptors.h"
#include "llvm/Analysis/InterestingMemoryOperand.h"
#include "llvm/IR/FMF.h"
Expand Down Expand Up @@ -2000,6 +2001,20 @@ class TargetTransformInfo {
/// target.
LLVM_ABI bool allowVectorElementIndexingUsingGEP() const;

InstructionUniformity getInstructionUniformity(const Value *V) const;

/// Determine if an instruction with some operands uniform can be proven
/// uniform. This is used for custom uniformity analysis where the target
/// can define complex rules that depend on which specific operands are
/// uniform.
///
/// \param I The instruction to check.
/// \param UniformArgs A bitvector indicating which operands are known to be
/// uniform (bit N corresponds to operand N).
/// \returns true if the instruction result can be proven uniform given the
/// uniform operands, false otherwise.
bool isUniform(const Instruction *I, const SmallBitVector &UniformArgs) const;

private:
std::unique_ptr<const TargetTransformInfoImplBase> TTIImpl;
};
Expand Down
15 changes: 15 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1154,6 +1154,21 @@ class TargetTransformInfoImplBase {

virtual bool allowVectorElementIndexingUsingGEP() const { return true; }

// New API for uniformity classification
// Targets should override this to provide target-specific uniformity analysis
// The default implementation returns Default (conservative behavior)
virtual InstructionUniformity getInstructionUniformity(const Value *V) const {
return InstructionUniformity::Default;
}

// Custom uniformity check for instructions marked as Custom
// Override this to provide complex uniformity rules based on which operands
// are uniform
virtual bool isUniform(const Instruction *I,
const SmallBitVector &UniformArgs) const {
return false; // Conservative: assume divergent
}

protected:
// Obtain the minimum required size to hold the value (without the sign)
// In case of a vector it returns the min required size for one element.
Expand Down
16 changes: 16 additions & 0 deletions llvm/include/llvm/CodeGen/TargetInstrInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -2359,6 +2359,22 @@ class LLVM_ABI TargetInstrInfo : public MCInstrInfo {
return InstructionUniformity::Default;
}

/// Determine if a machine instruction with some operands uniform can be
/// proven uniform. This is used for custom uniformity analysis where the
/// target can define complex rules that depend on which specific operands
/// are uniform.
///
/// \param MI The machine instruction to check.
/// \param UniformArgs A bitvector indicating which register operands are
/// known to be uniform (bit N corresponds to the Nth
/// register use operand).
/// \returns true if the instruction result can be proven uniform given the
/// uniform operands, false otherwise.
virtual bool isUniform(const MachineInstr &MI,
const SmallBitVector &UniformArgs) const {
return false; // Conservative: assume divergent
}

/// Returns true if the given \p MI defines a TargetIndex operand that can be
/// tracked by their offset, can have values, and can have debug info
/// associated with it. If so, sets \p Index and \p Offset of the target index
Expand Down
10 changes: 10 additions & 0 deletions llvm/lib/Analysis/TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1517,6 +1517,16 @@ bool TargetTransformInfo::allowVectorElementIndexingUsingGEP() const {
return TTIImpl->allowVectorElementIndexingUsingGEP();
}

InstructionUniformity
TargetTransformInfo::getInstructionUniformity(const Value *V) const {
return TTIImpl->getInstructionUniformity(V);
}

bool TargetTransformInfo::isUniform(const Instruction *I,
const SmallBitVector &UniformArgs) const {
return TTIImpl->isUniform(I, UniformArgs);
}

TargetTransformInfoImplBase::~TargetTransformInfoImplBase() = default;

TargetIRAnalysis::TargetIRAnalysis() : TTICallback(&getDefaultTTI) {}
Expand Down
33 changes: 30 additions & 3 deletions llvm/lib/Analysis/UniformityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

#include "llvm/Analysis/UniformityAnalysis.h"
#include "llvm/ADT/GenericUniformityImpl.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/Uniformity.h"
#include "llvm/Analysis/CycleAnalysis.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/IR/Dominators.h"
Expand All @@ -31,13 +33,25 @@ bool llvm::GenericUniformityAnalysisImpl<SSAContext>::markDefsDivergent(

template <> void llvm::GenericUniformityAnalysisImpl<SSAContext>::initialize() {
for (auto &I : instructions(F)) {
if (TTI->isSourceOfDivergence(&I))
InstructionUniformity IU = TTI->getInstructionUniformity(&I);
switch (IU) {
case InstructionUniformity::NeverUniform:
markDivergent(I);
else if (TTI->isAlwaysUniform(&I))
break;
case InstructionUniformity::AlwaysUniform:
addUniformOverride(I);
break;
case InstructionUniformity::Custom:
// Instructions requiring custom uniformity analysis based on operands
addUniformInstruction(&I, IU);
break;
case InstructionUniformity::Default:
break;
}
}
for (auto &Arg : F.args()) {
if (TTI->isSourceOfDivergence(&Arg)) {
if (TTI->getInstructionUniformity(&Arg) ==
InstructionUniformity::NeverUniform) {
markDivergent(&Arg);
}
}
Expand Down Expand Up @@ -101,6 +115,19 @@ bool llvm::GenericUniformityAnalysisImpl<SSAContext>::isDivergentUse(
return false;
}

template <>
bool GenericUniformityAnalysisImpl<SSAContext>::isCustomUniform(
const Instruction &I) const {
// Build bitvector of uniform operands
SmallBitVector UniformArgs(I.getNumOperands());
for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) {
UniformArgs[OpIdx] = !isDivergentUse(I.getOperandUse(OpIdx));
}

// Query target-specific uniformity callback
return TTI->isUniform(&I, UniformArgs);
}

// This ensures explicit instantiation of
// GenericUniformityAnalysisImpl::ImplDeleter::operator()
template class llvm::GenericUniformityInfo<SSAContext>;
Expand Down
41 changes: 35 additions & 6 deletions llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "llvm/CodeGen/MachineUniformityAnalysis.h"
#include "llvm/ADT/GenericUniformityImpl.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/CodeGen/MachineCycleAnalysis.h"
#include "llvm/CodeGen/MachineDominators.h"
Expand Down Expand Up @@ -53,13 +54,19 @@ void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::initialize() {
for (const MachineBasicBlock &block : F) {
for (const MachineInstr &instr : block) {
auto uniformity = InstrInfo.getInstructionUniformity(instr);
if (uniformity == InstructionUniformity::AlwaysUniform) {
addUniformOverride(instr);
continue;
}

if (uniformity == InstructionUniformity::NeverUniform) {
switch (uniformity) {
case InstructionUniformity::NeverUniform:
markDivergent(instr);
break;
case InstructionUniformity::AlwaysUniform:
addUniformOverride(instr);
break;
case InstructionUniformity::Custom:
// Instructions requiring custom uniformity analysis based on operands
addUniformInstruction(&instr, uniformity);
break;
case InstructionUniformity::Default:
break;
}
}
}
Expand Down Expand Up @@ -148,6 +155,28 @@ bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::isDivergentUse(
return isTemporalDivergent(*UseInstr->getParent(), *DefInstr);
}

template <>
bool GenericUniformityAnalysisImpl<MachineSSAContext>::isCustomUniform(
const MachineInstr &MI) const {
const auto &InstrInfo = *F.getSubtarget().getInstrInfo();

// Build bitvector of uniform register use operands
SmallVector<const MachineOperand *, 4> RegUseOps;
for (const MachineOperand &MO : MI.uses()) {
if (MO.isReg() && MO.getReg().isVirtual()) {
RegUseOps.push_back(&MO);
}
}

SmallBitVector UniformArgs(RegUseOps.size());
for (unsigned i = 0; i < RegUseOps.size(); ++i) {
UniformArgs[i] = !isDivergentUse(*RegUseOps[i]);
}

// Query target-specific uniformity callback
return InstrInfo.isUniform(MI, UniformArgs);
}

// This ensures explicit instantiation of
// GenericUniformityAnalysisImpl::ImplDeleter::operator()
template class llvm::GenericUniformityInfo<MachineSSAContext>;
Expand Down
2 changes: 0 additions & 2 deletions llvm/lib/Target/AMDGPU/AMDGPUSearchableTables.td
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,6 @@ def : SourceOfDivergence<int_amdgcn_live_mask>;
def : SourceOfDivergence<int_amdgcn_ds_swizzle>;
def : SourceOfDivergence<int_amdgcn_ds_ordered_add>;
def : SourceOfDivergence<int_amdgcn_ds_ordered_swap>;
def : SourceOfDivergence<int_amdgcn_permlane16>;
def : SourceOfDivergence<int_amdgcn_permlanex16>;
def : SourceOfDivergence<int_amdgcn_permlane16_var>;
def : SourceOfDivergence<int_amdgcn_permlanex16_var>;
def : SourceOfDivergence<int_amdgcn_permlane_bcast>;
Expand Down
53 changes: 53 additions & 0 deletions llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "AMDGPUTargetMachine.h"
#include "MCTargetDesc/AMDGPUMCTargetDesc.h"
#include "SIModeRegisterDefaults.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/Analysis/InlineCost.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/ValueTracking.h"
Expand Down Expand Up @@ -1574,3 +1575,55 @@ unsigned GCNTTIImpl::getNumberOfParts(Type *Tp) const {
}
return BaseT::getNumberOfParts(Tp);
}

// New API that wraps the old isSourceOfDivergence and isAlwaysUniform APIs
// with additional support for new uniformity classifications
InstructionUniformity
GCNTTIImpl::getInstructionUniformity(const Value *V) const {
// Check for special cases requiring custom uniformity analysis
if (const IntrinsicInst *Intrinsic = dyn_cast<IntrinsicInst>(V)) {
switch (Intrinsic->getIntrinsicID()) {
case Intrinsic::amdgcn_permlane16:
case Intrinsic::amdgcn_permlanex16:
return InstructionUniformity::Custom;
default:
break;
}
}

// Delegate to old APIs for backward compatibility
if (isAlwaysUniform(V))
return InstructionUniformity::AlwaysUniform;

// Check if source of divergence
if (isSourceOfDivergence(V))
return InstructionUniformity::NeverUniform;

// Default behavior
return InstructionUniformity::Default;
}

bool GCNTTIImpl::isUniform(const Instruction *I,
const SmallBitVector &UniformArgs) const {
// Custom uniformity check for permlane16/permlanex16
if (const IntrinsicInst *Intrinsic = dyn_cast<IntrinsicInst>(I)) {
switch (Intrinsic->getIntrinsicID()) {
case Intrinsic::amdgcn_permlane16:
case Intrinsic::amdgcn_permlanex16:
// For permlane16/permlanex16:
// Operand 0: old value (ignored for uniformity)
// Operand 1: src0 (source value to permute)
// Operand 2: src1 (lane select within 16-lane group)
// Operand 3: src2 (which 16-lane group)
// Result is uniform if either src0 (op 1) or src1 (op 2) is uniform
Comment on lines +1616 to +1618
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As Nicolai pointed out this is completely wrong and permlane16 is the wrong example to use to demonstrate the new functionality.

if (UniformArgs.size() > 2) {
return UniformArgs[1] || UniformArgs[2];
}
return false;
default:
break;
}
}

return false;
}
5 changes: 5 additions & 0 deletions llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,11 @@ class GCNTTIImpl final : public BasicTTIImplBase<GCNTTIImpl> {
/// together under a single i32 value. Otherwise fall back to base
/// implementation.
unsigned getNumberOfParts(Type *Tp) const override;

InstructionUniformity getInstructionUniformity(const Value *V) const override;

bool isUniform(const Instruction *I,
const SmallBitVector &UniformArgs) const override;
};

} // end namespace llvm
Expand Down
Loading