Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@ class SPIRVLegalizePointerCast : public FunctionPass {
assert(VecTy->getElementType() == ArrTy->getElementType() &&
"Element types of array and vector must be the same.");

const DataLayout &DL = B.GetInsertBlock()->getModule()->getDataLayout();
uint64_t ElemSize = DL.getTypeAllocSize(ArrTy->getElementType());

for (unsigned i = 0; i < VecTy->getNumElements(); ++i) {
// Create a GEP to access the i-th element of the array.
SmallVector<Type *, 2> Types = {DstArrayPtr->getType(),
Expand All @@ -190,7 +193,8 @@ class SPIRVLegalizePointerCast : public FunctionPass {
buildAssignType(B, VecTy->getElementType(), Element);

Types = {Element->getType(), ElementPtr->getType()};
Args = {Element, ElementPtr, B.getInt16(2), B.getInt8(Alignment.value())};
Align NewAlign = commonAlignment(Alignment, i * ElemSize);
Args = {Element, ElementPtr, B.getInt16(2), B.getInt8(NewAlign.value())};
B.CreateIntrinsic(Intrinsic::spv_store, {Types}, {Args});
}
}
Expand Down
141 changes: 103 additions & 38 deletions llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
v3s1, v3s8, v3s16, v3s32, v3s64,
v4s1, v4s8, v4s16, v4s32, v4s64};

auto allScalars = {s1, s8, s16, s32, s64};

auto allScalarsAndVectors = {
s1, s8, s16, s32, s64, s128, v2s1, v2s8,
v2s16, v2s32, v2s64, v3s1, v3s8, v3s16, v3s32, v3s64,
Expand Down Expand Up @@ -173,10 +175,45 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
uint32_t MaxVectorSize = ST.isShader() ? 4 : 16;

for (auto Opc : getTypeFoldingSupportedOpcodes()) {
if (Opc != G_EXTRACT_VECTOR_ELT)
getActionDefinitionsBuilder(Opc).custom();
switch (Opc) {
case G_EXTRACT_VECTOR_ELT:
case G_UREM:
case G_SREM:
case G_UDIV:
case G_SDIV:
case G_FREM:
break;
default:
getActionDefinitionsBuilder(Opc)
.customFor(allScalars)
.customFor(allowedVectorTypes)
.moreElementsToNextPow2(0)
.fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
LegalizeMutations::changeElementCountTo(
0, ElementCount::getFixed(MaxVectorSize)))
.custom();
break;
}
}

getActionDefinitionsBuilder({G_UREM, G_SREM, G_SDIV, G_UDIV, G_FREM})
.customFor(allScalars)
.customFor(allowedVectorTypes)
.scalarizeIf(numElementsNotPow2(0), 0)
.fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
LegalizeMutations::changeElementCountTo(
0, ElementCount::getFixed(MaxVectorSize)))
.custom();

getActionDefinitionsBuilder({G_FMA, G_STRICT_FMA})
.legalFor(allScalars)
.legalFor(allowedVectorTypes)
.moreElementsToNextPow2(0)
.fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
LegalizeMutations::changeElementCountTo(
0, ElementCount::getFixed(MaxVectorSize)))
.alwaysLegal();

getActionDefinitionsBuilder(G_INTRINSIC_W_SIDE_EFFECTS).custom();

getActionDefinitionsBuilder(G_SHUFFLE_VECTOR)
Expand All @@ -194,6 +231,13 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
1, ElementCount::getFixed(MaxVectorSize)))
.custom();

getActionDefinitionsBuilder(G_INSERT_VECTOR_ELT)
.moreElementsToNextPow2(0)
.fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
LegalizeMutations::changeElementCountTo(
0, ElementCount::getFixed(MaxVectorSize)))
.custom();

// Illegal G_UNMERGE_VALUES instructions should be handled
// during the combine phase.
getActionDefinitionsBuilder(G_BUILD_VECTOR)
Expand All @@ -217,14 +261,13 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
.lowerIf(vectorElementCountIsGreaterThan(1, MaxVectorSize))
.custom();

