Skip to content

Commit bbbb8fa

Browse files
rework assign_type, type/const instruction selection
1 parent 0253795 commit bbbb8fa

24 files changed

+577
-399
lines changed

llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,7 @@ static Register buildBuiltinVariableLoad(
535535
/// assign SPIRVType to both registers. If SpirvTy is provided, use it as
536536
/// SPIRVType in ASSIGN_TYPE, otherwise create it from \p Ty. Defined in
537537
/// SPIRVPreLegalizer.cpp.
538-
extern Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
538+
extern void insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
539539
SPIRVGlobalRegistry *GR,
540540
MachineIRBuilder &MIB,
541541
MachineRegisterInfo &MRI);

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp

Lines changed: 68 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -250,24 +250,32 @@ Register SPIRVGlobalRegistry::getOrCreateConstFP(APFloat Val, MachineInstr &I,
250250
SPIRVType *SpvType,
251251
const SPIRVInstrInfo &TII,
252252
bool ZeroAsNull) {
253-
unsigned BitWidth = getScalarOrVectorBitWidth(SpvType);
254253
LLVMContext &Ctx = CurMF->getFunction().getContext();
255254
auto *const CF = ConstantFP::get(Ctx, Val);
256-
Register Res = find(CF, CurMF);
257-
if (Res.isValid())
258-
return Res;
255+
const MachineInstr *MI = findMI(CF, CurMF);
256+
if (MI && (MI->getOpcode() == SPIRV::OpConstantNull ||
257+
MI->getOpcode() == SPIRV::OpConstantF))
258+
return MI->getOperand(0).getReg();
259+
return createConstFP(CF, I, SpvType, TII, ZeroAsNull);
260+
}
259261

