Skip to content

Commit 2a290ad

Browse files
[Backport to 11] Support SPV_INTEL_maximum_registers extension (#2344) (#2404)
Co-authored-by: Viktoria Maximova <viktoria.maksimova@intel.com>
1 parent e29887d commit 2a290ad

File tree

14 files changed

+295
-52
lines changed

14 files changed

+295
-52
lines changed

include/LLVMSPIRVExtensions.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,4 @@ EXT(SPV_INTEL_tensor_float32_rounding)
4848
EXT(SPV_INTEL_hw_thread_queries)
4949
EXT(SPV_EXT_relaxed_printf_string_address_space)
5050
EXT(SPV_INTEL_global_variable_decorations)
51+
EXT(SPV_INTEL_maximum_registers)

lib/SPIRV/SPIRVReader.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3776,6 +3776,48 @@ bool SPIRVToLLVM::transMetadata() {
37763776
F->setMetadata(kSPIR2MD::NumSIMD,
37773777
getMDNodeStringIntVec(Context, EM->getLiterals()));
37783778
}
3779+
if (auto *EM = BF->getExecutionMode(ExecutionModeMaximumRegistersINTEL)) {
3780+
NamedMDNode *ExecModeMD =
3781+
M->getOrInsertNamedMetadata(kSPIRVMD::ExecutionMode);
3782+
3783+
SmallVector<Metadata *, 4> ValueVec;
3784+
ValueVec.push_back(ConstantAsMetadata::get(F));
3785+
ValueVec.push_back(
3786+
ConstantAsMetadata::get(getUInt32(M, EM->getExecutionMode())));
3787+
ValueVec.push_back(
3788+
ConstantAsMetadata::get(getUInt32(M, EM->getLiterals()[0])));
3789+
ExecModeMD->addOperand(MDNode::get(*Context, ValueVec));
3790+
}
3791+
if (auto *EM = BF->getExecutionMode(ExecutionModeMaximumRegistersIdINTEL)) {
3792+
NamedMDNode *ExecModeMD =
3793+
M->getOrInsertNamedMetadata(kSPIRVMD::ExecutionMode);
3794+
3795+
SmallVector<Metadata *, 4> ValueVec;
3796+
ValueVec.push_back(ConstantAsMetadata::get(F));
3797+
ValueVec.push_back(
3798+
ConstantAsMetadata::get(getUInt32(M, EM->getExecutionMode())));
3799+
3800+
auto *ExecOp = BF->getModule()->getValue(EM->getLiterals()[0]);
3801+
ValueVec.push_back(
3802+
MDNode::get(*Context, ConstantAsMetadata::get(cast<ConstantInt>(
3803+
transValue(ExecOp, nullptr, nullptr)))));
3804+
ExecModeMD->addOperand(MDNode::get(*Context, ValueVec));
3805+
}
3806+
if (auto *EM =
3807+
BF->getExecutionMode(ExecutionModeNamedMaximumRegistersINTEL)) {
3808+
NamedMDNode *ExecModeMD =
3809+
M->getOrInsertNamedMetadata(kSPIRVMD::ExecutionMode);
3810+
3811+
SmallVector<Metadata *, 4> ValueVec;
3812+
ValueVec.push_back(ConstantAsMetadata::get(F));
3813+
ValueVec.push_back(
3814+
ConstantAsMetadata::get(getUInt32(M, EM->getExecutionMode())));
3815+
3816+
assert(EM->getLiterals()[0] == 0 &&
3817+
"Invalid named maximum number of registers");
3818+
ValueVec.push_back(MDString::get(*Context, "AutoINTEL"));
3819+
ExecModeMD->addOperand(MDNode::get(*Context, ValueVec));
3820+
}
37793821
}
37803822
NamedMDNode *MemoryModelMD =
37813823
M->getOrInsertNamedMetadata(kSPIRVMD::MemoryModel);

lib/SPIRV/SPIRVWriter.cpp

