Skip to content

Commit 6f80df1

Browse files
authored
[Backport to 10] Support SPV_INTEL_maximum_registers extension (#2344) (#2405)
1 parent fe87ee7 commit 6f80df1

File tree

14 files changed

+288
-44
lines changed

14 files changed

+288
-44
lines changed

include/LLVMSPIRVExtensions.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,4 @@ EXT(SPV_EXT_relaxed_printf_string_address_space)
4040
EXT(SPV_INTEL_hw_thread_queries)
4141
EXT(SPV_INTEL_split_barrier)
4242
EXT(SPV_INTEL_global_variable_decorations)
43+
EXT(SPV_INTEL_maximum_registers)

lib/SPIRV/SPIRVReader.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3203,6 +3203,48 @@ bool SPIRVToLLVM::transMetadata() {
32033203
F->setMetadata(kSPIR2MD::NumSIMD,
32043204
getMDNodeStringIntVec(Context, EM->getLiterals()));
32053205
}
3206+
if (auto *EM = BF->getExecutionMode(ExecutionModeMaximumRegistersINTEL)) {
3207+
NamedMDNode *ExecModeMD =
3208+
M->getOrInsertNamedMetadata(kSPIRVMD::ExecutionMode);
3209+
3210+
SmallVector<Metadata *, 4> ValueVec;
3211+
ValueVec.push_back(ConstantAsMetadata::get(F));
3212+
ValueVec.push_back(
3213+
ConstantAsMetadata::get(getUInt32(M, EM->getExecutionMode())));
3214+
ValueVec.push_back(
3215+
ConstantAsMetadata::get(getUInt32(M, EM->getLiterals()[0])));
3216+
ExecModeMD->addOperand(MDNode::get(*Context, ValueVec));
3217+
}
3218+
if (auto *EM = BF->getExecutionMode(ExecutionModeMaximumRegistersIdINTEL)) {
3219+
NamedMDNode *ExecModeMD =
3220+
M->getOrInsertNamedMetadata(kSPIRVMD::ExecutionMode);
3221+
3222+
SmallVector<Metadata *, 4> ValueVec;
3223+
ValueVec.push_back(ConstantAsMetadata::get(F));
3224+
ValueVec.push_back(
3225+
ConstantAsMetadata::get(getUInt32(M, EM->getExecutionMode())));
3226+
3227+
auto *ExecOp = BF->getModule()->getValue(EM->getLiterals()[0]);
3228+
ValueVec.push_back(
3229+
MDNode::get(*Context, ConstantAsMetadata::get(cast<ConstantInt>(
3230+
transValue(ExecOp, nullptr, nullptr)))));
3231+
ExecModeMD->addOperand(MDNode::get(*Context, ValueVec));
3232+
}
3233+
if (auto *EM =
3234+
BF->getExecutionMode(ExecutionModeNamedMaximumRegistersINTEL)) {
3235+
NamedMDNode *ExecModeMD =
3236+
M->getOrInsertNamedMetadata(kSPIRVMD::ExecutionMode);
3237+
3238+
SmallVector<Metadata *, 4> ValueVec;
3239+
ValueVec.push_back(ConstantAsMetadata::get(F));
3240+
ValueVec.push_back(
3241+
ConstantAsMetadata::get(getUInt32(M, EM->getExecutionMode())));
3242+
3243+
assert(EM->getLiterals()[0] == 0 &&
3244+
"Invalid named maximum number of registers");
3245+
ValueVec.push_back(MDString::get(*Context, "AutoINTEL"));
3246+
ExecModeMD->addOperand(MDNode::get(*Context, ValueVec));
3247+
}
32063248
}
32073249
return true;
32083250
}

lib/SPIRV/SPIRVWriter.cpp

Lines changed: 57 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,9 @@ SPIRVFunction *LLVMToSPIRV::transFunctionDecl(Function *F) {
632632
if (BM->isAllowedToUseExtension(ExtensionID::SPV_INTEL_vector_compute))
633633
transVectorComputeMetadata(F);
634634

635+
if (BM->isAllowedToUseExtension(ExtensionID::SPV_INTEL_maximum_registers))
636+
transFunctionMetadataAsExecutionMode(BF, F);
637+
635638
SPIRVDBG(dbgs() << "[transFunction] " << *F << " => ";
636639
spvdbgs() << *BF << '\n';)
637640
return BF;
@@ -708,6 +711,38 @@ void LLVMToSPIRV::transVectorComputeMetadata(Function *F) {
708711
}
709712
}
710713

