Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 83 additions & 34 deletions llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
v3s1, v3s8, v3s16, v3s32, v3s64,
v4s1, v4s8, v4s16, v4s32, v4s64};

auto allScalars = {s1, s8, s16, s32};

auto allScalarsAndVectors = {
s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64,
v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64,
Expand Down Expand Up @@ -172,9 +174,25 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {

for (auto Opc : getTypeFoldingSupportedOpcodes()) {
if (Opc != G_EXTRACT_VECTOR_ELT)
getActionDefinitionsBuilder(Opc).custom();
getActionDefinitionsBuilder(Opc)
.customFor(allScalars)
.customFor(allowedVectorTypes)
.moreElementsToNextPow2(0)
.fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
LegalizeMutations::changeElementCountTo(
0, ElementCount::getFixed(MaxVectorSize)))
.custom();
}

getActionDefinitionsBuilder(TargetOpcode::G_FMA)
.legalFor(allScalars)
.legalFor(allowedVectorTypes)
.moreElementsToNextPow2(0)
.fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
LegalizeMutations::changeElementCountTo(
0, ElementCount::getFixed(MaxVectorSize)))
.alwaysLegal();

getActionDefinitionsBuilder(G_INTRINSIC_W_SIDE_EFFECTS).custom();

getActionDefinitionsBuilder(G_SHUFFLE_VECTOR)
Expand All @@ -192,6 +210,13 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
1, ElementCount::getFixed(MaxVectorSize)))
.custom();

getActionDefinitionsBuilder(G_INSERT_VECTOR_ELT)
.moreElementsToNextPow2(0)
.fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
LegalizeMutations::changeElementCountTo(
0, ElementCount::getFixed(MaxVectorSize)))
.custom();

// Illegal G_UNMERGE_VALUES instructions should be handled
// during the combine phase.
getActionDefinitionsBuilder(G_BUILD_VECTOR)
Expand All @@ -215,14 +240,13 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
.lowerIf(vectorElementCountIsGreaterThan(1, MaxVectorSize))
.custom();

// If the result is still illegal, the combiner should be able to remove it.
getActionDefinitionsBuilder(G_CONCAT_VECTORS)
.legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize))
.moreElementsToNextPow2(0)
.lowerIf(vectorElementCountIsGreaterThan(0, MaxVectorSize))
.alwaysLegal();
.legalForCartesianProduct(allowedVectorTypes, allowedVectorTypes)
.moreElementsToNextPow2(0);

getActionDefinitionsBuilder(G_SPLAT_VECTOR)
.legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize))
.legalFor(allowedVectorTypes)
.moreElementsToNextPow2(0)
.fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
LegalizeMutations::changeElementSizeTo(0, MaxVectorSize))
Expand Down Expand Up @@ -458,6 +482,23 @@ static bool legalizeExtractVectorElt(LegalizerHelper &Helper, MachineInstr &MI,
return true;
}

static bool legalizeInsertVectorElt(LegalizerHelper &Helper, MachineInstr &MI,
SPIRVGlobalRegistry *GR) {
MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
Register DstReg = MI.getOperand(0).getReg();
Register SrcReg = MI.getOperand(1).getReg();
Register ValReg = MI.getOperand(2).getReg();
Register IdxReg = MI.getOperand(3).getReg();

MIRBuilder
.buildIntrinsic(Intrinsic::spv_insertelt, ArrayRef<Register>{DstReg})
.addUse(SrcReg)
.addUse(ValReg)
.addUse(IdxReg);
MI.eraseFromParent();
return true;
}

static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpvType,
LegalizerHelper &Helper,
MachineRegisterInfo &MRI,
Expand All @@ -483,6 +524,8 @@ bool SPIRVLegalizerInfo::legalizeCustom(
return legalizeBitcast(Helper, MI);
case TargetOpcode::G_EXTRACT_VECTOR_ELT:
return legalizeExtractVectorElt(Helper, MI, GR);
case TargetOpcode::G_INSERT_VECTOR_ELT:
return legalizeInsertVectorElt(Helper, MI, GR);
case TargetOpcode::G_INTRINSIC:
case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
return legalizeIntrinsic(Helper, MI);
Expand Down Expand Up @@ -512,6 +555,15 @@ bool SPIRVLegalizerInfo::legalizeCustom(
}
}

static bool needsVectorLegalization(const LLT &Ty, const SPIRVSubtarget &ST) {
if (!Ty.isVector())
return false;
unsigned NumElements = Ty.getNumElements();
unsigned MaxVectorSize = ST.isShader() ? 4 : 16;
return (NumElements > 4 && !isPowerOf2_32(NumElements)) ||
NumElements > MaxVectorSize;
}

