-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[SPIRV] Use a worklist in the post-legalizer #165027
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-backend-spir-v Author: Steven Perron (s-perron) ChangesThis commit refactors the SPIRV post-legalizer to use a worklist to process The new implementation adds all new instructions that require a SPIR-V type This change makes the post-legalizer more robust and fixes potential ordering Existing tests cover existing functionality. More tests will be added as Part of #153091 Full diff: https://github.com/llvm/llvm-project/pull/165027.diff 1 Files Affected:
diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
index d17528dd882bf..d11168b70aea8 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
@@ -17,7 +17,8 @@
#include "SPIRV.h"
#include "SPIRVSubtarget.h"
#include "SPIRVUtils.h"
-#include "llvm/IR/Attributes.h"
+#include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/Support/Debug.h"
#include <stack>
#define DEBUG_TYPE "spirv-postlegalizer"
@@ -45,6 +46,11 @@ extern void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
static bool mayBeInserted(unsigned Opcode) {
switch (Opcode) {
+ case TargetOpcode::G_CONSTANT:
+ case TargetOpcode::G_UNMERGE_VALUES:
+ case TargetOpcode::G_EXTRACT_VECTOR_ELT:
+ case TargetOpcode::G_INTRINSIC:
+ case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
case TargetOpcode::G_SMAX:
case TargetOpcode::G_UMAX:
case TargetOpcode::G_SMIN:
@@ -53,69 +59,344 @@ static bool mayBeInserted(unsigned Opcode) {
case TargetOpcode::G_FMINIMUM:
case TargetOpcode::G_FMAXNUM:
case TargetOpcode::G_FMAXIMUM:
+ case TargetOpcode::G_IMPLICIT_DEF:
+ case TargetOpcode::G_BUILD_VECTOR:
+ case TargetOpcode::G_ICMP:
+ case TargetOpcode::G_ANYEXT:
return true;
default:
return isTypeFoldingSupported(Opcode);
}
}
-static void processNewInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
- MachineIRBuilder MIB) {
+static SPIRVType *deduceTypeForGConstant(MachineInstr *I, MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder &MIB,
+ Register ResVReg) {
MachineRegisterInfo &MRI = MF.getRegInfo();
+ const LLT &Ty = MRI.getType(ResVReg);
+ unsigned BitWidth = Ty.getScalarSizeInBits();
+ return GR->getOrCreateSPIRVIntegerType(BitWidth, MIB);
+}
- for (MachineBasicBlock &MBB : MF) {
- for (MachineInstr &I : MBB) {
- const unsigned Opcode = I.getOpcode();
- if (Opcode == TargetOpcode::G_UNMERGE_VALUES) {
- unsigned ArgI = I.getNumOperands() - 1;
- Register SrcReg = I.getOperand(ArgI).isReg()
- ? I.getOperand(ArgI).getReg()
- : Register(0);
- SPIRVType *DefType =
- SrcReg.isValid() ? GR->getSPIRVTypeForVReg(SrcReg) : nullptr;
- if (!DefType || DefType->getOpcode() != SPIRV::OpTypeVector)
- report_fatal_error(
- "cannot select G_UNMERGE_VALUES with a non-vector argument");
- SPIRVType *ScalarType =
- GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
- for (unsigned i = 0; i < I.getNumDefs(); ++i) {
- Register ResVReg = I.getOperand(i).getReg();
- SPIRVType *ResType = GR->getSPIRVTypeForVReg(ResVReg);
- if (!ResType) {
- // There was no "assign type" actions, let's fix this now
+static bool deduceAndAssignTypeForGUnmerge(MachineInstr *I, MachineFunction &MF,
+ SPIRVGlobalRegistry *GR) {
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ Register SrcReg = I->getOperand(I->getNumOperands() - 1).getReg();
+ if (SPIRVType *DefType = GR->getSPIRVTypeForVReg(SrcReg)) {
+ if (DefType->getOpcode() == SPIRV::OpTypeVector) {
+ SPIRVType *ScalarType =
+ GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
+ for (unsigned i = 0; i < I->getNumDefs(); ++i) {
+ Register DefReg = I->getOperand(i).getReg();
+ if (!GR->getSPIRVTypeForVReg(DefReg)) {
+ LLT DefLLT = MRI.getType(DefReg);
+ SPIRVType *ResType;
+ if (DefLLT.isVector()) {
+ const SPIRVInstrInfo *TII =
+ MF.getSubtarget<SPIRVSubtarget>().getInstrInfo();
+ ResType = GR->getOrCreateSPIRVVectorType(
+ ScalarType, DefLLT.getNumElements(), *I, *TII);
+ } else {
ResType = ScalarType;
- setRegClassType(ResVReg, ResType, GR, &MRI, *GR->CurMF, true);
}
+ setRegClassType(DefReg, ResType, GR, &MRI, MF);
}
- } else if (mayBeInserted(Opcode) && I.getNumDefs() == 1 &&
- I.getNumOperands() > 1 && I.getOperand(1).isReg()) {
- // Legalizer may have added a new instructions and introduced new
- // registers, we must decorate them as if they were introduced in a
- // non-automatic way
- Register ResVReg = I.getOperand(0).getReg();
- // Check if the register defined by the instruction is newly generated
- // or already processed
- // Check if we have type defined for operands of the new instruction
- bool IsKnownReg = MRI.getRegClassOrNull(ResVReg);
- SPIRVType *ResVType = GR->getSPIRVTypeForVReg(
- IsKnownReg ? ResVReg : I.getOperand(1).getReg());
- if (!ResVType)
- continue;
- // Set type & class
- if (!IsKnownReg)
- setRegClassType(ResVReg, ResVType, GR, &MRI, *GR->CurMF, true);
- // If this is a simple operation that is to be reduced by TableGen
- // definition we must apply some of pre-legalizer rules here
- if (isTypeFoldingSupported(Opcode)) {
- processInstr(I, MIB, MRI, GR, GR->getSPIRVTypeForVReg(ResVReg));
- if (IsKnownReg && MRI.hasOneUse(ResVReg)) {
- MachineInstr &UseMI = *MRI.use_instr_begin(ResVReg);
- if (UseMI.getOpcode() == SPIRV::ASSIGN_TYPE)
- continue;
- }
- insertAssignInstr(ResVReg, nullptr, ResVType, GR, MIB, MRI);
+ }
+ return true;
+ }
+ }
+ return false;
+}
+
+static SPIRVType *deduceTypeForGExtractVectorElt(MachineInstr *I,
+ MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ Register ResVReg) {
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ Register VecReg = I->getOperand(1).getReg();
+ if (SPIRVType *VecType = GR->getSPIRVTypeForVReg(VecReg)) {
+ assert(VecType->getOpcode() == SPIRV::OpTypeVector);
+ return GR->getScalarOrVectorComponentType(VecType);
+ }
+
+ // If not handled yet, then check if it is used in a G_BUILD_VECTOR.
+ // If so get the type from there.
+ for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) {
+ if (Use.getOpcode() == TargetOpcode::G_BUILD_VECTOR) {
+ Register BuildVecResReg = Use.getOperand(0).getReg();
+ if (SPIRVType *BuildVecType = GR->getSPIRVTypeForVReg(BuildVecResReg))
+ return GR->getScalarOrVectorComponentType(BuildVecType);
+ }
+ }
+ return nullptr;
+}
+
+static SPIRVType *deduceTypeForGBuildVector(MachineInstr *I,
+ MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder &MIB,
+ Register ResVReg) {
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ // First check if any of the operands have a type.
+ for (unsigned i = 1; i < I->getNumOperands(); ++i) {
+ if (SPIRVType *OpType =
+ GR->getSPIRVTypeForVReg(I->getOperand(i).getReg())) {
+ const LLT &ResLLT = MRI.getType(ResVReg);
+ return GR->getOrCreateSPIRVVectorType(OpType, ResLLT.getNumElements(),
+ MIB, false);
+ }
+ }
+ // If that did not work, then check the uses.
+ for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) {
+ if (Use.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT) {
+ Register ExtractResReg = Use.getOperand(0).getReg();
+ if (SPIRVType *ScalarType = GR->getSPIRVTypeForVReg(ExtractResReg)) {
+ const LLT &ResLLT = MRI.getType(ResVReg);
+ return GR->getOrCreateSPIRVVectorType(
+ ScalarType, ResLLT.getNumElements(), MIB, false);
+ }
+ }
+ }
+ return nullptr;
+}
+
+static SPIRVType *deduceTypeForGImplicitDef(MachineInstr *I,
+ MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ Register ResVReg) {
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) {
+ const unsigned UseOpc = Use.getOpcode();
+ assert(UseOpc == TargetOpcode::G_BUILD_VECTOR ||
+ UseOpc == TargetOpcode::G_SHUFFLE_VECTOR);
+ // It's possible that the use instruction has not been processed yet.
+ // We should look at the operands of the use to determine the type.
+ for (unsigned i = 1; i < Use.getNumOperands(); ++i) {
+ if (auto *Type = GR->getSPIRVTypeForVReg(Use.getOperand(i).getReg()))
+ return Type;
+ }
+ }
+ return nullptr;
+}
+
+static SPIRVType *deduceTypeForGIntrinsic(MachineInstr *I, MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder &MIB,
+ Register ResVReg) {
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ if (!isSpvIntrinsic(*I, Intrinsic::spv_bitcast))
+ return nullptr;
+
+ for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) {
+ const unsigned UseOpc = Use.getOpcode();
+ assert(UseOpc == TargetOpcode::G_EXTRACT_VECTOR_ELT ||
+ UseOpc == TargetOpcode::G_SHUFFLE_VECTOR);
+ Register UseResultReg = Use.getOperand(0).getReg();
+ if (SPIRVType *UseResType = GR->getSPIRVTypeForVReg(UseResultReg)) {
+ SPIRVType *ScalarType = GR->getScalarOrVectorComponentType(UseResType);
+ const LLT &BitcastLLT = MRI.getType(ResVReg);
+ if (BitcastLLT.isVector())
+ return GR->getOrCreateSPIRVVectorType(
+ ScalarType, BitcastLLT.getNumElements(), MIB, false);
+ return ScalarType;
+ }
+ }
+ return nullptr;
+}
+
+static SPIRVType *deduceTypeForGAnyExt(MachineInstr *I, MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder &MIB,
+ Register ResVReg) {
+ // The result type of G_ANYEXT cannot be inferred from its operand.
+ // We use the result register's LLT to determine the correct integer type.
+ const LLT &ResLLT = MIB.getMRI()->getType(ResVReg);
+ if (!ResLLT.isScalar())
+ return nullptr;
+ return GR->getOrCreateSPIRVIntegerType(ResLLT.getSizeInBits(), MIB);
+}
+
+static SPIRVType *deduceTypeForDefault(MachineInstr *I, MachineFunction &MF,
+ SPIRVGlobalRegistry *GR) {
+ if (I->getNumDefs() != 1 || I->getNumOperands() <= 1 ||
+ !I->getOperand(1).isReg())
+ return nullptr;
+
+ SPIRVType *OpType = GR->getSPIRVTypeForVReg(I->getOperand(1).getReg());
+ if (!OpType)
+ return nullptr;
+ return OpType;
+}
+
+static bool deduceAndAssignSpirvType(MachineInstr *I, MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder &MIB) {
+ LLVM_DEBUG(dbgs() << "Processing instruction: " << *I);
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ const unsigned Opcode = I->getOpcode();
+ Register ResVReg = I->getOperand(0).getReg();
+ SPIRVType *ResType = nullptr;
+
+ switch (Opcode) {
+ case TargetOpcode::G_CONSTANT: {
+ ResType = deduceTypeForGConstant(I, MF, GR, MIB, ResVReg);
+ break;
+ }
+ case TargetOpcode::G_UNMERGE_VALUES: {
+ // This one is special as it defines multiple registers.
+ if (deduceAndAssignTypeForGUnmerge(I, MF, GR))
+ return true;
+ break;
+ }
+ case TargetOpcode::G_EXTRACT_VECTOR_ELT: {
+ ResType = deduceTypeForGExtractVectorElt(I, MF, GR, ResVReg);
+ break;
+ }
+ case TargetOpcode::G_BUILD_VECTOR: {
+ ResType = deduceTypeForGBuildVector(I, MF, GR, MIB, ResVReg);
+ break;
+ }
+ case TargetOpcode::G_ANYEXT: {
+ ResType = deduceTypeForGAnyExt(I, MF, GR, MIB, ResVReg);
+ break;
+ }
+ case TargetOpcode::G_IMPLICIT_DEF: {
+ ResType = deduceTypeForGImplicitDef(I, MF, GR, ResVReg);
+ break;
+ }
+ case TargetOpcode::G_INTRINSIC:
+ case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS: {
+ ResType = deduceTypeForGIntrinsic(I, MF, GR, MIB, ResVReg);
+ break;
+ }
+ default:
+ ResType = deduceTypeForDefault(I, MF, GR);
+ break;
+ }
+
+ if (ResType) {
+ LLVM_DEBUG(dbgs() << "Assigned type to " << *I << ": " << *ResType << "\n");
+ GR->assignSPIRVTypeToVReg(ResType, ResVReg, MF);
+
+ if (!MRI.getRegClassOrNull(ResVReg)) {
+ LLVM_DEBUG(dbgs() << "Updating the register class.\n");
+ setRegClassType(ResVReg, ResType, GR, &MRI, *GR->CurMF, true);
+ }
+ return true;
+ }
+ return false;
+}
+
+static bool requiresSpirvType(MachineInstr &I, SPIRVGlobalRegistry *GR,
+ MachineRegisterInfo &MRI) {
+ LLVM_DEBUG(dbgs() << "Checking if instruction requires a SPIR-V type: "
+ << I;);
+ if (I.getNumDefs() == 0) {
+ LLVM_DEBUG(dbgs() << "Instruction does not have a definition.\n");
+ return false;
+ }
+ if (!mayBeInserted(I.getOpcode())) {
+ LLVM_DEBUG(dbgs() << "Instruction may not be inserted.\n");
+ return false;
+ }
+
+ Register ResultRegister = I.defs().begin()->getReg();
+ if (GR->getSPIRVTypeForVReg(ResultRegister)) {
+ LLVM_DEBUG(dbgs() << "Instruction already has a SPIR-V type.\n");
+ if (!MRI.getRegClassOrNull(ResultRegister)) {
+ LLVM_DEBUG(dbgs() << "Updating the register class.\n");
+ setRegClassType(ResultRegister, GR->getSPIRVTypeForVReg(ResultRegister),
+ GR, &MRI, *GR->CurMF, true);
+ }
+ return false;
+ }
+
+ return true;
+}
+
+static void registerSpirvTypeForNewInstructions(MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder MIB) {
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ SmallVector<MachineInstr *, 8> Worklist;
+ for (MachineBasicBlock &MBB : MF) {
+ for (MachineInstr &I : MBB) {
+ if (requiresSpirvType(I, GR, MRI)) {
+ Worklist.push_back(&I);
+ }
+ }
+ }
+
+ if (Worklist.empty()) {
+ LLVM_DEBUG(dbgs() << "Initial worklist is empty.\n");
+ return;
+ }
+
+ LLVM_DEBUG(dbgs() << "Initial worklist:\n";
+ for (auto *I : Worklist) { I->dump(); });
+
+ bool Changed = true;
+ while (Changed) {
+ Changed = false;
+ SmallVector<MachineInstr *, 8> NextWorklist;
+
+ for (MachineInstr *I : Worklist) {
+ if (deduceAndAssignSpirvType(I, MF, GR, MIB)) {
+ Changed = true;
+ } else {
+ NextWorklist.push_back(I);
+ }
+ }
+ Worklist = NextWorklist;
+ LLVM_DEBUG(dbgs() << "Worklist size: " << Worklist.size() << "\n");
+ }
+
+ if (!Worklist.empty()) {
+ LLVM_DEBUG(dbgs() << "Remaining worklist:\n";
+ for (auto *I : Worklist) { I->dump(); });
+ assert(Worklist.empty() && "Worklist is not empty");
+ }
+}
+
+static void ensureAssignTypeForTypeFolding(MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder MIB) {
+ LLVM_DEBUG(dbgs() << "Entering ensureAssignTypeForTypeFolding for function "
+ << MF.getName() << "\n");
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ for (MachineBasicBlock &MBB : MF) {
+ for (MachineInstr &MI : MBB) {
+ if (!isTypeFoldingSupported(MI.getOpcode()))
+ continue;
+ if (MI.getNumOperands() == 1 || !MI.getOperand(1).isReg())
+ continue;
+
+ LLVM_DEBUG(dbgs() << "Processing instruction: " << MI);
+
+ // Check uses of MI to see if it already has an use in SPIRV::ASSIGN_TYPE
+ bool HasAssignType = false;
+ Register ResultRegister = MI.defs().begin()->getReg();
+ // All uses of Result register
+ for (MachineInstr &UseInstr :
+ MRI.use_nodbg_instructions(ResultRegister)) {
+ if (UseInstr.getOpcode() == SPIRV::ASSIGN_TYPE) {
+ HasAssignType = true;
+ LLVM_DEBUG(dbgs() << " Instruction already has an ASSIGN_TYPE use: "
+ << UseInstr);
+ break;
}
}
+
+ if (!HasAssignType) {
+ Register ResultRegister = MI.defs().begin()->getReg();
+ SPIRVType *ResultType = GR->getSPIRVTypeForVReg(ResultRegister);
+ LLVM_DEBUG(
+ dbgs() << " Adding ASSIGN_TYPE for ResultRegister: "
+ << printReg(ResultRegister, MRI.getTargetRegisterInfo())
+ << " with type: " << *ResultType);
+ insertAssignInstr(ResultRegister, nullptr, ResultType, GR, MIB, MRI);
+ }
}
}
}
@@ -156,9 +437,8 @@ bool SPIRVPostLegalizer::runOnMachineFunction(MachineFunction &MF) {
SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
GR->setCurrentFunc(MF);
MachineIRBuilder MIB(MF);
-
- processNewInstrs(MF, GR, MIB);
-
+ registerSpirvTypeForNewInstructions(MF, GR, MIB);
+ ensureAssignTypeForTypeFolding(MF, GR, MIB);
return true;
}
|
c16bc1e to
f2f29a5
Compare
This commit refactors the SPIRV post-legalizer to use a worklist to process new instructions. Previously, the post-legalizer would iterate through all instructions and try to assign types. This could fail if a new instruction depended on another new instruction that had not been processed yet. The new implementation adds all new instructions that require a SPIR-V type to a worklist. It then iteratively processes the worklist until it is empty. This ensures that all dependencies are met before an instruction is processed. This change makes the post-legalizer more robust and fixes potential ordering issues with newly generated instructions. Existing tests cover existing functionality. More tests will be added as the legalizer is modified. Part of llvm#153091
c248de2 to
4f7f51c
Compare
|
@Keenuts @luciechoi I believe this is in good shape for a review. |
Keenuts
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Code looks easy to follow, just some comments
This commit refactors the SPIRV post-legalizer to use a worklist to process
new instructions. Previously, the post-legalizer would iterate through all
instructions and try to assign types. This could fail if a new instruction
depended on another new instruction that had not been processed yet.
The new implementation adds all new instructions that require a SPIR-V type
to a worklist. It then iteratively processes the worklist until it is empty.
This ensures that all dependencies are met before an instruction is
processed.
This change makes the post-legalizer more robust and fixes potential ordering
issues with newly generated instructions.
Existing tests cover existing functionality. More tests will be added as
the legalizer is modified.
Part of #153091