714+
void LLVMToSPIRV::transFunctionMetadataAsExecutionMode(SPIRVFunction *BF,
715+
Function *F) {
716+
SmallVector<MDNode *, 1> RegisterAllocModeMDs;
717+
F->getMetadata("RegisterAllocMode", RegisterAllocModeMDs);
718+
719+
for (unsigned I = 0; I < RegisterAllocModeMDs.size(); I++) {
720+
auto *RegisterAllocMode = RegisterAllocModeMDs[I]->getOperand(0).get();
721+
if (isa<MDString>(RegisterAllocMode)) {
722+
const std::string Str = getMDOperandAsString(RegisterAllocModeMDs[I], 0);
723+
const NamedMaximumNumberOfRegisters NamedValue =
724+
SPIRVNamedMaximumNumberOfRegistersNameMap::rmap(Str);
725+
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
726+
OpExecutionMode, BF, ExecutionModeNamedMaximumRegistersINTEL,
727+
NamedValue)));
728+
} else if (isa<MDNode>(RegisterAllocMode)) {
729+
auto *RegisterAllocNodeMDOp =
730+
getMDOperandAsMDNode(RegisterAllocModeMDs[I], 0);
731+
const int Num = getMDOperandAsInt(RegisterAllocNodeMDOp, 0);
732+
auto *Const =
733+
BM->addConstant(transType(Type::getInt32Ty(F->getContext())), Num);
734+
BF->addExecutionMode(BM->add(new SPIRVExecutionModeId(
735+
BF, ExecutionModeMaximumRegistersIdINTEL, Const->getId())));
736+
} else {
737+
const int64_t RegisterAllocVal =
738+
mdconst::dyn_extract<ConstantInt>(RegisterAllocMode)->getZExtValue();
739+
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
740+
OpExecutionMode, BF, ExecutionModeMaximumRegistersINTEL,
741+
RegisterAllocVal)));
742+
}
743+
}
744+
}
745+
711746
SPIRVValue *LLVMToSPIRV::transConstant(Value *V) {
712747
if (auto CPNull = dyn_cast<ConstantPointerNull>(V))
713748
return BM->addNullConstant(
@@ -3254,23 +3289,25 @@ bool LLVMToSPIRV::transExecutionMode() {
32543289
case spv::ExecutionModeContractionOff:
32553290
case spv::ExecutionModeInitializer:
32563291
case spv::ExecutionModeFinalizer:
3257-
BF->addExecutionMode(BM->add(
3258-
new SPIRVExecutionMode(BF, static_cast<ExecutionMode>(EMode))));
3292+
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
3293+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode))));
3294+
break;
32593295
break;
32603296
case spv::ExecutionModeLocalSize:
32613297
case spv::ExecutionModeLocalSizeHint: {
32623298
unsigned X, Y, Z;
32633299
N.get(X).get(Y).get(Z);
32643300
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
3265-
BF, static_cast<ExecutionMode>(EMode), X, Y, Z)));
3301+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode), X, Y, Z)));
32663302
} break;
32673303
case spv::ExecutionModeMaxWorkgroupSizeINTEL: {
32683304
if (BM->isAllowedToUseExtension(
32693305
ExtensionID::SPV_INTEL_kernel_attributes)) {
32703306
unsigned X, Y, Z;
32713307
N.get(X).get(Y).get(Z);
32723308
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
3273-
BF, static_cast<ExecutionMode>(EMode), X, Y, Z)));
3309+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode), X, Y,
3310+
Z)));
32743311
BM->addCapability(CapabilityKernelAttributesINTEL);
32753312
}
32763313
} break;
@@ -3279,16 +3316,16 @@ bool LLVMToSPIRV::transExecutionMode() {
32793316
case spv::ExecutionModeSubgroupsPerWorkgroup: {
32803317
unsigned X;
32813318
N.get(X);
3282-
BF->addExecutionMode(BM->add(
3283-
new SPIRVExecutionMode(BF, static_cast<ExecutionMode>(EMode), X)));
3319+
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
3320+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode), X)));
32843321
} break;
32853322
case spv::ExecutionModeNumSIMDWorkitemsINTEL: {
32863323
if (BM->isAllowedToUseExtension(
32873324
ExtensionID::SPV_INTEL_kernel_attributes)) {
32883325
unsigned X;
32893326
N.get(X);
32903327
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
3291-
BF, static_cast<ExecutionMode>(EMode), X)));
3328+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode), X)));
32923329
BM->addCapability(CapabilityFPGAKernelAttributesINTEL);
32933330
}
32943331
} break;
@@ -3298,7 +3335,7 @@ bool LLVMToSPIRV::transExecutionMode() {
32983335
unsigned X;
32993336
N.get(X);
33003337
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
3301-
BF, static_cast<ExecutionMode>(EMode), X)));
3338+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode), X)));
33023339
BM->addCapability(CapabilityFPGAKernelAttributesINTEL);
33033340
}
33043341
} break;
@@ -3308,15 +3345,16 @@ bool LLVMToSPIRV::transExecutionMode() {
33083345
unsigned SLMSize;
33093346
N.get(SLMSize);
33103347
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
3311-
BF, static_cast<ExecutionMode>(EMode), SLMSize)));
3348+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode), SLMSize)));
33123349
} break;
33133350
case spv::ExecutionModeNamedBarrierCountINTEL: {
33143351
if (!BM->isAllowedToUseExtension(ExtensionID::SPV_INTEL_vector_compute))
33153352
break;
33163353
unsigned NBarrierCnt = 0;
33173354
N.get(NBarrierCnt);
3318-
BF->addExecutionMode(new SPIRVExecutionMode(
3319-
BF, static_cast<ExecutionMode>(EMode), NBarrierCnt));
3355+
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
3356+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode),
3357+
NBarrierCnt)));
33203358
BM->addExtension(ExtensionID::SPV_INTEL_vector_compute);
33213359
BM->addCapability(CapabilityVectorComputeINTEL);
33223360
} break;
@@ -3331,7 +3369,8 @@ bool LLVMToSPIRV::transExecutionMode() {
33313369
unsigned TargetWidth;
33323370
N.get(TargetWidth);
33333371
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
3334-
BF, static_cast<ExecutionMode>(EMode), TargetWidth)));
3372+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode),
3373+
TargetWidth)));
33353374
} break;
33363375
case spv::ExecutionModeRoundingModeRTPINTEL:
33373376
case spv::ExecutionModeRoundingModeRTNINTEL:
@@ -3343,12 +3382,13 @@ bool LLVMToSPIRV::transExecutionMode() {
33433382
unsigned TargetWidth;
33443383
N.get(TargetWidth);
33453384
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
3346-
BF, static_cast<ExecutionMode>(EMode), TargetWidth)));
3385+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode),
3386+
TargetWidth)));
33473387
} break;
33483388
case spv::ExecutionModeFastCompositeKernelINTEL: {
33493389
if (BM->isAllowedToUseExtension(ExtensionID::SPV_INTEL_fast_composite))
3350-
BF->addExecutionMode(BM->add(
3351-
new SPIRVExecutionMode(BF, static_cast<ExecutionMode>(EMode))));
3390+
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
3391+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode))));
33523392
} break;
33533393
default:
33543394
llvm_unreachable("invalid execution mode");
@@ -3393,8 +3433,8 @@ void LLVMToSPIRV::transFPContract() {
33933433
}
33943434