// If the result is still illegal, the combiner should be able to remove it.
getActionDefinitionsBuilder(G_CONCAT_VECTORS)
.legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize))
.moreElementsToNextPow2(0)
.lowerIf(vectorElementCountIsGreaterThan(0, MaxVectorSize))
.alwaysLegal();
.legalForCartesianProduct(allowedVectorTypes, allowedVectorTypes)
.moreElementsToNextPow2(0);

getActionDefinitionsBuilder(G_SPLAT_VECTOR)
.legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize))
.legalFor(allowedVectorTypes)
.moreElementsToNextPow2(0)
.fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
LegalizeMutations::changeElementSizeTo(0, MaxVectorSize))
Expand Down Expand Up @@ -273,9 +316,6 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
.legalFor(allIntScalarsAndVectors)
.legalIf(extendedScalarsAndVectors);

getActionDefinitionsBuilder({G_FMA, G_STRICT_FMA})
.legalFor(allFloatScalarsAndVectors);

getActionDefinitionsBuilder(G_STRICT_FLDEXP)
.legalForCartesianProduct(allFloatScalarsAndVectors, allIntScalars);

Expand Down Expand Up @@ -461,6 +501,23 @@ static bool legalizeExtractVectorElt(LegalizerHelper &Helper, MachineInstr &MI,
return true;
}

static bool legalizeInsertVectorElt(LegalizerHelper &Helper, MachineInstr &MI,
SPIRVGlobalRegistry *GR) {
MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
Register DstReg = MI.getOperand(0).getReg();
Register SrcReg = MI.getOperand(1).getReg();
Register ValReg = MI.getOperand(2).getReg();
Register IdxReg = MI.getOperand(3).getReg();

MIRBuilder
.buildIntrinsic(Intrinsic::spv_insertelt, ArrayRef<Register>{DstReg})
.addUse(SrcReg)
.addUse(ValReg)
.addUse(IdxReg);
MI.eraseFromParent();
return true;
}

static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpvType,
LegalizerHelper &Helper,
MachineRegisterInfo &MRI,
Expand All @@ -486,6 +543,8 @@ bool SPIRVLegalizerInfo::legalizeCustom(
return legalizeBitcast(Helper, MI);
case TargetOpcode::G_EXTRACT_VECTOR_ELT:
return legalizeExtractVectorElt(Helper, MI, GR);
case TargetOpcode::G_INSERT_VECTOR_ELT:
return legalizeInsertVectorElt(Helper, MI, GR);
case TargetOpcode::G_INTRINSIC:
case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
return legalizeIntrinsic(Helper, MI);
Expand Down Expand Up @@ -515,6 +574,15 @@ bool SPIRVLegalizerInfo::legalizeCustom(
}
}

static bool needsVectorLegalization(const LLT &Ty, const SPIRVSubtarget &ST) {
if (!Ty.isVector())
return false;
unsigned NumElements = Ty.getNumElements();
unsigned MaxVectorSize = ST.isShader() ? 4 : 16;
return (NumElements > 4 && !isPowerOf2_32(NumElements)) ||
NumElements > MaxVectorSize;
}

