Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,16 @@ LLVM_ABI LegalityPredicate scalarWiderThan(unsigned TypeIdx, unsigned Size);
LLVM_ABI LegalityPredicate scalarOrEltNarrowerThan(unsigned TypeIdx,
unsigned Size);

/// True iff the specified type index is a vector with a number of elements
/// that's greater than the given size.
LLVM_ABI LegalityPredicate vectorElementCountIsGreaterThan(unsigned TypeIdx,
unsigned Size);

/// True iff the specified type index is a vector with a number of elements
/// that's less than or equal to the given size.
LLVM_ABI LegalityPredicate
vectorElementCountIsLessThanOrEqualTo(unsigned TypeIdx, unsigned Size);

/// True iff the specified type index is a scalar or a vector with an element
/// type that's wider than the given size.
LLVM_ABI LegalityPredicate scalarOrEltWiderThan(unsigned TypeIdx,
Expand Down
20 changes: 20 additions & 0 deletions llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,26 @@ LegalityPredicate LegalityPredicates::scalarOrEltNarrowerThan(unsigned TypeIdx,
};
}

LegalityPredicate
LegalityPredicates::vectorElementCountIsGreaterThan(unsigned TypeIdx,
unsigned Size) {

return [=](const LegalityQuery &Query) {
const LLT QueryTy = Query.Types[TypeIdx];
return QueryTy.isFixedVector() && QueryTy.getNumElements() > Size;
};
}

LegalityPredicate
LegalityPredicates::vectorElementCountIsLessThanOrEqualTo(unsigned TypeIdx,
unsigned Size) {

return [=](const LegalityQuery &Query) {
const LLT QueryTy = Query.Types[TypeIdx];
return QueryTy.isFixedVector() && QueryTy.getNumElements() <= Size;
};
}

LegalityPredicate LegalityPredicates::scalarOrEltWiderThan(unsigned TypeIdx,
unsigned Size) {
return [=](const LegalityQuery &Query) {
Expand Down
50 changes: 37 additions & 13 deletions llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1781,33 +1781,57 @@ bool SPIRVInstructionSelector::selectUnmergeValues(MachineInstr &I) const {
unsigned ArgI = I.getNumOperands() - 1;
Register SrcReg =
I.getOperand(ArgI).isReg() ? I.getOperand(ArgI).getReg() : Register(0);
SPIRVType *DefType =
SPIRVType *SrcType =
SrcReg.isValid() ? GR.getSPIRVTypeForVReg(SrcReg) : nullptr;
if (!DefType || DefType->getOpcode() != SPIRV::OpTypeVector)
if (!SrcType || SrcType->getOpcode() != SPIRV::OpTypeVector)
report_fatal_error(
"cannot select G_UNMERGE_VALUES with a non-vector argument");

SPIRVType *ScalarType =
GR.getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
GR.getSPIRVTypeForVReg(SrcType->getOperand(1).getReg());
MachineBasicBlock &BB = *I.getParent();
bool Res = false;
unsigned CurrentIndex = 0;
for (unsigned i = 0; i < I.getNumDefs(); ++i) {
Register ResVReg = I.getOperand(i).getReg();
SPIRVType *ResType = GR.getSPIRVTypeForVReg(ResVReg);
if (!ResType) {
// There was no "assign type" actions, let's fix this now
ResType = ScalarType;
LLT ResLLT = MRI->getType(ResVReg);
assert(ResLLT.isValid());
if (ResLLT.isVector()) {
ResType = GR.getOrCreateSPIRVVectorType(
ScalarType, ResLLT.getNumElements(), I, TII);
} else {
ResType = ScalarType;
}
MRI->setRegClass(ResVReg, GR.getRegClass(ResType));
MRI->setType(ResVReg, LLT::scalar(GR.getScalarOrVectorBitWidth(ResType)));
GR.assignSPIRVTypeToVReg(ResType, ResVReg, *GR.CurMF);
}
auto MIB =
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(SrcReg)
.addImm(static_cast<int64_t>(i));
Res |= MIB.constrainAllUses(TII, TRI, RBI);

if (ResType->getOpcode() == SPIRV::OpTypeVector) {
Register UndefReg = GR.getOrCreateUndef(I, SrcType, TII);
auto MIB =
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpVectorShuffle))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(SrcReg)
.addUse(UndefReg);
unsigned NumElements = GR.getScalarOrVectorComponentCount(ResType);
for (unsigned j = 0; j < NumElements; ++j) {
MIB.addImm(CurrentIndex + j);
}
CurrentIndex += NumElements;
Res |= MIB.constrainAllUses(TII, TRI, RBI);
} else {
auto MIB =
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(SrcReg)
.addImm(CurrentIndex);
CurrentIndex++;
Res |= MIB.constrainAllUses(TII, TRI, RBI);
}
}
return Res;
}
Expand Down
186 changes: 175 additions & 11 deletions llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,22 @@
#include "SPIRV.h"
#include "SPIRVGlobalRegistry.h"
#include "SPIRVSubtarget.h"
#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
#include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
#include "llvm/CodeGen/MachineInstr.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/TargetOpcodes.h"
#include "llvm/IR/IntrinsicsSPIRV.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/MathExtras.h"