33953435
if (DisableContraction) {
3396-
BF->addExecutionMode(BF->getModule()->add(
3397-
new SPIRVExecutionMode(BF, spv::ExecutionModeContractionOff)));
3436+
BF->addExecutionMode(BF->getModule()->add(new SPIRVExecutionMode(
3437+
OpExecutionMode, BF, spv::ExecutionModeContractionOff)));
33983438
}
33993439
}
34003440
}

lib/SPIRV/SPIRVWriter.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ class LLVMToSPIRV : public ModulePass {
113113
SPIRVWord transFunctionControlMask(Function *);
114114
SPIRVFunction *transFunctionDecl(Function *F);
115115
void transVectorComputeMetadata(Function *F);
116+
void transFunctionMetadataAsExecutionMode(SPIRVFunction *BF, Function *F);
116117
bool transGlobalVariables();
117118

118119
Op transBoolOpCode(SPIRVValue *Opn, Op OC);

lib/SPIRV/libSPIRV/SPIRVEntry.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,7 @@ SPIRVEntryPoint::SPIRVEntryPoint(SPIRVModule *TheModule,
540540
SPIRVExecutionModelKind TheExecModel,
541541
SPIRVId TheId, const std::string &TheName,
542542
std::vector<SPIRVId> Variables)
543-
: SPIRVAnnotation(TheModule->get<SPIRVFunction>(TheId),
543+
: SPIRVAnnotation(OpEntryPoint, TheModule->get<SPIRVFunction>(TheId),
544544
getSizeInWords(TheName) + Variables.size() + 3),
545545
ExecModel(TheExecModel), Name(TheName), Variables(Variables) {}
546546

@@ -560,7 +560,7 @@ void SPIRVExecutionMode::encode(spv_ostream &O) const {
560560

561561
void SPIRVExecutionMode::decode(std::istream &I) {
562562
getDecoder(I) >> Target >> ExecMode;
563-
switch (ExecMode) {
563+
switch (static_cast<uint32_t>(ExecMode)) {
564564
case ExecutionModeLocalSize:
565565
case ExecutionModeLocalSizeHint:
566566
case ExecutionModeMaxWorkgroupSizeINTEL:
@@ -583,6 +583,9 @@ void SPIRVExecutionMode::decode(std::istream &I) {
583583
case ExecutionModeSubgroupSize:
584584
case ExecutionModeMaxWorkDimINTEL:
585585
case ExecutionModeNumSIMDWorkitemsINTEL:
586+
case ExecutionModeMaximumRegistersINTEL:
587+
case ExecutionModeMaximumRegistersIdINTEL:
588+
case ExecutionModeNamedMaximumRegistersINTEL:
586589
WordLiterals.resize(1);
587590
break;
588591
default:
@@ -604,7 +607,8 @@ SPIRVForward *SPIRVAnnotationGeneric::getOrCreateTarget() const {
604607
}
605608

606609
SPIRVName::SPIRVName(const SPIRVEntry *TheTarget, const std::string &TheStr)
607-
: SPIRVAnnotation(TheTarget, getSizeInWords(TheStr) + 2), Str(TheStr) {}
610+
: SPIRVAnnotation(OpName, TheTarget, getSizeInWords(TheStr) + 2),
611+
Str(TheStr) {}
608612

609613
void SPIRVName::encode(spv_ostream &O) const { getEncoder(O) << Target << Str; }
610614

0 commit comments

Comments
 (0)