Lines changed: 64 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,9 @@ SPIRVFunction *LLVMToSPIRVBase::transFunctionDecl(Function *F) {
686686
if (BM->isAllowedToUseExtension(ExtensionID::SPV_INTEL_vector_compute))
687687
transVectorComputeMetadata(F);
688688

689+
if (BM->isAllowedToUseExtension(ExtensionID::SPV_INTEL_maximum_registers))
690+
transFunctionMetadataAsExecutionMode(BF, F);
691+
689692
SPIRVDBG(dbgs() << "[transFunction] " << *F << " => ";
690693
spvdbgs() << *BF << '\n';)
691694
return BF;
@@ -763,6 +766,38 @@ void LLVMToSPIRVBase::transVectorComputeMetadata(Function *F) {
763766
}
764767
}
765768

769+
void LLVMToSPIRVBase::transFunctionMetadataAsExecutionMode(SPIRVFunction *BF,
770+
Function *F) {
771+
SmallVector<MDNode *, 1> RegisterAllocModeMDs;
772+
F->getMetadata("RegisterAllocMode", RegisterAllocModeMDs);
773+
774+
for (unsigned I = 0; I < RegisterAllocModeMDs.size(); I++) {
775+
auto *RegisterAllocMode = RegisterAllocModeMDs[I]->getOperand(0).get();
776+
if (isa<MDString>(RegisterAllocMode)) {
777+
const std::string Str = getMDOperandAsString(RegisterAllocModeMDs[I], 0);
778+
const NamedMaximumNumberOfRegisters NamedValue =
779+
SPIRVNamedMaximumNumberOfRegistersNameMap::rmap(Str);
780+
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
781+
OpExecutionMode, BF, ExecutionModeNamedMaximumRegistersINTEL,
782+
NamedValue)));
783+
} else if (isa<MDNode>(RegisterAllocMode)) {
784+
auto *RegisterAllocNodeMDOp =
785+
getMDOperandAsMDNode(RegisterAllocModeMDs[I], 0);
786+
const int Num = getMDOperandAsInt(RegisterAllocNodeMDOp, 0);
787+
auto *Const =
788+
BM->addConstant(transType(Type::getInt32Ty(F->getContext())), Num);
789+
BF->addExecutionMode(BM->add(new SPIRVExecutionModeId(
790+
BF, ExecutionModeMaximumRegistersIdINTEL, Const->getId())));
791+
} else {
792+
const int64_t RegisterAllocVal =
793+
mdconst::dyn_extract<ConstantInt>(RegisterAllocMode)->getZExtValue();
794+
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
795+
OpExecutionMode, BF, ExecutionModeMaximumRegistersINTEL,
796+
RegisterAllocVal)));
797+
}
798+
}
799+
}
800+
766801
SPIRVValue *LLVMToSPIRVBase::transConstant(Value *V) {
767802
if (auto CPNull = dyn_cast<ConstantPointerNull>(V))
768803
return BM->addNullConstant(
@@ -3857,14 +3892,14 @@ bool LLVMToSPIRVBase::transExecutionMode() {
38573892

38583893
switch (EMode) {
38593894
case spv::ExecutionModeContractionOff:
3860-
BF->addExecutionMode(BM->add(
3861-
new SPIRVExecutionMode(BF, static_cast<ExecutionMode>(EMode))));
3895+
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
3896+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode))));
38623897
break;
38633898
case spv::ExecutionModeInitializer:
38643899
case spv::ExecutionModeFinalizer:
38653900
if (BM->isAllowedToUseVersion(VersionNumber::SPIRV_1_1)) {
3866-
BF->addExecutionMode(BM->add(
3867-
new SPIRVExecutionMode(BF, static_cast<ExecutionMode>(EMode))));
3901+
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
3902+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode))));
38683903
} else {
38693904
getErrorLog().checkError(false, SPIRVEC_Requires1_1,
38703905
"Initializer/Finalizer Execution Mode");
@@ -3876,41 +3911,42 @@ bool LLVMToSPIRVBase::transExecutionMode() {
38763911
unsigned X, Y, Z;
38773912
N.get(X).get(Y).get(Z);
38783913
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
3879-
BF, static_cast<ExecutionMode>(EMode), X, Y, Z)));
3914+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode), X, Y, Z)));
38803915
} break;
38813916
case spv::ExecutionModeMaxWorkgroupSizeINTEL: {
38823917
if (BM->isAllowedToUseExtension(
38833918
ExtensionID::SPV_INTEL_kernel_attributes)) {
38843919
unsigned X, Y, Z;
38853920
N.get(X).get(Y).get(Z);
38863921
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
3887-
BF, static_cast<ExecutionMode>(EMode), X, Y, Z)));
3922+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode), X, Y,
3923+
Z)));
38883924
BM->addCapability(CapabilityKernelAttributesINTEL);
38893925
}
38903926
} break;
38913927
case spv::ExecutionModeNoGlobalOffsetINTEL: {
3892-
if (BM->isAllowedToUseExtension(
3893-
ExtensionID::SPV_INTEL_kernel_attributes)) {
3894-
BF->addExecutionMode(BM->add(
3895-
new SPIRVExecutionMode(BF, static_cast<ExecutionMode>(EMode))));
3896-
BM->addCapability(CapabilityKernelAttributesINTEL);
3897-
}
3928+
if (!BM->isAllowedToUseExtension(
3929+
ExtensionID::SPV_INTEL_kernel_attributes))
3930+
break;
3931+
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
3932+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode))));
3933+
BM->addCapability(CapabilityKernelAttributesINTEL);
38983934
} break;
38993935
case spv::ExecutionModeVecTypeHint:
39003936
case spv::ExecutionModeSubgroupSize:
39013937
case spv::ExecutionModeSubgroupsPerWorkgroup: {
39023938
unsigned X;
39033939
N.get(X);
3904-
BF->addExecutionMode(BM->add(
3905-
new SPIRVExecutionMode(BF, static_cast<ExecutionMode>(EMode), X)));
3940+
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
3941+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode), X)));
39063942
} break;
39073943
case spv::ExecutionModeNumSIMDWorkitemsINTEL: {
39083944
if (BM->isAllowedToUseExtension(
39093945
ExtensionID::SPV_INTEL_kernel_attributes)) {
39103946
unsigned X;
39113947
N.get(X);
39123948
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
3913-
BF, static_cast<ExecutionMode>(EMode), X)));
3949+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode), X)));
39143950
BM->addCapability(CapabilityFPGAKernelAttributesINTEL);
39153951
}
39163952
} break;
@@ -3920,7 +3956,7 @@ bool LLVMToSPIRVBase::transExecutionMode() {
39203956
unsigned X;
39213957
N.get(X);
39223958
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
3923-
BF, static_cast<ExecutionMode>(EMode), X)));
3959+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode), X)));
39243960
BM->addCapability(CapabilityFPGAKernelAttributesINTEL);
39253961
}
39263962
} break;
@@ -3930,15 +3966,16 @@ bool LLVMToSPIRVBase::transExecutionMode() {
39303966
unsigned SLMSize;
39313967
N.get(SLMSize);
39323968
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
3933-
BF, static_cast<ExecutionMode>(EMode), SLMSize)));
3969+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode), SLMSize)));
39343970
} break;
39353971
case spv::ExecutionModeNamedBarrierCountINTEL: {
39363972
if (!BM->isAllowedToUseExtension(ExtensionID::SPV_INTEL_vector_compute))
39373973
break;
39383974
unsigned NBarrierCnt = 0;
39393975
N.get(NBarrierCnt);
3940-
BF->addExecutionMode(new SPIRVExecutionMode(
3941-
BF, static_cast<ExecutionMode>(EMode), NBarrierCnt));
3976+
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
3977+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode),
3978+
NBarrierCnt)));
39423979
BM->addExtension(ExtensionID::SPV_INTEL_vector_compute);
39433980
BM->addCapability(CapabilityVectorComputeINTEL);
39443981
} break;
@@ -3953,7 +3990,8 @@ bool LLVMToSPIRVBase::transExecutionMode() {
39533990
unsigned TargetWidth;
39543991
N.get(TargetWidth);
39553992
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
3956-
BF, static_cast<ExecutionMode>(EMode), TargetWidth)));
3993+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode),
3994+
TargetWidth)));
39573995
} break;
39583996
case spv::ExecutionModeRoundingModeRTPINTEL:
39593997
case spv::ExecutionModeRoundingModeRTNINTEL:
@@ -3965,12 +4003,13 @@ bool LLVMToSPIRVBase::transExecutionMode() {
39654003
unsigned TargetWidth;
39664004
N.get(TargetWidth);
39674005
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
3968-
BF, static_cast<ExecutionMode>(EMode), TargetWidth)));
4006+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode),
4007+
TargetWidth)));
39694008
} break;
39704009
case spv::ExecutionModeFastCompositeKernelINTEL: {
39714010
if (BM->isAllowedToUseExtension(ExtensionID::SPV_INTEL_fast_composite))
3972-
BF->addExecutionMode(BM->add(
3973-
new SPIRVExecutionMode(BF, static_cast<ExecutionMode>(EMode))));
4011+
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
4012+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode))));
39744013
} break;
39754014
default:
39764015
llvm_unreachable("invalid execution mode");
@@ -4015,8 +4054,8 @@ void LLVMToSPIRVBase::transFPContract() {
40154054
}
40164055