bool SPIRVLegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper,
MachineInstr &MI) const {
LLVM_DEBUG(dbgs() << "legalizeIntrinsic: " << MI);
Expand All @@ -531,41 +599,38 @@ bool SPIRVLegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper,
LLT DstTy = MRI.getType(DstReg);
LLT SrcTy = MRI.getType(SrcReg);

int32_t MaxVectorSize = ST.isShader() ? 4 : 16;

bool DstNeedsLegalization = false;
bool SrcNeedsLegalization = false;

if (DstTy.isVector()) {
if (DstTy.getNumElements() > 4 &&
!isPowerOf2_32(DstTy.getNumElements())) {
DstNeedsLegalization = true;
}

if (DstTy.getNumElements() > MaxVectorSize) {
DstNeedsLegalization = true;
}
}

if (SrcTy.isVector()) {
if (SrcTy.getNumElements() > 4 &&
!isPowerOf2_32(SrcTy.getNumElements())) {
SrcNeedsLegalization = true;
}

if (SrcTy.getNumElements() > MaxVectorSize) {
SrcNeedsLegalization = true;
}
}

// If an spv_bitcast needs to be legalized, we convert it to G_BITCAST to
// allow using the generic legalization rules.
if (DstNeedsLegalization || SrcNeedsLegalization) {
if (needsVectorLegalization(DstTy, ST) ||
needsVectorLegalization(SrcTy, ST)) {
LLVM_DEBUG(dbgs() << "Replacing with a G_BITCAST\n");
MIRBuilder.buildBitcast(DstReg, SrcReg);
MI.eraseFromParent();
}
return true;
} else if (IntrinsicID == Intrinsic::spv_insertelt) {
Register DstReg = MI.getOperand(0).getReg();
LLT DstTy = MRI.getType(DstReg);

if (needsVectorLegalization(DstTy, ST)) {
Register SrcReg = MI.getOperand(2).getReg();
Register ValReg = MI.getOperand(3).getReg();
Register IdxReg = MI.getOperand(4).getReg();
MIRBuilder.buildInsertVectorElement(DstReg, SrcReg, ValReg, IdxReg);
MI.eraseFromParent();
}
return true;
} else if (IntrinsicID == Intrinsic::spv_extractelt) {
Register SrcReg = MI.getOperand(2).getReg();
LLT SrcTy = MRI.getType(SrcReg);

if (needsVectorLegalization(SrcTy, ST)) {
Register DstReg = MI.getOperand(0).getReg();
Register IdxReg = MI.getOperand(3).getReg();
MIRBuilder.buildExtractVectorElement(DstReg, SrcReg, IdxReg);
MI.eraseFromParent();
}
return true;
}
return true;
}
Expand Down
79 changes: 57 additions & 22 deletions llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "SPIRV.h"
#include "SPIRVSubtarget.h"
#include "SPIRVUtils.h"
#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
#include "llvm/IR/IntrinsicsSPIRV.h"
#include "llvm/Support/Debug.h"
#include <stack>
Expand Down Expand Up @@ -66,8 +67,9 @@ static bool deduceAndAssignTypeForGUnmerge(MachineInstr *I, MachineFunction &MF,
for (unsigned i = 0; i < I->getNumDefs() && !ScalarType; ++i) {
for (const auto &Use :
MRI.use_nodbg_instructions(I->getOperand(i).getReg())) {
assert(Use.getOpcode() == TargetOpcode::G_BUILD_VECTOR &&
"Expected use of G_UNMERGE_VALUES to be a G_BUILD_VECTOR");
if (Use.getOpcode() != TargetOpcode::G_BUILD_VECTOR)
continue;

if (auto *VecType =
GR->getSPIRVTypeForVReg(Use.getOperand(0).getReg())) {
ScalarType = GR->getScalarOrVectorComponentType(VecType);
Expand Down Expand Up @@ -133,10 +135,10 @@ static SPIRVType *deduceTypeFromOperandRange(MachineInstr *I,
return ResType;
}

static SPIRVType *deduceTypeForResultRegister(MachineInstr *Use,
Register UseRegister,
SPIRVGlobalRegistry *GR,
MachineIRBuilder &MIB) {
static SPIRVType *deduceTypeFromResultRegister(MachineInstr *Use,
Register UseRegister,
SPIRVGlobalRegistry *GR,
MachineIRBuilder &MIB) {
for (const MachineOperand &MO : Use->defs()) {
if (!MO.isReg())
continue;
Expand All @@ -159,16 +161,44 @@ static SPIRVType *deduceTypeFromUses(Register Reg, MachineFunction &MF,
MachineRegisterInfo &MRI = MF.getRegInfo();
for (MachineInstr &Use : MRI.use_nodbg_instructions(Reg)) {
SPIRVType *ResType = nullptr;
LLVM_DEBUG(dbgs() << "Looking at use " << Use);
switch (Use.getOpcode()) {
case TargetOpcode::G_BUILD_VECTOR:
case TargetOpcode::G_EXTRACT_VECTOR_ELT:
case TargetOpcode::G_UNMERGE_VALUES:
LLVM_DEBUG(dbgs() << "Looking at use " << Use << "\n");
ResType = deduceTypeForResultRegister(&Use, Reg, GR, MIB);
case TargetOpcode::G_ADD:
case TargetOpcode::G_SUB:
case TargetOpcode::G_MUL:
case TargetOpcode::G_SDIV:
case TargetOpcode::G_UDIV:
case TargetOpcode::G_SREM:
case TargetOpcode::G_UREM:
case TargetOpcode::G_FADD:
case TargetOpcode::G_FSUB:
case TargetOpcode::G_FMUL:
case TargetOpcode::G_FDIV:
case TargetOpcode::G_FREM:
case TargetOpcode::G_FMA:
case TargetOpcode::G_STRICT_FMA:
ResType = deduceTypeFromResultRegister(&Use, Reg, GR, MIB);
break;
case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
case TargetOpcode::G_INTRINSIC: {
auto IntrinsicID = cast<GIntrinsic>(Use).getIntrinsicID();
if (IntrinsicID == Intrinsic::spv_insertelt) {
if (Reg == Use.getOperand(2).getReg())
ResType = deduceTypeFromResultRegister(&Use, Reg, GR, MIB);
} else if (IntrinsicID == Intrinsic::spv_extractelt) {
if (Reg == Use.getOperand(2).getReg())
ResType = deduceTypeFromResultRegister(&Use, Reg, GR, MIB);
}
break;
}
}
if (ResType)
if (ResType) {
LLVM_DEBUG(dbgs() << "Deduced type from use " << *ResType);
return ResType;
}
}
return nullptr;
}
Expand Down Expand Up @@ -296,20 +326,25 @@ static void registerSpirvTypeForNewInstructions(MachineFunction &MF,

for (auto *I : Worklist) {
MachineIRBuilder MIB(*I);
Register ResVReg = I->getOperand(0).getReg();
const LLT &ResLLT = MRI.getType(ResVReg);
SPIRVType *ResType = nullptr;
if (ResLLT.isVector()) {
SPIRVType *CompType = GR->getOrCreateSPIRVIntegerType(
ResLLT.getElementType().getSizeInBits(), MIB);
ResType = GR->getOrCreateSPIRVVectorType(
CompType, ResLLT.getNumElements(), MIB, false);
} else {
ResType = GR->getOrCreateSPIRVIntegerType(ResLLT.getSizeInBits(), MIB);
for (unsigned Idx = 0; Idx < I->getNumDefs(); ++Idx) {
Register ResVReg = I->getOperand(Idx).getReg();
if (GR->getSPIRVTypeForVReg(ResVReg))
continue;
const LLT &ResLLT = MRI.getType(ResVReg);
SPIRVType *ResType = nullptr;
if (ResLLT.isVector()) {
SPIRVType *CompType = GR->getOrCreateSPIRVIntegerType(
ResLLT.getElementType().getSizeInBits(), MIB);
ResType = GR->getOrCreateSPIRVVectorType(
CompType, ResLLT.getNumElements(), MIB, false);
} else {
ResType = GR->getOrCreateSPIRVIntegerType(ResLLT.getSizeInBits(), MIB);
}
LLVM_DEBUG(dbgs() << "Could not determine type for " << ResVReg
<< ", defaulting to " << *ResType << "\n");

setRegClassType(ResVReg, ResType, GR, &MRI, MF, true);
}
LLVM_DEBUG(dbgs() << "Could not determine type for " << *I
<< ", defaulting to " << *ResType << "\n");
setRegClassType(ResVReg, ResType, GR, &MRI, MF, true);
}
}

Expand Down
Loading