Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@ class SPIRVLegalizePointerCast : public FunctionPass {
assert(VecTy->getElementType() == ArrTy->getElementType() &&
"Element types of array and vector must be the same.");

const DataLayout &DL = B.GetInsertBlock()->getModule()->getDataLayout();
uint64_t ElemSize = DL.getTypeAllocSize(ArrTy->getElementType());

for (unsigned i = 0; i < VecTy->getNumElements(); ++i) {
// Create a GEP to access the i-th element of the array.
SmallVector<Type *, 2> Types = {DstArrayPtr->getType(),
Expand All @@ -190,7 +193,8 @@ class SPIRVLegalizePointerCast : public FunctionPass {
buildAssignType(B, VecTy->getElementType(), Element);

Types = {Element->getType(), ElementPtr->getType()};
Args = {Element, ElementPtr, B.getInt16(2), B.getInt8(Alignment.value())};
Align NewAlign = commonAlignment(Alignment, i * ElemSize);
Args = {Element, ElementPtr, B.getInt16(2), B.getInt8(NewAlign.value())};
B.CreateIntrinsic(Intrinsic::spv_store, {Types}, {Args});
}
}
Expand Down
117 changes: 83 additions & 34 deletions llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
v3s1, v3s8, v3s16, v3s32, v3s64,
v4s1, v4s8, v4s16, v4s32, v4s64};

auto allScalars = {s1, s8, s16, s32, s64};

auto allScalarsAndVectors = {
s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64,
v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64,
Expand Down Expand Up @@ -172,9 +174,25 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {

for (auto Opc : getTypeFoldingSupportedOpcodes()) {
if (Opc != G_EXTRACT_VECTOR_ELT)
getActionDefinitionsBuilder(Opc).custom();
getActionDefinitionsBuilder(Opc)
.customFor(allScalars)
.customFor(allowedVectorTypes)
.moreElementsToNextPow2(0)
.fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
LegalizeMutations::changeElementCountTo(
0, ElementCount::getFixed(MaxVectorSize)))
.custom();
}

getActionDefinitionsBuilder(TargetOpcode::G_FMA)
.legalFor(allScalars)
.legalFor(allowedVectorTypes)
.moreElementsToNextPow2(0)
.fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
LegalizeMutations::changeElementCountTo(
0, ElementCount::getFixed(MaxVectorSize)))
.alwaysLegal();

getActionDefinitionsBuilder(G_INTRINSIC_W_SIDE_EFFECTS).custom();

getActionDefinitionsBuilder(G_SHUFFLE_VECTOR)
Expand All @@ -192,6 +210,13 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
1, ElementCount::getFixed(MaxVectorSize)))
.custom();

getActionDefinitionsBuilder(G_INSERT_VECTOR_ELT)
.moreElementsToNextPow2(0)
.fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
LegalizeMutations::changeElementCountTo(
0, ElementCount::getFixed(MaxVectorSize)))
.custom();

// Illegal G_UNMERGE_VALUES instructions should be handled
// during the combine phase.
getActionDefinitionsBuilder(G_BUILD_VECTOR)
Expand All @@ -215,14 +240,13 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
.lowerIf(vectorElementCountIsGreaterThan(1, MaxVectorSize))
.custom();

// If the result is still illegal, the combiner should be able to remove it.
getActionDefinitionsBuilder(G_CONCAT_VECTORS)
.legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize))
.moreElementsToNextPow2(0)
.lowerIf(vectorElementCountIsGreaterThan(0, MaxVectorSize))
.alwaysLegal();
.legalForCartesianProduct(allowedVectorTypes, allowedVectorTypes)
.moreElementsToNextPow2(0);

getActionDefinitionsBuilder(G_SPLAT_VECTOR)
.legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize))
.legalFor(allowedVectorTypes)
.moreElementsToNextPow2(0)
.fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
LegalizeMutations::changeElementSizeTo(0, MaxVectorSize))
Expand Down Expand Up @@ -458,6 +482,23 @@ static bool legalizeExtractVectorElt(LegalizerHelper &Helper, MachineInstr &MI,
return true;
}

static bool legalizeInsertVectorElt(LegalizerHelper &Helper, MachineInstr &MI,
SPIRVGlobalRegistry *GR) {
MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
Register DstReg = MI.getOperand(0).getReg();
Register SrcReg = MI.getOperand(1).getReg();
Register ValReg = MI.getOperand(2).getReg();
Register IdxReg = MI.getOperand(3).getReg();

MIRBuilder
.buildIntrinsic(Intrinsic::spv_insertelt, ArrayRef<Register>{DstReg})
.addUse(SrcReg)
.addUse(ValReg)
.addUse(IdxReg);
MI.eraseFromParent();
return true;
}