using namespace llvm;
using namespace llvm::LegalizeActions;
using namespace llvm::LegalityPredicates;

#define DEBUG_TYPE "spirv-legalizer"

LegalityPredicate typeOfExtendedScalars(unsigned TypeIdx, bool IsExtendedInts) {
return [IsExtendedInts, TypeIdx](const LegalityQuery &Query) {
const LLT Ty = Query.Types[TypeIdx];
Expand Down Expand Up @@ -101,6 +107,10 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
v4s64, v8s1, v8s8, v8s16, v8s32, v8s64, v16s1,
v16s8, v16s16, v16s32, v16s64};

auto allShaderVectors = {v2s1, v2s8, v2s16, v2s32, v2s64,
v3s1, v3s8, v3s16, v3s32, v3s64,
v4s1, v4s8, v4s16, v4s32, v4s64};

auto allScalarsAndVectors = {
s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64,
v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64,
Expand All @@ -126,6 +136,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {

auto allPtrs = {p0, p1, p2, p3, p4, p5, p6, p7, p8, p10, p11, p12};

auto &allowedVectorTypes = ST.isShader() ? allShaderVectors : allVectors;

bool IsExtendedInts =
ST.canUseExtension(
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
Expand All @@ -148,14 +160,64 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
return IsExtendedInts && Ty.isValid();
};

for (auto Opc : getTypeFoldingSupportedOpcodes())
getActionDefinitionsBuilder(Opc).custom();
uint32_t MaxVectorSize = ST.isShader() ? 4 : 16;

getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal();
for (auto Opc : getTypeFoldingSupportedOpcodes()) {
if (Opc != G_EXTRACT_VECTOR_ELT)
getActionDefinitionsBuilder(Opc).custom();
}

// TODO: add proper rules for vectors legalization.
getActionDefinitionsBuilder(
{G_BUILD_VECTOR, G_SHUFFLE_VECTOR, G_SPLAT_VECTOR})
getActionDefinitionsBuilder(G_INTRINSIC_W_SIDE_EFFECTS).custom();

getActionDefinitionsBuilder(G_SHUFFLE_VECTOR)
.legalForCartesianProduct(allowedVectorTypes, allowedVectorTypes)
.moreElementsToNextPow2(0)
.lowerIf(vectorElementCountIsGreaterThan(0, MaxVectorSize))
.moreElementsToNextPow2(1)
.lowerIf(vectorElementCountIsGreaterThan(1, MaxVectorSize))
.alwaysLegal();

getActionDefinitionsBuilder(G_EXTRACT_VECTOR_ELT)
.moreElementsToNextPow2(1)
.fewerElementsIf(vectorElementCountIsGreaterThan(1, MaxVectorSize),
LegalizeMutations::changeElementCountTo(
1, ElementCount::getFixed(MaxVectorSize)))
.custom();

// Illegal G_UNMERGE_VALUES instructions should be handled
// during the combine phase.
getActionDefinitionsBuilder(G_BUILD_VECTOR)
.legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize))
.fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
LegalizeMutations::changeElementCountTo(
0, ElementCount::getFixed(MaxVectorSize)));