262+
Register SPIRVGlobalRegistry::createConstFP(const ConstantFP *CF,
263+
MachineInstr &I, SPIRVType *SpvType,
264+
const SPIRVInstrInfo &TII,
265+
bool ZeroAsNull) {
266+
unsigned BitWidth = getScalarOrVectorBitWidth(SpvType);
260267
LLT LLTy = LLT::scalar(BitWidth);
261-
Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
268+
Register Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
262269
CurMF->getRegInfo().setRegClass(Res, &SPIRV::fIDRegClass);
263270
assignFloatTypeToVReg(BitWidth, Res, I, TII);
264271

265-
MachineIRBuilder MIRBuilder(I);
272+
MachineInstr *DepMI = const_cast<MachineInstr *>(SpvType);
273+
MachineIRBuilder MIRBuilder(*DepMI->getParent(), DepMI->getIterator());
266274
SPIRVType *NewType =
267275
createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
268276
MachineInstrBuilder MIB;
269277
// In OpenCL OpConstantNull - Scalar floating point: +0.0 (all bits 0)
270-
if (Val.isPosZero() && ZeroAsNull) {
278+
if (CF->getValue().isPosZero() && ZeroAsNull) {
271279
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
272280
.addDef(Res)
273281
.addUse(getSPIRVTypeID(SpvType));
@@ -294,24 +302,35 @@ Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
294302
const SPIRVInstrInfo &TII,
295303
bool ZeroAsNull) {
296304
const IntegerType *Ty = cast<IntegerType>(getTypeForSPIRVType(SpvType));
297-
unsigned BitWidth = getScalarOrVectorBitWidth(SpvType);
298305
auto *const CI = ConstantInt::get(const_cast<IntegerType *>(Ty), Val);
299-
Register Res = find(CI, CurMF);
300-
if (Res.isValid())
301-
return Res;
306+
const MachineInstr *MI = findMI(CI, CurMF);
307+
if (MI && (MI->getOpcode() == SPIRV::OpConstantNull ||
308+
MI->getOpcode() == SPIRV::OpConstantI))
309+
return MI->getOperand(0).getReg();
310+
return createConstInt(CI, I, SpvType, TII, ZeroAsNull);
311+
}
312+
313+
Register SPIRVGlobalRegistry::createConstInt(const ConstantInt *CI,
314+
MachineInstr &I,
315+
SPIRVType *SpvType,
316+
const SPIRVInstrInfo &TII,
317+
bool ZeroAsNull) {
318+
unsigned BitWidth = getScalarOrVectorBitWidth(SpvType);
302319
LLT LLTy = LLT::scalar(BitWidth);
303-
Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
320+
Register Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
304321
CurMF->getRegInfo().setRegClass(Res, &SPIRV::iIDRegClass);
305322
assignIntTypeToVReg(BitWidth, Res, I, TII);
306-
MachineIRBuilder MIRBuilder(I);
323+
324+
MachineInstr *DepMI = const_cast<MachineInstr *>(SpvType);
325+
MachineIRBuilder MIRBuilder(*DepMI->getParent(), DepMI->getIterator());
307326
SPIRVType *NewType =
308327
createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
309328
MachineInstrBuilder MIB;
310-
if (Val || !ZeroAsNull) {
329+
if (!CI->isZero() || !ZeroAsNull) {
311330
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
312331
.addDef(Res)
313332
.addUse(getSPIRVTypeID(SpvType));
314-
addNumImm(APInt(BitWidth, Val), MIB);
333+
addNumImm(APInt(BitWidth, CI->getZExtValue()), MIB);
315334
} else {
316335
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
317336
.addDef(Res)
@@ -441,7 +460,8 @@ Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(
441460
CurMF->getRegInfo().setRegClass(Res, getRegClass(SpvType));
442461
assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
443462

444-
MachineIRBuilder MIRBuilder(I);
463+
MachineInstr *DepMI = const_cast<MachineInstr *>(SpvType);
464+
MachineIRBuilder MIRBuilder(*DepMI->getParent(), DepMI->getIterator());
445465
const MachineInstr *NewMI =
446466
createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
447467
MachineInstrBuilder MIB;
@@ -751,11 +771,7 @@ static std::string buildSpirvTypeName(const SPIRVType *Type,
751771
MachineRegisterInfo *MRI = MIRBuilder.getMRI();
752772
Register ElementTypeReg = Type->getOperand(1).getReg();
753773
auto *ElementType = MRI->getUniqueVRegDef(ElementTypeReg);
754-
const SPIRVType *TypeInst = MRI->getVRegDef(Type->getOperand(2).getReg());
755-
assert(TypeInst->getOpcode() != SPIRV::OpConstantI);
756-
MachineInstr *ImmInst = MRI->getVRegDef(TypeInst->getOperand(1).getReg());
757-
assert(ImmInst->getOpcode() == TargetOpcode::G_CONSTANT);
758-
uint32_t ArraySize = ImmInst->getOperand(1).getCImm()->getZExtValue();
774+
uint32_t ArraySize = getArrayComponentCount(MRI, Type);
759775
return (buildSpirvTypeName(ElementType, MIRBuilder) + Twine("[") +
760776
Twine(ArraySize) + Twine("]"))
761777
.str();
@@ -1274,9 +1290,9 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeImage(
12741290
uint32_t Depth, uint32_t Arrayed, uint32_t Multisampled, uint32_t Sampled,
12751291
SPIRV::ImageFormat::ImageFormat ImageFormat,
12761292
SPIRV::AccessQualifier::AccessQualifier AccessQual) {
1277-
auto Key = SPIRV::make_descr_image(SPIRVToLLVMType.lookup(SampledType), Dim,
1278-
Depth, Arrayed, Multisampled, Sampled,
1279-
ImageFormat, AccessQual);
1293+
auto Key = SPIRV::irhandle_image(SPIRVToLLVMType.lookup(SampledType), Dim,
1294+
Depth, Arrayed, Multisampled, Sampled,
1295+
ImageFormat, AccessQual);
12801296
if (const MachineInstr *MI = findMI(Key, &MIRBuilder.getMF()))
12811297
return MI;
12821298
const MachineInstr *NewMI =
@@ -1301,7 +1317,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeImage(
13011317

13021318
SPIRVType *
13031319
SPIRVGlobalRegistry::getOrCreateOpTypeSampler(MachineIRBuilder &MIRBuilder) {
1304-
auto Key = SPIRV::make_descr_sampler();
1320+
auto Key = SPIRV::irhandle_sampler();
13051321
const MachineFunction *MF = &MIRBuilder.getMF();
13061322
if (const MachineInstr *MI = findMI(Key, MF))
13071323
return MI;
@@ -1317,7 +1333,7 @@ SPIRVGlobalRegistry::getOrCreateOpTypeSampler(MachineIRBuilder &MIRBuilder) {
13171333
SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypePipe(
13181334
MachineIRBuilder &MIRBuilder,
13191335
SPIRV::AccessQualifier::AccessQualifier AccessQual) {
1320-
auto Key = SPIRV::make_descr_pipe(AccessQual);
1336+
auto Key = SPIRV::irhandle_pipe(AccessQual);
13211337
if (const MachineInstr *MI = findMI(Key, &MIRBuilder.getMF()))
13221338
return MI;
13231339
const MachineInstr *NewMI =
@@ -1332,7 +1348,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypePipe(
13321348

13331349
SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeDeviceEvent(
13341350
MachineIRBuilder &MIRBuilder) {
1335-
auto Key = SPIRV::make_descr_event();
1351+
auto Key = SPIRV::irhandle_event();
13361352
if (const MachineInstr *MI = findMI(Key, &MIRBuilder.getMF()))
13371353
return MI;
13381354
const MachineInstr *NewMI =
@@ -1346,7 +1362,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeDeviceEvent(
13461362

13471363
SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeSampledImage(
13481364
SPIRVType *ImageType, MachineIRBuilder &MIRBuilder) {
1349-
auto Key = SPIRV::make_descr_sampled_image(
1365+
auto Key = SPIRV::irhandle_sampled_image(
13501366
SPIRVToLLVMType.lookup(MIRBuilder.getMF().getRegInfo().getVRegDef(
13511367
ImageType->getOperand(1).getReg())),
13521368
ImageType);
@@ -1465,10 +1481,12 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
14651481
Type *Ty) {
14661482
if (const MachineInstr *MI = findMI(Ty, CurMF))
14671483
return MI;
1468-
MachineIRBuilder MIRBuilder(I);
1484+
MachineBasicBlock &DepMBB = I.getMF()->front();
1485+
MachineIRBuilder MIRBuilder(DepMBB, DepMBB.getFirstNonPHI());
14691486
const MachineInstr *NewMI =
14701487
createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
1471-
return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRVOPcode))
1488+
return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
1489+
MIRBuilder.getDL(), TII.get(SPIRVOPcode))
14721490
.addDef(createTypeVReg(CurMF->getRegInfo()))
14731491
.addImm(BitWidth)
14741492
.addImm(0);
@@ -1522,11 +1540,12 @@ SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineInstr &I,
15221540
Type *Ty = IntegerType::get(CurMF->getFunction().getContext(), 1);
15231541
if (const MachineInstr *MI = findMI(Ty, CurMF))
15241542
return MI;
1525-
MachineIRBuilder MIRBuilder(I);
1543+
MachineBasicBlock &DepMBB = I.getMF()->front();
1544+
MachineIRBuilder MIRBuilder(DepMBB, DepMBB.getFirstNonPHI());
15261545
const MachineInstr *NewMI =
15271546
createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
1528-
return BuildMI(*I.getParent(), I, I.getDebugLoc(),
1529-
TII.get(SPIRV::OpTypeBool))
1547+
return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
1548+
MIRBuilder.getDL(), TII.get(SPIRV::OpTypeBool))
15301549
.addDef(createTypeVReg(CurMF->getRegInfo()));
15311550
});
15321551
add(Ty, NewMI);
@@ -1549,11 +1568,12 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
15491568
const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements);
15501569
if (const MachineInstr *MI = findMI(Ty, CurMF))
15511570
return MI;
1552-
MachineIRBuilder MIRBuilder(I);
1571+
MachineInstr *DepMI = const_cast<MachineInstr *>(BaseType);
1572+
MachineIRBuilder MIRBuilder(*DepMI->getParent(), DepMI->getIterator());
15531573
const MachineInstr *NewMI =
15541574
createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
1555-
return BuildMI(*I.getParent(), I, I.getDebugLoc(),
1556-
TII.get(SPIRV::OpTypeVector))
1575+
return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
1576+
MIRBuilder.getDL(), TII.get(SPIRV::OpTypeVector))
15571577
.addDef(createTypeVReg(CurMF->getRegInfo()))
15581578
.addUse(getSPIRVTypeID(BaseType))
15591579
.addImm(NumElements);
@@ -1571,11 +1591,12 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVArrayType(
15711591
return MI;
15721592
SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(32, I, TII);
15731593
Register Len = getOrCreateConstInt(NumElements, I, SpvTypeInt32, TII);
1574-
MachineIRBuilder MIRBuilder(I);
1594+
MachineBasicBlock &DepMBB = I.getMF()->front();
1595+
MachineIRBuilder MIRBuilder(DepMBB, getInsertPtValidEnd(&DepMBB));
15751596
const MachineInstr *NewMI =
15761597
createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
1577-
return BuildMI(*I.getParent(), I, I.getDebugLoc(),
1578-
TII.get(SPIRV::OpTypeArray))
1598+
return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
1599+
MIRBuilder.getDL(), TII.get(SPIRV::OpTypeArray))
15791600
.addDef(createTypeVReg(CurMF->getRegInfo()))
15801601
.addUse(getSPIRVTypeID(BaseType))
15811602
.addUse(Len);
@@ -1609,7 +1630,8 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
16091630
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
16101631
SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &,
16111632
SPIRV::StorageClass::StorageClass SC) {
1612-
MachineIRBuilder MIRBuilder(I);
1633+
MachineInstr *DepMI = const_cast<MachineInstr *>(BaseType);
1634+
MachineIRBuilder MIRBuilder(*DepMI->getParent(), DepMI->getIterator());
16131635
return getOrCreateSPIRVPointerType(BaseType, MIRBuilder, SC);
16141636
}
16151637

@@ -1627,13 +1649,14 @@ Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I,
16271649
CurMF->getRegInfo().setRegClass(Res, &SPIRV::iIDRegClass);
16281650
assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
16291651

1630-
MachineIRBuilder MIRBuilder(I);
1652+
MachineInstr *DepMI = const_cast<MachineInstr *>(SpvType);
1653+
MachineIRBuilder MIRBuilder(*DepMI->getParent(), DepMI->getIterator());
16311654
const MachineInstr *NewMI =
16321655
createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
1633-
auto MIB =
1634-
BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpUndef))
1635-
.addDef(Res)
1636-
.addUse(getSPIRVTypeID(SpvType));
1656+
auto MIB = BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
1657+
MIRBuilder.getDL(), TII.get(SPIRV::OpUndef))
1658+
.addDef(Res)
1659+
.addUse(getSPIRVTypeID(SpvType));
16371660
const auto &ST = CurMF->getSubtarget();
16381661
constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
16391662
*ST.getRegisterInfo(),

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,9 +473,15 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
473473
Register getOrCreateConstInt(uint64_t Val, MachineInstr &I,
474474
SPIRVType *SpvType, const SPIRVInstrInfo &TII,
475475
bool ZeroAsNull = true);
476+
Register createConstInt(const ConstantInt *CI, MachineInstr &I,
477+
SPIRVType *SpvType, const SPIRVInstrInfo &TII,
478+
bool ZeroAsNull);
476479
Register getOrCreateConstFP(APFloat Val, MachineInstr &I, SPIRVType *SpvType,
477480
const SPIRVInstrInfo &TII,
478481
bool ZeroAsNull = true);
482+
Register createConstFP(const ConstantFP *CF, MachineInstr &I,
483+
SPIRVType *SpvType, const SPIRVInstrInfo &TII,
484+
bool ZeroAsNull);
479485
Register buildConstantFP(APFloat Val, MachineIRBuilder &MIRBuilder,
480486
SPIRVType *SpvType = nullptr);
481487

0 commit comments

Comments
 (0)