static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpvType,
LegalizerHelper &Helper,
MachineRegisterInfo &MRI,
Expand All @@ -483,6 +524,8 @@ bool SPIRVLegalizerInfo::legalizeCustom(
return legalizeBitcast(Helper, MI);
case TargetOpcode::G_EXTRACT_VECTOR_ELT:
return legalizeExtractVectorElt(Helper, MI, GR);
case TargetOpcode::G_INSERT_VECTOR_ELT:
return legalizeInsertVectorElt(Helper, MI, GR);
case TargetOpcode::G_INTRINSIC:
case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
return legalizeIntrinsic(Helper, MI);
Expand Down Expand Up @@ -512,6 +555,15 @@ bool SPIRVLegalizerInfo::legalizeCustom(
}
}

static bool needsVectorLegalization(const LLT &Ty, const SPIRVSubtarget &ST) {
if (!Ty.isVector())
return false;
unsigned NumElements = Ty.getNumElements();
unsigned MaxVectorSize = ST.isShader() ? 4 : 16;
return (NumElements > 4 && !isPowerOf2_32(NumElements)) ||
NumElements > MaxVectorSize;
}

bool SPIRVLegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper,
MachineInstr &MI) const {
LLVM_DEBUG(dbgs() << "legalizeIntrinsic: " << MI);
Expand All @@ -528,41 +580,38 @@ bool SPIRVLegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper,
LLT DstTy = MRI.getType(DstReg);
LLT SrcTy = MRI.getType(SrcReg);

int32_t MaxVectorSize = ST.isShader() ? 4 : 16;

bool DstNeedsLegalization = false;
bool SrcNeedsLegalization = false;

if (DstTy.isVector()) {
if (DstTy.getNumElements() > 4 &&
!isPowerOf2_32(DstTy.getNumElements())) {
DstNeedsLegalization = true;
}

if (DstTy.getNumElements() > MaxVectorSize) {
DstNeedsLegalization = true;
}
}

if (SrcTy.isVector()) {
if (SrcTy.getNumElements() > 4 &&
!isPowerOf2_32(SrcTy.getNumElements())) {
SrcNeedsLegalization = true;
}

if (SrcTy.getNumElements() > MaxVectorSize) {
SrcNeedsLegalization = true;
}
}

// If an spv_bitcast needs to be legalized, we convert it to G_BITCAST to
// allow using the generic legalization rules.
if (DstNeedsLegalization || SrcNeedsLegalization) {
if (needsVectorLegalization(DstTy, ST) ||
needsVectorLegalization(SrcTy, ST)) {
LLVM_DEBUG(dbgs() << "Replacing with a G_BITCAST\n");
MIRBuilder.buildBitcast(DstReg, SrcReg);
MI.eraseFromParent();
}
return true;
} else if (IntrinsicID == Intrinsic::spv_insertelt) {
Register DstReg = MI.getOperand(0).getReg();
LLT DstTy = MRI.getType(DstReg);

if (needsVectorLegalization(DstTy, ST)) {
Register SrcReg = MI.getOperand(2).getReg();
Register ValReg = MI.getOperand(3).getReg();
Register IdxReg = MI.getOperand(4).getReg();
MIRBuilder.buildInsertVectorElement(DstReg, SrcReg, ValReg, IdxReg);
MI.eraseFromParent();
}
return true;
} else if (IntrinsicID == Intrinsic::spv_extractelt) {
Register SrcReg = MI.getOperand(2).getReg();
LLT SrcTy = MRI.getType(SrcReg);

if (needsVectorLegalization(SrcTy, ST)) {
Register DstReg = MI.getOperand(0).getReg();
Register IdxReg = MI.getOperand(3).getReg();
MIRBuilder.buildExtractVectorElement(DstReg, SrcReg, IdxReg);
MI.eraseFromParent();
}
return true;
}
return true;
}
Expand Down
78 changes: 56 additions & 22 deletions llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "SPIRV.h"
#include "SPIRVSubtarget.h"
#include "SPIRVUtils.h"
#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
#include "llvm/IR/IntrinsicsSPIRV.h"
#include "llvm/Support/Debug.h"
#include <stack>
Expand Down Expand Up @@ -66,8 +67,9 @@ static bool deduceAndAssignTypeForGUnmerge(MachineInstr *I, MachineFunction &MF,
for (unsigned i = 0; i < I->getNumDefs() && !ScalarType; ++i) {
for (const auto &Use :
MRI.use_nodbg_instructions(I->getOperand(i).getReg())) {
assert(Use.getOpcode() == TargetOpcode::G_BUILD_VECTOR &&
"Expected use of G_UNMERGE_VALUES to be a G_BUILD_VECTOR");
if (Use.getOpcode() != TargetOpcode::G_BUILD_VECTOR)
continue;

if (auto *VecType =
GR->getSPIRVTypeForVReg(Use.getOperand(0).getReg())) {
ScalarType = GR->getScalarOrVectorComponentType(VecType);
Expand Down Expand Up @@ -133,10 +135,10 @@ static SPIRVType *deduceTypeFromOperandRange(MachineInstr *I,
return ResType;
}

static SPIRVType *deduceTypeForResultRegister(MachineInstr *Use,
Register UseRegister,
SPIRVGlobalRegistry *GR,
MachineIRBuilder &MIB) {
static SPIRVType *deduceTypeFromResultRegister(MachineInstr *Use,
Register UseRegister,
SPIRVGlobalRegistry *GR,
MachineIRBuilder &MIB) {
for (const MachineOperand &MO : Use->defs()) {
if (!MO.isReg())
continue;
Expand All @@ -159,16 +161,43 @@ static SPIRVType *deduceTypeFromUses(Register Reg, MachineFunction &MF,
MachineRegisterInfo &MRI = MF.getRegInfo();
for (MachineInstr &Use : MRI.use_nodbg_instructions(Reg)) {
SPIRVType *ResType = nullptr;
LLVM_DEBUG(dbgs() << "Looking at use " << Use);
switch (Use.getOpcode()) {
case TargetOpcode::G_BUILD_VECTOR:
case TargetOpcode::G_EXTRACT_VECTOR_ELT:
case TargetOpcode::G_UNMERGE_VALUES:
LLVM_DEBUG(dbgs() << "Looking at use " << Use << "\n");
ResType = deduceTypeForResultRegister(&Use, Reg, GR, MIB);
case TargetOpcode::G_ADD:
case TargetOpcode::G_SUB:
case TargetOpcode::G_MUL:
case TargetOpcode::G_SDIV:
case TargetOpcode::G_UDIV:
case TargetOpcode::G_SREM:
case TargetOpcode::G_UREM:
case TargetOpcode::G_FADD:
case TargetOpcode::G_FSUB:
case TargetOpcode::G_FMUL:
case TargetOpcode::G_FDIV:
case TargetOpcode::G_FREM:
case TargetOpcode::G_FMA:
ResType = deduceTypeFromResultRegister(&Use, Reg, GR, MIB);
break;
case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
case TargetOpcode::G_INTRINSIC: {
auto IntrinsicID = cast<GIntrinsic>(Use).getIntrinsicID();
if (IntrinsicID == Intrinsic::spv_insertelt) {
if (Reg == Use.getOperand(2).getReg())
ResType = deduceTypeFromResultRegister(&Use, Reg, GR, MIB);
} else if (IntrinsicID == Intrinsic::spv_extractelt) {
if (Reg == Use.getOperand(2).getReg())
ResType = deduceTypeFromResultRegister(&Use, Reg, GR, MIB);
}
break;
}
}
if (ResType)
if (ResType) {
LLVM_DEBUG(dbgs() << "Deduced type from use " << *ResType);
return ResType;
}
}
return nullptr;
}
Expand Down Expand Up @@ -296,20 +325,25 @@ static void registerSpirvTypeForNewInstructions(MachineFunction &MF,

for (auto *I : Worklist) {
MachineIRBuilder MIB(*I);
Register ResVReg = I->getOperand(0).getReg();
const LLT &ResLLT = MRI.getType(ResVReg);
SPIRVType *ResType = nullptr;
if (ResLLT.isVector()) {
SPIRVType *CompType = GR->getOrCreateSPIRVIntegerType(
ResLLT.getElementType().getSizeInBits(), MIB);
ResType = GR->getOrCreateSPIRVVectorType(
CompType, ResLLT.getNumElements(), MIB, false);
} else {
ResType = GR->getOrCreateSPIRVIntegerType(ResLLT.getSizeInBits(), MIB);
for (unsigned Idx = 0; Idx < I->getNumDefs(); ++Idx) {
Register ResVReg = I->getOperand(Idx).getReg();
if (GR->getSPIRVTypeForVReg(ResVReg))
continue;
const LLT &ResLLT = MRI.getType(ResVReg);
SPIRVType *ResType = nullptr;
if (ResLLT.isVector()) {
SPIRVType *CompType = GR->getOrCreateSPIRVIntegerType(
ResLLT.getElementType().getSizeInBits(), MIB);
ResType = GR->getOrCreateSPIRVVectorType(
CompType, ResLLT.getNumElements(), MIB, false);
} else {
ResType = GR->getOrCreateSPIRVIntegerType(ResLLT.getSizeInBits(), MIB);
}
LLVM_DEBUG(dbgs() << "Could not determine type for " << ResVReg
<< ", defaulting to " << *ResType << "\n");

setRegClassType(ResVReg, ResType, GR, &MRI, MF, true);
}
LLVM_DEBUG(dbgs() << "Could not determine type for " << *I
<< ", defaulting to " << *ResType << "\n");
setRegClassType(ResVReg, ResType, GR, &MRI, MF, true);
}
}

Expand Down
Loading