40174056
if (DisableContraction) {
4018-
BF->addExecutionMode(BF->getModule()->add(
4019-
new SPIRVExecutionMode(BF, spv::ExecutionModeContractionOff)));
4057+
BF->addExecutionMode(BF->getModule()->add(new SPIRVExecutionMode(
4058+
OpExecutionMode, BF, spv::ExecutionModeContractionOff)));
40204059
}
40214060
}
40224061
}

lib/SPIRV/SPIRVWriter.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ class LLVMToSPIRVBase {
107107
SPIRVWord transFunctionControlMask(Function *);
108108
SPIRVFunction *transFunctionDecl(Function *F);
109109
void transVectorComputeMetadata(Function *F);
110+
void transFunctionMetadataAsExecutionMode(SPIRVFunction *BF, Function *F);
110111
bool transGlobalVariables();
111112

112113
Op transBoolOpCode(SPIRVValue *Opn, Op OC);

lib/SPIRV/libSPIRV/SPIRVEntry.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,7 @@ SPIRVEntryPoint::SPIRVEntryPoint(SPIRVModule *TheModule,
540540
SPIRVExecutionModelKind TheExecModel,
541541
SPIRVId TheId, const std::string &TheName,
542542
std::vector<SPIRVId> Variables)
543-
: SPIRVAnnotation(TheModule->get<SPIRVFunction>(TheId),
543+
: SPIRVAnnotation(OpEntryPoint, TheModule->get<SPIRVFunction>(TheId),
544544
getSizeInWords(TheName) + Variables.size() + 3),
545545
ExecModel(TheExecModel), Name(TheName), Variables(Variables) {}
546546

@@ -560,7 +560,7 @@ void SPIRVExecutionMode::encode(spv_ostream &O) const {
560560

561561
void SPIRVExecutionMode::decode(std::istream &I) {
562562
getDecoder(I) >> Target >> ExecMode;
563-
switch (ExecMode) {
563+
switch (static_cast<uint32_t>(ExecMode)) {
564564
case ExecutionModeLocalSize:
565565
case ExecutionModeLocalSizeHint:
566566
case ExecutionModeMaxWorkgroupSizeINTEL:
@@ -583,6 +583,9 @@ void SPIRVExecutionMode::decode(std::istream &I) {
583583
case ExecutionModeSubgroupSize:
584584
case ExecutionModeMaxWorkDimINTEL:
585585
case ExecutionModeNumSIMDWorkitemsINTEL:
586+
case ExecutionModeMaximumRegistersINTEL:
587+
case ExecutionModeMaximumRegistersIdINTEL:
588+
case ExecutionModeNamedMaximumRegistersINTEL:
586589
WordLiterals.resize(1);
587590
break;
588591
default:
@@ -604,7 +607,8 @@ SPIRVForward *SPIRVAnnotationGeneric::getOrCreateTarget() const {
604607
}
605608

606609
SPIRVName::SPIRVName(const SPIRVEntry *TheTarget, const std::string &TheStr)
607-
: SPIRVAnnotation(TheTarget, getSizeInWords(TheStr) + 2), Str(TheStr) {}
610+
: SPIRVAnnotation(OpName, TheTarget, getSizeInWords(TheStr) + 2),
611+
Str(TheStr) {}
608612

609613
void SPIRVName::encode(spv_ostream &O) const { getEncoder(O) << Target << Str; }
610614

0 commit comments

Comments
 (0)