// When entering the legalizer, there should be no G_BITCAST instructions.
// They should all be calls to the `spv_bitcast` intrinsic. The call to
// the intrinsic will be converted to a G_BITCAST during legalization if
// the vectors are not legal. After using the rules to legalize a G_BITCAST,
// we turn it back into a call to the intrinsic with a custom rule to avoid
// potential machine verifier failures.
getActionDefinitionsBuilder(G_BITCAST)
.moreElementsToNextPow2(0)
.moreElementsToNextPow2(1)
.fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
LegalizeMutations::changeElementCountTo(
0, ElementCount::getFixed(MaxVectorSize)))
.lowerIf(vectorElementCountIsGreaterThan(1, MaxVectorSize))
.custom();

getActionDefinitionsBuilder(G_CONCAT_VECTORS)
.legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize))
.moreElementsToNextPow2(0)
.lowerIf(vectorElementCountIsGreaterThan(0, MaxVectorSize))
.alwaysLegal();

getActionDefinitionsBuilder(G_SPLAT_VECTOR)
.legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize))
.moreElementsToNextPow2(0)
.fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
LegalizeMutations::changeElementSizeTo(0, MaxVectorSize))
.alwaysLegal();

// Vector Reduction Operations
Expand All @@ -164,17 +226,18 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
G_VECREDUCE_ADD, G_VECREDUCE_MUL, G_VECREDUCE_FMUL, G_VECREDUCE_FMIN,
G_VECREDUCE_FMAX, G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM,
G_VECREDUCE_OR, G_VECREDUCE_AND, G_VECREDUCE_XOR})
.legalFor(allVectors)
.legalFor(allowedVectorTypes)
.scalarize(1)
.lower();

getActionDefinitionsBuilder({G_VECREDUCE_SEQ_FADD, G_VECREDUCE_SEQ_FMUL})
.scalarize(2)
.lower();

// Merge/Unmerge
// TODO: add proper legalization rules.
getActionDefinitionsBuilder(G_UNMERGE_VALUES).alwaysLegal();
// Illegal G_UNMERGE_VALUES instructions should be handled
// during the combine phase.
getActionDefinitionsBuilder(G_UNMERGE_VALUES)
.legalIf(vectorElementCountIsLessThanOrEqualTo(1, MaxVectorSize));

getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
.legalIf(all(typeInSet(0, allPtrs), typeInSet(1, allPtrs)));
Expand Down Expand Up @@ -228,7 +291,14 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
all(typeInSet(0, allPtrsScalarsAndVectors),
typeInSet(1, allPtrsScalarsAndVectors)));

getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}).alwaysLegal();
getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE})
.legalFor({s1})
.legalFor(allFloatAndIntScalarsAndPtrs)
.legalFor(allowedVectorTypes)
.moreElementsToNextPow2(0)
.fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
LegalizeMutations::changeElementCountTo(
0, ElementCount::getFixed(MaxVectorSize)));

getActionDefinitionsBuilder({G_STACKSAVE, G_STACKRESTORE}).alwaysLegal();

Expand Down Expand Up @@ -287,6 +357,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
// Pointer-handling.
getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0});

getActionDefinitionsBuilder(G_GLOBAL_VALUE).legalFor(allPtrs);

// Control-flow. In some cases (e.g. constants) s1 may be promoted to s32.
getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32});

Expand Down Expand Up @@ -353,6 +425,21 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
verify(*ST.getInstrInfo());
}

