Skip to content
Open
23 changes: 22 additions & 1 deletion llvm/include/llvm/ADT/GenericUniformityImpl.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//===- GenericUniformityImpl.h -----------------------*- C++ -*------------===//
//===- GenericUniformityImpl.h -----------------------*- C++ -*------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand Down Expand Up @@ -51,6 +51,7 @@
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SparseBitVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/Uniformity.h"
#include "llvm/Support/raw_ostream.h"

#define DEBUG_TYPE "uniformity"
Expand Down Expand Up @@ -407,6 +408,11 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {
void recordTemporalDivergence(ConstValueRefT, const InstructionT *,
const CycleT *);

bool isOperandUniform(const InstructionT &I, InstructionUniformity IU) const;

/// \brief keep track of target instruction that can be proven uniform.
void addUniformInstruction(const InstructionT *I, InstructionUniformity IU);

protected:
const ContextT &Context;
const FunctionT &F;
Expand All @@ -420,6 +426,10 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {
// Internal worklist for divergence propagation.
std::vector<const InstructionT *> Worklist;

// Map containing tracked instruction that can be proven uniform based on its
// operand Uniformity.
DenseMap<const InstructionT *, InstructionUniformity> UniformInstruction;

/// \brief Mark \p Term as divergent and push all Instructions that become
/// divergent as a result on the worklist.
void analyzeControlDivergence(const InstructionT &Term);
Expand Down Expand Up @@ -785,6 +795,11 @@ void GenericUniformityAnalysisImpl<ContextT>::markDivergent(
const InstructionT &I) {
if (isAlwaysUniform(I))
return;
auto It = UniformInstruction.find(&I);
if (It != UniformInstruction.end() && isOperandUniform(I, It->second)) {
addUniformOverride(I);
return;
}
bool Marked = false;
if (I.isTerminator()) {
Marked = DivergentTermBlocks.insert(I.getParent()).second;
Expand Down Expand Up @@ -816,6 +831,12 @@ void GenericUniformityAnalysisImpl<ContextT>::addUniformOverride(
UniformOverrides.insert(&Instr);
}

template <typename ContextT>
void GenericUniformityAnalysisImpl<ContextT>::addUniformInstruction(
const InstructionT *I, InstructionUniformity IU) {
UniformInstruction[I] = IU;
}

// Mark as divergent all external uses of values defined in \p DefCycle.
//
// A value V defined by a block B inside \p DefCycle may be used outside the
Expand Down
6 changes: 5 additions & 1 deletion llvm/include/llvm/ADT/Uniformity.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ enum class InstructionUniformity {
AlwaysUniform,

/// The result values can never be assumed to be uniform.
NeverUniform
NeverUniform,

/// Result value can be uniform if any of the first two use operand are
/// uniform.
AnyOfFirstTwoUseOp
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That enum value seems like a really bad precedent. It's so arbitrary.

};

} // namespace llvm
Expand Down
3 changes: 3 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/BitmaskEnum.h"
#include "llvm/ADT/Uniformity.h"
#include "llvm/Analysis/IVDescriptors.h"
#include "llvm/Analysis/InterestingMemoryOperand.h"
#include "llvm/IR/FMF.h"
Expand Down Expand Up @@ -2000,6 +2001,8 @@ class TargetTransformInfo {
/// target.
LLVM_ABI bool allowVectorElementIndexingUsingGEP() const;

InstructionUniformity getInstructionUniformity(const Value *V) const;

private:
std::unique_ptr<const TargetTransformInfoImplBase> TTIImpl;
};
Expand Down
7 changes: 7 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1154,6 +1154,13 @@ class TargetTransformInfoImplBase {

virtual bool allowVectorElementIndexingUsingGEP() const { return true; }

// New API for uniformity classification
// Targets should override this to provide target-specific uniformity analysis
// The default implementation returns Default (conservative behavior)
virtual InstructionUniformity getInstructionUniformity(const Value *V) const {
return InstructionUniformity::Default;
}

protected:
// Obtain the minimum required size to hold the value (without the sign)
// In case of a vector it returns the min required size for one element.
Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/Analysis/TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1517,6 +1517,11 @@ bool TargetTransformInfo::allowVectorElementIndexingUsingGEP() const {
return TTIImpl->allowVectorElementIndexingUsingGEP();
}

InstructionUniformity
TargetTransformInfo::getInstructionUniformity(const Value *V) const {
return TTIImpl->getInstructionUniformity(V);
}

TargetTransformInfoImplBase::~TargetTransformInfoImplBase() = default;

TargetIRAnalysis::TargetIRAnalysis() : TTICallback(&getDefaultTTI) {}
Expand Down
33 changes: 30 additions & 3 deletions llvm/lib/Analysis/UniformityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "llvm/Analysis/UniformityAnalysis.h"
#include "llvm/ADT/GenericUniformityImpl.h"
#include "llvm/ADT/Uniformity.h"
#include "llvm/Analysis/CycleAnalysis.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/IR/Dominators.h"
Expand All @@ -31,13 +32,24 @@ bool llvm::GenericUniformityAnalysisImpl<SSAContext>::markDefsDivergent(

template <> void llvm::GenericUniformityAnalysisImpl<SSAContext>::initialize() {
for (auto &I : instructions(F)) {
if (TTI->isSourceOfDivergence(&I))
InstructionUniformity IU = TTI->getInstructionUniformity(&I);
switch (IU) {
case InstructionUniformity::NeverUniform:
markDivergent(I);
else if (TTI->isAlwaysUniform(&I))
break;
case InstructionUniformity::AlwaysUniform:
addUniformOverride(I);
break;
case InstructionUniformity::Default:
break;
default:
addUniformInstruction(&I, IU);
break;
}
}
for (auto &Arg : F.args()) {
if (TTI->isSourceOfDivergence(&Arg)) {
if (TTI->getInstructionUniformity(&Arg) ==
InstructionUniformity::NeverUniform) {
markDivergent(&Arg);
}
}
Expand Down Expand Up @@ -101,6 +113,21 @@ bool llvm::GenericUniformityAnalysisImpl<SSAContext>::isDivergentUse(
return false;
}

template <>
bool GenericUniformityAnalysisImpl<SSAContext>::isOperandUniform(
const Instruction &I, InstructionUniformity IU) const {
switch (IU) {
case InstructionUniformity::AnyOfFirstTwoUseOp:
// For permlane16/permlanex16: <old> <src0> <src1> <src2> <fi>
// <bound_control> Check if either src0 (operand 1) or src1 (operand 2 -
// lane select) is uniform
return !isDivergentUse(I.getOperandUse(1)) ||
!isDivergentUse(I.getOperandUse(2));
default:
return false;
}
}

// This ensures explicit instantiation of
// GenericUniformityAnalysisImpl::ImplDeleter::operator()
template class llvm::GenericUniformityInfo<SSAContext>;
Expand Down
55 changes: 49 additions & 6 deletions llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,18 @@ void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::initialize() {
for (const MachineBasicBlock &block : F) {
for (const MachineInstr &instr : block) {
auto uniformity = InstrInfo.getInstructionUniformity(instr);
if (uniformity == InstructionUniformity::AlwaysUniform) {
addUniformOverride(instr);
continue;
}

if (uniformity == InstructionUniformity::NeverUniform) {
switch (uniformity) {
case InstructionUniformity::NeverUniform:
markDivergent(instr);
break;
case InstructionUniformity::AlwaysUniform:
addUniformOverride(instr);
break;
case InstructionUniformity::Default:
break;
default:
addUniformInstruction(&instr, uniformity);
break;
}
}
}
Expand Down Expand Up @@ -148,6 +153,44 @@ bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::isDivergentUse(
return isTemporalDivergent(*UseInstr->getParent(), *DefInstr);
}

template <>
bool GenericUniformityAnalysisImpl<MachineSSAContext>::isOperandUniform(
const MachineInstr &MI, InstructionUniformity IU) const {
switch (IU) {
// For permlane16/permlanex16, check if either src or lane select is uniform
// These instructions have mixed immediate and register operands:
// Operand 1 is src0 (the source value to permute)
// Operand 3 is src1 (lane select - which lane within the 16 to read from)
// Result is uniform if EITHER the source OR lane select is uniform
case InstructionUniformity::AnyOfFirstTwoUseOp: {
// Check if any of the first two register use operands is uniform
// Result is uniform if ANY of these operands is uniform
const MachineOperand *FirstRegOp = nullptr;
const MachineOperand *SecondRegOp = nullptr;

// Find the first two register use operands
for (const MachineOperand &MO : MI.uses()) {
if (MO.isReg() && MO.getReg().isVirtual()) {
if (!FirstRegOp)
FirstRegOp = &MO;
else if (!SecondRegOp) {
SecondRegOp = &MO;
break;
}
}
}

if (!FirstRegOp || !SecondRegOp)
return false;

// Return true if either operand is uniform
return !isDivergentUse(*FirstRegOp) || !isDivergentUse(*SecondRegOp);
}
default:
return false;
}
}

// This ensures explicit instantiation of
// GenericUniformityAnalysisImpl::ImplDeleter::operator()
template class llvm::GenericUniformityInfo<MachineSSAContext>;
Expand Down
2 changes: 0 additions & 2 deletions llvm/lib/Target/AMDGPU/AMDGPUSearchableTables.td
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,6 @@ def : SourceOfDivergence<int_amdgcn_live_mask>;
def : SourceOfDivergence<int_amdgcn_ds_swizzle>;
def : SourceOfDivergence<int_amdgcn_ds_ordered_add>;
def : SourceOfDivergence<int_amdgcn_ds_ordered_swap>;
def : SourceOfDivergence<int_amdgcn_permlane16>;
def : SourceOfDivergence<int_amdgcn_permlanex16>;
def : SourceOfDivergence<int_amdgcn_permlane16_var>;
def : SourceOfDivergence<int_amdgcn_permlanex16_var>;
def : SourceOfDivergence<int_amdgcn_permlane_bcast>;
Expand Down
29 changes: 29 additions & 0 deletions llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1574,3 +1574,32 @@ unsigned GCNTTIImpl::getNumberOfParts(Type *Tp) const {
}
return BaseT::getNumberOfParts(Tp);
}

// New API that wraps the old isSourceOfDivergence and isAlwaysUniform APIs
// with additional support for new uniformity classifications
InstructionUniformity
GCNTTIImpl::getInstructionUniformity(const Value *V) const {
// Check for new special cases first (permlane16/permlanex16)
// These need operand-dependent uniformity analysis
if (const IntrinsicInst *Intrinsic = dyn_cast<IntrinsicInst>(V)) {
switch (Intrinsic->getIntrinsicID()) {
case Intrinsic::amdgcn_permlane16:
case Intrinsic::amdgcn_permlanex16:
// Result value can be uniform if either of first two operands are uniform
return InstructionUniformity::AnyOfFirstTwoUseOp;
default:
break;
}
}

// Delegate to old APIs for backward compatibility
if (isAlwaysUniform(V))
return InstructionUniformity::AlwaysUniform;

// Check if source of divergence
if (isSourceOfDivergence(V))
return InstructionUniformity::NeverUniform;

// Default behavior
return InstructionUniformity::Default;
}
2 changes: 2 additions & 0 deletions llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,8 @@ class GCNTTIImpl final : public BasicTTIImplBase<GCNTTIImpl> {
/// together under a single i32 value. Otherwise fall back to base
/// implementation.
unsigned getNumberOfParts(Type *Tp) const override;

InstructionUniformity getInstructionUniformity(const Value *V) const override;
};

} // end namespace llvm
Expand Down
7 changes: 7 additions & 0 deletions llvm/lib/Target/AMDGPU/SIInstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10574,6 +10574,13 @@ SIInstrInfo::getInstructionUniformity(const MachineInstr &MI) const {
return InstructionUniformity::NeverUniform;

unsigned opcode = MI.getOpcode();

// Special handling for permlane16/permlanex16 - uniformity depends on
// operands
if (opcode == AMDGPU::V_PERMLANE16_B32_e64 ||
opcode == AMDGPU::V_PERMLANEX16_B32_e64)
return InstructionUniformity::AnyOfFirstTwoUseOp;

if (opcode == AMDGPU::V_READLANE_B32 ||
opcode == AMDGPU::V_READFIRSTLANE_B32 ||
opcode == AMDGPU::SI_RESTORE_S32_FROM_VGPR)
Expand Down
13 changes: 13 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -635,3 +635,16 @@ void NVPTXTTIImpl::collectKernelLaunchBounds(
if (MaxNTID.size() > 2)
LB.push_back({"maxntidz", MaxNTID[2]});
}

// New API that wraps the old isSourceOfDivergence API
// NVPTX doesn't have isAlwaysUniform, so we only delegate to
// isSourceOfDivergence
InstructionUniformity
NVPTXTTIImpl::getInstructionUniformity(const Value *V) const {
// Delegate to old API for backward compatibility
if (isSourceOfDivergence(V))
return InstructionUniformity::NeverUniform;

// Default behavior
return InstructionUniformity::Default;
}
1 change: 1 addition & 0 deletions llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ class NVPTXTTIImpl final : public BasicTTIImplBase<NVPTXTTIImpl> {
// Self-referential globals are not supported.
return false;
}
InstructionUniformity getInstructionUniformity(const Value *V) const override;
};

} // end namespace llvm
Expand Down
Loading