bool SPIRVLegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper,
MachineInstr &MI) const {
LLVM_DEBUG(dbgs() << "legalizeIntrinsic: " << MI);
Expand All @@ -528,41 +580,38 @@ bool SPIRVLegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper,
LLT DstTy = MRI.getType(DstReg);
LLT SrcTy = MRI.getType(SrcReg);

int32_t MaxVectorSize = ST.isShader() ? 4 : 16;

bool DstNeedsLegalization = false;
bool SrcNeedsLegalization = false;

if (DstTy.isVector()) {
if (DstTy.getNumElements() > 4 &&
!isPowerOf2_32(DstTy.getNumElements())) {
DstNeedsLegalization = true;
}

if (DstTy.getNumElements() > MaxVectorSize) {
DstNeedsLegalization = true;
}
}

if (SrcTy.isVector()) {
if (SrcTy.getNumElements() > 4 &&
!isPowerOf2_32(SrcTy.getNumElements())) {
SrcNeedsLegalization = true;
}

if (SrcTy.getNumElements() > MaxVectorSize) {
SrcNeedsLegalization = true;
}
}

// If an spv_bitcast needs to be legalized, we convert it to G_BITCAST to
// allow using the generic legalization rules.
if (DstNeedsLegalization || SrcNeedsLegalization) {
if (needsVectorLegalization(DstTy, ST) ||
needsVectorLegalization(SrcTy, ST)) {
LLVM_DEBUG(dbgs() << "Replacing with a G_BITCAST\n");
MIRBuilder.buildBitcast(DstReg, SrcReg);
MI.eraseFromParent();
}
return true;
} else if (IntrinsicID == Intrinsic::spv_insertelt) {
Register DstReg = MI.getOperand(0).getReg();
LLT DstTy = MRI.getType(DstReg);

if (needsVectorLegalization(DstTy, ST)) {
Register SrcReg = MI.getOperand(2).getReg();
Register ValReg = MI.getOperand(3).getReg();
Register IdxReg = MI.getOperand(4).getReg();
MIRBuilder.buildInsertVectorElement(DstReg, SrcReg, ValReg, IdxReg);
MI.eraseFromParent();
}
return true;
} else if (IntrinsicID == Intrinsic::spv_extractelt) {
Register SrcReg = MI.getOperand(2).getReg();
LLT SrcTy = MRI.getType(SrcReg);

if (needsVectorLegalization(SrcTy, ST)) {
Register DstReg = MI.getOperand(0).getReg();
Register IdxReg = MI.getOperand(3).getReg();
MIRBuilder.buildExtractVectorElement(DstReg, SrcReg, IdxReg);
MI.eraseFromParent();
}
return true;
}
return true;
}
Expand Down
47 changes: 38 additions & 9 deletions llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "SPIRV.h"
#include "SPIRVSubtarget.h"
#include "SPIRVUtils.h"
#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
#include "llvm/IR/IntrinsicsSPIRV.h"
#include "llvm/Support/Debug.h"
#include <stack>
Expand Down Expand Up @@ -66,8 +67,9 @@ static bool deduceAndAssignTypeForGUnmerge(MachineInstr *I, MachineFunction &MF,
for (unsigned i = 0; i < I->getNumDefs() && !ScalarType; ++i) {
for (const auto &Use :
MRI.use_nodbg_instructions(I->getOperand(i).getReg())) {
assert(Use.getOpcode() == TargetOpcode::G_BUILD_VECTOR &&
"Expected use of G_UNMERGE_VALUES to be a G_BUILD_VECTOR");
if (Use.getOpcode() != TargetOpcode::G_BUILD_VECTOR)
continue;

if (auto *VecType =
GR->getSPIRVTypeForVReg(Use.getOperand(0).getReg())) {
ScalarType = GR->getScalarOrVectorComponentType(VecType);
Expand Down Expand Up @@ -133,10 +135,10 @@ static SPIRVType *deduceTypeFromOperandRange(MachineInstr *I,
return ResType;
}

static SPIRVType *deduceTypeForResultRegister(MachineInstr *Use,
Register UseRegister,
SPIRVGlobalRegistry *GR,
MachineIRBuilder &MIB) {
static SPIRVType *deduceTypeFromResultRegister(MachineInstr *Use,
Register UseRegister,
SPIRVGlobalRegistry *GR,
MachineIRBuilder &MIB) {
for (const MachineOperand &MO : Use->defs()) {
if (!MO.isReg())
continue;
Expand All @@ -159,16 +161,43 @@ static SPIRVType *deduceTypeFromUses(Register Reg, MachineFunction &MF,
MachineRegisterInfo &MRI = MF.getRegInfo();
for (MachineInstr &Use : MRI.use_nodbg_instructions(Reg)) {
SPIRVType *ResType = nullptr;
LLVM_DEBUG(dbgs() << "Looking at use " << Use);
switch (Use.getOpcode()) {
case TargetOpcode::G_BUILD_VECTOR:
case TargetOpcode::G_EXTRACT_VECTOR_ELT:
case TargetOpcode::G_UNMERGE_VALUES:
LLVM_DEBUG(dbgs() << "Looking at use " << Use << "\n");
ResType = deduceTypeForResultRegister(&Use, Reg, GR, MIB);
case TargetOpcode::G_ADD:
case TargetOpcode::G_SUB:
case TargetOpcode::G_MUL:
case TargetOpcode::G_SDIV:
case TargetOpcode::G_UDIV:
case TargetOpcode::G_SREM:
case TargetOpcode::G_UREM:
case TargetOpcode::G_FADD:
case TargetOpcode::G_FSUB:
case TargetOpcode::G_FMUL:
case TargetOpcode::G_FDIV:
case TargetOpcode::G_FREM:
case TargetOpcode::G_FMA:
ResType = deduceTypeFromResultRegister(&Use, Reg, GR, MIB);
break;
case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
case TargetOpcode::G_INTRINSIC: {
auto IntrinsicID = cast<GIntrinsic>(Use).getIntrinsicID();
if (IntrinsicID == Intrinsic::spv_insertelt) {
if (Reg == Use.getOperand(2).getReg())
ResType = deduceTypeFromResultRegister(&Use, Reg, GR, MIB);
} else if (IntrinsicID == Intrinsic::spv_extractelt) {
if (Reg == Use.getOperand(2).getReg())
ResType = deduceTypeFromResultRegister(&Use, Reg, GR, MIB);
}
break;
}
if (ResType)
}
if (ResType) {
LLVM_DEBUG(dbgs() << "Deduced type from use " << *ResType);
return ResType;
}
}
return nullptr;
}
Expand Down
Loading
Loading