static bool legalizeExtractVectorElt(LegalizerHelper &Helper, MachineInstr &MI,
SPIRVGlobalRegistry *GR) {
MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
Register DstReg = MI.getOperand(0).getReg();
Register SrcReg = MI.getOperand(1).getReg();
Register IdxReg = MI.getOperand(2).getReg();

MIRBuilder
.buildIntrinsic(Intrinsic::spv_extractelt, ArrayRef<Register>{DstReg})
.addUse(SrcReg)
.addUse(IdxReg);
MI.eraseFromParent();
return true;
}

static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpvType,
LegalizerHelper &Helper,
MachineRegisterInfo &MRI,
Expand All @@ -374,6 +461,13 @@ bool SPIRVLegalizerInfo::legalizeCustom(
default:
// TODO: implement legalization for other opcodes.
return true;
case TargetOpcode::G_BITCAST:
return legalizeBitcast(Helper, MI);
case TargetOpcode::G_EXTRACT_VECTOR_ELT:
return legalizeExtractVectorElt(Helper, MI, GR);
case TargetOpcode::G_INTRINSIC:
case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
return legalizeIntrinsic(Helper, MI);
case TargetOpcode::G_IS_FPCLASS:
return legalizeIsFPClass(Helper, MI, LocObserver);
case TargetOpcode::G_ICMP: {
Expand All @@ -400,6 +494,76 @@ bool SPIRVLegalizerInfo::legalizeCustom(
}
}

bool SPIRVLegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper,
MachineInstr &MI) const {
LLVM_DEBUG(dbgs() << "legalizeIntrinsic: " << MI);

MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
const SPIRVSubtarget &ST = MI.getMF()->getSubtarget<SPIRVSubtarget>();

auto IntrinsicID = cast<GIntrinsic>(MI).getIntrinsicID();
if (IntrinsicID == Intrinsic::spv_bitcast) {
LLVM_DEBUG(dbgs() << "Found a bitcast instruction\n");
Register DstReg = MI.getOperand(0).getReg();
Register SrcReg = MI.getOperand(2).getReg();
LLT DstTy = MRI.getType(DstReg);
LLT SrcTy = MRI.getType(SrcReg);

int32_t MaxVectorSize = ST.isShader() ? 4 : 16;

bool DstNeedsLegalization = false;
bool SrcNeedsLegalization = false;

if (DstTy.isVector()) {
if (DstTy.getNumElements() > 4 &&
!isPowerOf2_32(DstTy.getNumElements())) {
DstNeedsLegalization = true;
}

if (DstTy.getNumElements() > MaxVectorSize) {
DstNeedsLegalization = true;
}
}

if (SrcTy.isVector()) {
if (SrcTy.getNumElements() > 4 &&
!isPowerOf2_32(SrcTy.getNumElements())) {
SrcNeedsLegalization = true;
}

if (SrcTy.getNumElements() > MaxVectorSize) {
SrcNeedsLegalization = true;
}
}

// If an spv_bitcast needs to be legalized, we convert it to G_BITCAST to
// allow using the generic legalization rules.
if (DstNeedsLegalization || SrcNeedsLegalization) {
LLVM_DEBUG(dbgs() << "Replacing with a G_BITCAST\n");
MIRBuilder.buildBitcast(DstReg, SrcReg);
MI.eraseFromParent();
}
return true;
}
return true;
}

bool SPIRVLegalizerInfo::legalizeBitcast(LegalizerHelper &Helper,
MachineInstr &MI) const {
// Once the G_BITCAST is using vectors that are allowed, we turn it back into
// an spv_bitcast to avoid verifier problems when the register types are the
// same for the source and the result. Note that the SPIR-V types associated
// with the bitcast can be different even if the register types are the same.
MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
Register DstReg = MI.getOperand(0).getReg();
Register SrcReg = MI.getOperand(1).getReg();
SmallVector<Register, 1> DstRegs = {DstReg};
MIRBuilder.buildIntrinsic(Intrinsic::spv_bitcast, DstRegs).addUse(SrcReg);
MI.eraseFromParent();
return true;
}

// Note this code was copied from LegalizerHelper::lowerISFPCLASS and adjusted
// to ensure that all instructions created during the lowering have SPIR-V types
// assigned to them.
Expand Down
Loading
Loading