Skip to content

Commit 909fd3c

Browse files
committed
Merge branch 'post-legalizer-worklist' into legalize-long-vectors-final
2 parents e2fa040 + 547301b commit 909fd3c

File tree

1 file changed

+108
-30
lines changed

1 file changed

+108
-30
lines changed

llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp

Lines changed: 108 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ static bool mayBeInserted(unsigned Opcode) {
6262
case TargetOpcode::G_IMPLICIT_DEF:
6363
case TargetOpcode::G_BUILD_VECTOR:
6464
case TargetOpcode::G_ICMP:
65+
case TargetOpcode::G_SHUFFLE_VECTOR:
6566
case TargetOpcode::G_ANYEXT:
6667
return true;
6768
default:
@@ -83,30 +84,47 @@ static bool deduceAndAssignTypeForGUnmerge(MachineInstr *I, MachineFunction &MF,
8384
SPIRVGlobalRegistry *GR) {
8485
MachineRegisterInfo &MRI = MF.getRegInfo();
8586
Register SrcReg = I->getOperand(I->getNumOperands() - 1).getReg();
87+
SPIRVType *ScalarType = nullptr;
8688
if (SPIRVType *DefType = GR->getSPIRVTypeForVReg(SrcReg)) {
87-
if (DefType->getOpcode() == SPIRV::OpTypeVector) {
88-
SPIRVType *ScalarType =
89-
GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
90-
for (unsigned i = 0; i < I->getNumDefs(); ++i) {
91-
Register DefReg = I->getOperand(i).getReg();
92-
if (!GR->getSPIRVTypeForVReg(DefReg)) {
93-
LLT DefLLT = MRI.getType(DefReg);
94-
SPIRVType *ResType;
95-
if (DefLLT.isVector()) {
96-
const SPIRVInstrInfo *TII =
97-
MF.getSubtarget<SPIRVSubtarget>().getInstrInfo();
98-
ResType = GR->getOrCreateSPIRVVectorType(
99-
ScalarType, DefLLT.getNumElements(), *I, *TII);
100-
} else {
101-
ResType = ScalarType;
102-
}
103-
setRegClassType(DefReg, ResType, GR, &MRI, MF);
89+
assert(DefType->getOpcode() == SPIRV::OpTypeVector);
90+
ScalarType = GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
91+
}
92+
93+
if (!ScalarType) {
94+
// If we could not deduce the type from the source, try to deduce it from
95+
// the uses of the results.
96+
for (unsigned i = 0; i < I->getNumDefs() && !ScalarType; ++i) {
97+
for (const auto &Use :
98+
MRI.use_nodbg_instructions(I->getOperand(i).getReg())) {
99+
assert(Use.getOpcode() == TargetOpcode::G_BUILD_VECTOR &&
100+
"Expected use of G_UNMERGE_VALUES to be a G_BUILD_VECTOR");
101+
if (auto *VecType =
102+
GR->getSPIRVTypeForVReg(Use.getOperand(0).getReg())) {
103+
ScalarType = GR->getScalarOrVectorComponentType(VecType);
104+
break;
104105
}
105106
}
106-
return true;
107107
}
108108
}
109-
return false;
109+
110+
if (!ScalarType)
111+
return false;
112+
113+
for (unsigned i = 0; i < I->getNumDefs(); ++i) {
114+
Register DefReg = I->getOperand(i).getReg();
115+
if (GR->getSPIRVTypeForVReg(DefReg))
116+
continue;
117+
118+
LLT DefLLT = MRI.getType(DefReg);
119+
SPIRVType *ResType =
120+
DefLLT.isVector()
121+
? GR->getOrCreateSPIRVVectorType(
122+
ScalarType, DefLLT.getNumElements(), *I,
123+
*MF.getSubtarget<SPIRVSubtarget>().getInstrInfo())
124+
: ScalarType;
125+
setRegClassType(DefReg, ResType, GR, &MRI, MF);
126+
}
127+
return true;
110128
}
111129

112130
static SPIRVType *deduceTypeForGExtractVectorElt(MachineInstr *I,
@@ -167,20 +185,61 @@ static SPIRVType *deduceTypeForGBuildVector(MachineInstr *I,
167185
return nullptr;
168186
}
169187

188+
static SPIRVType *deduceTypeForGShuffleVector(MachineInstr *I,
189+
MachineFunction &MF,
190+
SPIRVGlobalRegistry *GR,
191+
MachineIRBuilder &MIB,
192+
Register ResVReg) {
193+
MachineRegisterInfo &MRI = MF.getRegInfo();
194+
const LLT &ResLLT = MRI.getType(ResVReg);
195+
assert(ResLLT.isVector() && "G_SHUFFLE_VECTOR result must be a vector");
196+
197+
// The result element type should be the same as the input vector element
198+
// types.
199+
for (unsigned i = 1; i <= 2; ++i) {
200+
Register VReg = I->getOperand(i).getReg();
201+
if (auto *VType = GR->getSPIRVTypeForVReg(VReg)) {
202+
if (auto *ScalarType = GR->getScalarOrVectorComponentType(VType))
203+
return GR->getOrCreateSPIRVVectorType(
204+
ScalarType, ResLLT.getNumElements(), MIB, false);
205+
}
206+
}
207+
return nullptr;
208+
}
209+
170210
static SPIRVType *deduceTypeForGImplicitDef(MachineInstr *I,
171211
MachineFunction &MF,
172212
SPIRVGlobalRegistry *GR,
173213
Register ResVReg) {
174214
MachineRegisterInfo &MRI = MF.getRegInfo();
175-
for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) {
176-
const unsigned UseOpc = Use.getOpcode();
177-
assert(UseOpc == TargetOpcode::G_BUILD_VECTOR ||
178-
UseOpc == TargetOpcode::G_SHUFFLE_VECTOR);
179-
// It's possible that the use instruction has not been processed yet.
180-
// We should look at the operands of the use to determine the type.
181-
for (unsigned i = 1; i < Use.getNumOperands(); ++i) {
182-
if (auto *Type = GR->getSPIRVTypeForVReg(Use.getOperand(i).getReg()))
183-
return Type;
215+
for (const MachineInstr &Use : MRI.use_nodbg_instructions(ResVReg)) {
216+
SPIRVType *ScalarType = nullptr;
217+
switch (Use.getOpcode()) {
218+
case TargetOpcode::G_BUILD_VECTOR:
219+
case TargetOpcode::G_UNMERGE_VALUES:
220+
// It's possible that the use instruction has not been processed yet.
221+
// We should look at the operands of the use to determine the type.
222+
for (unsigned i = 1; i < Use.getNumOperands(); ++i) {
223+
if (SPIRVType *OpType =
224+
GR->getSPIRVTypeForVReg(Use.getOperand(i).getReg()))
225+
ScalarType = GR->getScalarOrVectorComponentType(OpType);
226+
}
227+
break;
228+
case TargetOpcode::G_SHUFFLE_VECTOR:
229+
// For G_SHUFFLE_VECTOR, only look at the vector input operands.
230+
if (auto *Type = GR->getSPIRVTypeForVReg(Use.getOperand(1).getReg()))
231+
ScalarType = GR->getScalarOrVectorComponentType(Type);
232+
if (auto *Type = GR->getSPIRVTypeForVReg(Use.getOperand(2).getReg()))
233+
ScalarType = GR->getScalarOrVectorComponentType(Type);
234+
break;
235+
}
236+
if (ScalarType) {
237+
const LLT &ResLLT = MRI.getType(ResVReg);
238+
if (!ResLLT.isVector())
239+
return ScalarType;
240+
return GR->getOrCreateSPIRVVectorType(
241+
ScalarType, ResLLT.getNumElements(), *I,
242+
*MF.getSubtarget<SPIRVSubtarget>().getInstrInfo());
184243
}
185244
}
186245
return nullptr;
@@ -198,7 +257,8 @@ static SPIRVType *deduceTypeForGIntrinsic(MachineInstr *I, MachineFunction &MF,
198257
const unsigned UseOpc = Use.getOpcode();
199258
assert(UseOpc == TargetOpcode::G_EXTRACT_VECTOR_ELT ||
200259
UseOpc == TargetOpcode::G_SHUFFLE_VECTOR ||
201-
UseOpc == TargetOpcode::G_BUILD_VECTOR);
260+
UseOpc == TargetOpcode::G_BUILD_VECTOR ||
261+
UseOpc == TargetOpcode::G_UNMERGE_VALUES);
202262
Register UseResultReg = Use.getOperand(0).getReg();
203263
if (SPIRVType *UseResType = GR->getSPIRVTypeForVReg(UseResultReg)) {
204264
SPIRVType *ScalarType = GR->getScalarOrVectorComponentType(UseResType);
@@ -264,6 +324,10 @@ static bool deduceAndAssignSpirvType(MachineInstr *I, MachineFunction &MF,
264324
ResType = deduceTypeForGBuildVector(I, MF, GR, MIB, ResVReg);
265325
break;
266326
}
327+
case TargetOpcode::G_SHUFFLE_VECTOR: {
328+
ResType = deduceTypeForGShuffleVector(I, MF, GR, MIB, ResVReg);
329+
break;
330+
}
267331
case TargetOpcode::G_ANYEXT: {
268332
ResType = deduceTypeForGAnyExt(I, MF, GR, MIB, ResVReg);
269333
break;
@@ -362,7 +426,21 @@ static void registerSpirvTypeForNewInstructions(MachineFunction &MF,
362426
if (!Worklist.empty()) {
363427
LLVM_DEBUG(dbgs() << "Remaining worklist:\n";
364428
for (auto *I : Worklist) { I->dump(); });
365-
assert(Worklist.empty() && "Worklist is not empty");
429+
for (auto *I : Worklist) {
430+
MachineIRBuilder MIB(*I);
431+
Register ResVReg = I->getOperand(0).getReg();
432+
const LLT &ResLLT = MRI.getType(ResVReg);
433+
SPIRVType *ResType = nullptr;
434+
if (ResLLT.isVector()) {
435+
SPIRVType *CompType = GR->getOrCreateSPIRVIntegerType(
436+
ResLLT.getElementType().getSizeInBits(), MIB);
437+
ResType = GR->getOrCreateSPIRVVectorType(
438+
CompType, ResLLT.getNumElements(), MIB, false);
439+
} else {
440+
ResType = GR->getOrCreateSPIRVIntegerType(ResLLT.getSizeInBits(), MIB);
441+
}
442+
setRegClassType(ResVReg, ResType, GR, &MRI, MF, true);
443+
}
366444
}
367445
}
368446

0 commit comments

Comments
 (0)