2525#include " llvm/IR/Type.h"
2626#include " llvm/Support/Casting.h"
2727#include < cassert>
28+ #include < functional>
2829
2930using namespace llvm ;
3031SPIRVGlobalRegistry::SPIRVGlobalRegistry (unsigned PointerSize)
@@ -83,8 +84,11 @@ inline Register createTypeVReg(MachineIRBuilder &MIRBuilder) {
8384}
8485
8586SPIRVType *SPIRVGlobalRegistry::getOpTypeBool (MachineIRBuilder &MIRBuilder) {
86- return MIRBuilder.buildInstr (SPIRV::OpTypeBool)
87- .addDef (createTypeVReg (MIRBuilder));
87+
88+ return createOpType (MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
89+ return MIRBuilder.buildInstr (SPIRV::OpTypeBool)
90+ .addDef (createTypeVReg (MIRBuilder));
91+ });
8892}
8993
9094unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth (unsigned Width) const {
@@ -118,24 +122,53 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
118122 MIRBuilder.buildInstr (SPIRV::OpCapability)
119123 .addImm (SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);
120124 }
121- auto MIB = MIRBuilder.buildInstr (SPIRV::OpTypeInt)
122- .addDef (createTypeVReg (MIRBuilder))
123- .addImm (Width)
124- .addImm (IsSigned ? 1 : 0 );
125- return MIB;
125+ return createOpType (MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
126+ return MIRBuilder.buildInstr (SPIRV::OpTypeInt)
127+ .addDef (createTypeVReg (MIRBuilder))
128+ .addImm (Width)
129+ .addImm (IsSigned ? 1 : 0 );
130+ });
126131}
127132
128133SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat (uint32_t Width,
129134 MachineIRBuilder &MIRBuilder) {
130- auto MIB = MIRBuilder.buildInstr (SPIRV::OpTypeFloat)
131- .addDef (createTypeVReg (MIRBuilder))
132- .addImm (Width);
133- return MIB;
135+ return createOpType (MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
136+ return MIRBuilder.buildInstr (SPIRV::OpTypeFloat)
137+ .addDef (createTypeVReg (MIRBuilder))
138+ .addImm (Width);
139+ });
134140}
135141
136142SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid (MachineIRBuilder &MIRBuilder) {
137- return MIRBuilder.buildInstr (SPIRV::OpTypeVoid)
138- .addDef (createTypeVReg (MIRBuilder));
143+ return createOpType (MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
144+ return MIRBuilder.buildInstr (SPIRV::OpTypeVoid)
145+ .addDef (createTypeVReg (MIRBuilder));
146+ });
147+ }
148+
149+ SPIRVType *SPIRVGlobalRegistry::createOpType (
150+ MachineIRBuilder &MIRBuilder,
151+ std::function<MachineInstr *(MachineIRBuilder &)> Op) {
152+ auto oldInsertPoint = MIRBuilder.getInsertPt ();
153+ MachineBasicBlock *OldMBB = &MIRBuilder.getMBB ();
154+
155+ auto LastInsertedType = LastInsertedTypeMap.find (CurMF);
156+ if (LastInsertedType != LastInsertedTypeMap.end ()) {
157+ MIRBuilder.setInsertPt (*MIRBuilder.getMF ().begin (),
158+ LastInsertedType->second ->getIterator ());
159+ } else {
160+ MIRBuilder.setInsertPt (*MIRBuilder.getMF ().begin (),
161+ MIRBuilder.getMF ().begin ()->begin ());
162+ auto Result = LastInsertedTypeMap.try_emplace (CurMF, nullptr );
163+ assert (Result.second );
164+ LastInsertedType = Result.first ;
165+ }
166+
167+ MachineInstr *Type = Op (MIRBuilder);
168+ LastInsertedType->second = Type;
169+
170+ MIRBuilder.setInsertPt (*OldMBB, oldInsertPoint);
171+ return Type;
139172}
140173
141174SPIRVType *SPIRVGlobalRegistry::getOpTypeVector (uint32_t NumElems,
@@ -147,11 +180,12 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems,
147180 EleOpc == SPIRV::OpTypeBool) &&
148181 " Invalid vector element type" );
149182
150- auto MIB = MIRBuilder.buildInstr (SPIRV::OpTypeVector)
151- .addDef (createTypeVReg (MIRBuilder))
152- .addUse (getSPIRVTypeID (ElemType))
153- .addImm (NumElems);
154- return MIB;
183+ return createOpType (MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
184+ return MIRBuilder.buildInstr (SPIRV::OpTypeVector)
185+ .addDef (createTypeVReg (MIRBuilder))
186+ .addUse (getSPIRVTypeID (ElemType))
187+ .addImm (NumElems);
188+ });
155189}
156190
157191std::tuple<Register, ConstantInt *, bool , unsigned >
@@ -688,22 +722,25 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems,
688722 SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType (32 , MIRBuilder);
689723 Register NumElementsVReg =
690724 buildConstantInt (NumElems, MIRBuilder, SpvTypeInt32, EmitIR);
691- auto MIB = MIRBuilder.buildInstr (SPIRV::OpTypeArray)
692- .addDef (createTypeVReg (MIRBuilder))
693- .addUse (getSPIRVTypeID (ElemType))
694- .addUse (NumElementsVReg);
695- return MIB;
725+ return createOpType (MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
726+ return MIRBuilder.buildInstr (SPIRV::OpTypeArray)
727+ .addDef (createTypeVReg (MIRBuilder))
728+ .addUse (getSPIRVTypeID (ElemType))
729+ .addUse (NumElementsVReg);
730+ });
696731}
697732
698733SPIRVType *SPIRVGlobalRegistry::getOpTypeOpaque (const StructType *Ty,
699734 MachineIRBuilder &MIRBuilder) {
700735 assert (Ty->hasName ());
701736 const StringRef Name = Ty->hasName () ? Ty->getName () : " " ;
702737 Register ResVReg = createTypeVReg (MIRBuilder);
703- auto MIB = MIRBuilder.buildInstr (SPIRV::OpTypeOpaque).addDef (ResVReg);
704- addStringImm (Name, MIB);
705- buildOpName (ResVReg, Name, MIRBuilder);
706- return MIB;
738+ return createOpType (MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
739+ auto MIB = MIRBuilder.buildInstr (SPIRV::OpTypeOpaque).addDef (ResVReg);
740+ addStringImm (Name, MIB);
741+ buildOpName (ResVReg, Name, MIRBuilder);
742+ return MIB;
743+ });
707744}
708745
709746SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct (const StructType *Ty,
@@ -717,14 +754,16 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty,
717754 FieldTypes.push_back (getSPIRVTypeID (ElemTy));
718755 }
719756 Register ResVReg = createTypeVReg (MIRBuilder);
720- auto MIB = MIRBuilder.buildInstr (SPIRV::OpTypeStruct).addDef (ResVReg);
721- for (const auto &Ty : FieldTypes)
722- MIB.addUse (Ty);
723- if (Ty->hasName ())
724- buildOpName (ResVReg, Ty->getName (), MIRBuilder);
725- if (Ty->isPacked ())
726- buildOpDecorate (ResVReg, MIRBuilder, SPIRV::Decoration::CPacked, {});
727- return MIB;
757+ return createOpType (MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
758+ auto MIB = MIRBuilder.buildInstr (SPIRV::OpTypeStruct).addDef (ResVReg);
759+ for (const auto &Ty : FieldTypes)
760+ MIB.addUse (Ty);
761+ if (Ty->hasName ())
762+ buildOpName (ResVReg, Ty->getName (), MIRBuilder);
763+ if (Ty->isPacked ())
764+ buildOpDecorate (ResVReg, MIRBuilder, SPIRV::Decoration::CPacked, {});
765+ return MIB;
766+ });
728767}
729768
730769SPIRVType *SPIRVGlobalRegistry::getOrCreateSpecialType (
@@ -739,17 +778,22 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypePointer(
739778 MachineIRBuilder &MIRBuilder, Register Reg) {
740779 if (!Reg.isValid ())
741780 Reg = createTypeVReg (MIRBuilder);
742- return MIRBuilder.buildInstr (SPIRV::OpTypePointer)
743- .addDef (Reg)
744- .addImm (static_cast <uint32_t >(SC))
745- .addUse (getSPIRVTypeID (ElemType));
781+
782+ return createOpType (MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
783+ return MIRBuilder.buildInstr (SPIRV::OpTypePointer)
784+ .addDef (Reg)
785+ .addImm (static_cast <uint32_t >(SC))
786+ .addUse (getSPIRVTypeID (ElemType));
787+ });
746788}
747789
748790SPIRVType *SPIRVGlobalRegistry::getOpTypeForwardPointer (
749791 SPIRV::StorageClass::StorageClass SC, MachineIRBuilder &MIRBuilder) {
750- return MIRBuilder.buildInstr (SPIRV::OpTypeForwardPointer)
751- .addUse (createTypeVReg (MIRBuilder))
752- .addImm (static_cast <uint32_t >(SC));
792+ return createOpType (MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
793+ return MIRBuilder.buildInstr (SPIRV::OpTypeForwardPointer)
794+ .addUse (createTypeVReg (MIRBuilder))
795+ .addImm (static_cast <uint32_t >(SC));
796+ });
753797}
754798
755799SPIRVType *SPIRVGlobalRegistry::getOpTypeFunction (
0 commit comments