Skip to content

Commit e896cc8

Browse files
authored
[Backport to 9] Support SPV_INTEL_maximum_registers extension (#2344) (#2408)
1 parent 9caf18b commit e896cc8

File tree

14 files changed

+282
-41
lines changed

14 files changed

+282
-41
lines changed

include/LLVMSPIRVExtensions.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,4 @@ EXT(SPV_INTEL_hw_thread_queries)
3838
EXT(SPV_EXT_relaxed_printf_string_address_space)
3939
EXT(SPV_INTEL_split_barrier)
4040
EXT(SPV_INTEL_global_variable_decorations)
41+
EXT(SPV_INTEL_maximum_registers)

lib/SPIRV/SPIRVReader.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2823,6 +2823,48 @@ bool SPIRVToLLVM::transMetadata() {
28232823
auto SizeMD = ConstantAsMetadata::get(getUInt32(M, EM->getLiterals()[0]));
28242824
F->setMetadata(kSPIR2MD::SubgroupSize, MDNode::get(*Context, SizeMD));
28252825
}
2826+
if (auto *EM = BF->getExecutionMode(ExecutionModeMaximumRegistersINTEL)) {
2827+
NamedMDNode *ExecModeMD =
2828+
M->getOrInsertNamedMetadata(kSPIRVMD::ExecutionMode);
2829+
2830+
SmallVector<Metadata *, 4> ValueVec;
2831+
ValueVec.push_back(ConstantAsMetadata::get(F));
2832+
ValueVec.push_back(
2833+
ConstantAsMetadata::get(getUInt32(M, EM->getExecutionMode())));
2834+
ValueVec.push_back(
2835+
ConstantAsMetadata::get(getUInt32(M, EM->getLiterals()[0])));
2836+
ExecModeMD->addOperand(MDNode::get(*Context, ValueVec));
2837+
}
2838+
if (auto *EM = BF->getExecutionMode(ExecutionModeMaximumRegistersIdINTEL)) {
2839+
NamedMDNode *ExecModeMD =
2840+
M->getOrInsertNamedMetadata(kSPIRVMD::ExecutionMode);
2841+
2842+
SmallVector<Metadata *, 4> ValueVec;
2843+
ValueVec.push_back(ConstantAsMetadata::get(F));
2844+
ValueVec.push_back(
2845+
ConstantAsMetadata::get(getUInt32(M, EM->getExecutionMode())));
2846+
2847+
auto *ExecOp = BF->getModule()->getValue(EM->getLiterals()[0]);
2848+
ValueVec.push_back(
2849+
MDNode::get(*Context, ConstantAsMetadata::get(cast<ConstantInt>(
2850+
transValue(ExecOp, nullptr, nullptr)))));
2851+
ExecModeMD->addOperand(MDNode::get(*Context, ValueVec));
2852+
}
2853+
if (auto *EM =
2854+
BF->getExecutionMode(ExecutionModeNamedMaximumRegistersINTEL)) {
2855+
NamedMDNode *ExecModeMD =
2856+
M->getOrInsertNamedMetadata(kSPIRVMD::ExecutionMode);
2857+
2858+
SmallVector<Metadata *, 4> ValueVec;
2859+
ValueVec.push_back(ConstantAsMetadata::get(F));
2860+
ValueVec.push_back(
2861+
ConstantAsMetadata::get(getUInt32(M, EM->getExecutionMode())));
2862+
2863+
assert(EM->getLiterals()[0] == 0 &&
2864+
"Invalid named maximum number of registers");
2865+
ValueVec.push_back(MDString::get(*Context, "AutoINTEL"));
2866+
ExecModeMD->addOperand(MDNode::get(*Context, ValueVec));
2867+
}
28262868
}
28272869
return true;
28282870
}

lib/SPIRV/SPIRVWriter.cpp

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,9 @@ SPIRVFunction *LLVMToSPIRV::transFunctionDecl(Function *F) {
593593
if (BM->isAllowedToUseExtension(ExtensionID::SPV_INTEL_vector_compute))
594594
transVectorComputeMetadata(F);
595595

596+
if (BM->isAllowedToUseExtension(ExtensionID::SPV_INTEL_maximum_registers))
597+
transFunctionMetadataAsExecutionMode(BF, F);
598+
596599
SPIRVDBG(dbgs() << "[transFunction] " << *F << " => ";
597600
spvdbgs() << *BF << '\n';)
598601
return BF;
@@ -669,6 +672,36 @@ void LLVMToSPIRV::transVectorComputeMetadata(Function *F) {
669672
}
670673
}
671674

675+
void LLVMToSPIRV::transFunctionMetadataAsExecutionMode(SPIRVFunction *BF,
676+
Function *F) {
677+
SmallVector<MDNode *, 1> RegisterAllocModeMDs;
678+
F->getMetadata("RegisterAllocMode", RegisterAllocModeMDs);
679+
680+
for (unsigned I = 0; I < RegisterAllocModeMDs.size(); I++) {
681+
auto *RegisterAllocMode = RegisterAllocModeMDs[I]->getOperand(0).get();
682+
if (isa<MDString>(RegisterAllocMode)) {
683+
const std::string Str = getMDOperandAsString(RegisterAllocModeMDs[I], 0);
684+
const NamedMaximumNumberOfRegisters NamedValue =
685+
SPIRVNamedMaximumNumberOfRegistersNameMap::rmap(Str);
686+
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
687+
OpExecutionMode, BF, ExecutionModeNamedMaximumRegistersINTEL,
688+
NamedValue)));
689+
} else if (isa<MDNode>(RegisterAllocMode)) {
690+
const int Num = getMDOperandAsInt(cast<MDNode>(RegisterAllocMode), 0);
691+
auto *Const =
692+
BM->addConstant(transType(Type::getInt32Ty(F->getContext())), Num);
693+
BF->addExecutionMode(BM->add(new SPIRVExecutionModeId(
694+
BF, ExecutionModeMaximumRegistersIdINTEL, Const->getId())));
695+
} else {
696+
const int64_t RegisterAllocVal =
697+
mdconst::dyn_extract<ConstantInt>(RegisterAllocMode)->getZExtValue();
698+
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
699+
OpExecutionMode, BF, ExecutionModeMaximumRegistersINTEL,
700+
RegisterAllocVal)));
701+
}
702+
}
703+
}
704+
672705
SPIRVValue *LLVMToSPIRV::transConstant(Value *V) {
673706
if (auto CPNull = dyn_cast<ConstantPointerNull>(V))
674707
return BM->addNullConstant(
@@ -2982,39 +3015,41 @@ bool LLVMToSPIRV::transExecutionMode() {
29823015
case spv::ExecutionModeContractionOff:
29833016
case spv::ExecutionModeInitializer:
29843017
case spv::ExecutionModeFinalizer:
2985-
BF->addExecutionMode(BM->add(
2986-
new SPIRVExecutionMode(BF, static_cast<ExecutionMode>(EMode))));
3018+
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
3019+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode))));
3020+
break;
29873021
break;
29883022
case spv::ExecutionModeLocalSize:
29893023
case spv::ExecutionModeLocalSizeHint: {
29903024
unsigned X, Y, Z;
29913025
N.get(X).get(Y).get(Z);
29923026
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
2993-
BF, static_cast<ExecutionMode>(EMode), X, Y, Z)));
3027+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode), X, Y, Z)));
29943028
} break;
29953029
case spv::ExecutionModeVecTypeHint:
29963030
case spv::ExecutionModeSubgroupSize:
29973031
case spv::ExecutionModeSubgroupsPerWorkgroup: {
29983032
unsigned X;
29993033
N.get(X);
3000-
BF->addExecutionMode(BM->add(
3001-
new SPIRVExecutionMode(BF, static_cast<ExecutionMode>(EMode), X)));
3034+
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
3035+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode), X)));
30023036
} break;
30033037
case spv::ExecutionModeSharedLocalMemorySizeINTEL: {
30043038
if (!BM->isAllowedToUseExtension(ExtensionID::SPV_INTEL_vector_compute))
30053039
break;
30063040
unsigned SLMSize;
30073041
N.get(SLMSize);
30083042
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
3009-
BF, static_cast<ExecutionMode>(EMode), SLMSize)));
3043+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode), SLMSize)));
30103044
} break;
30113045
case spv::ExecutionModeNamedBarrierCountINTEL: {
30123046
if (!BM->isAllowedToUseExtension(ExtensionID::SPV_INTEL_vector_compute))
30133047
break;
30143048
unsigned NBarrierCnt = 0;
30153049
N.get(NBarrierCnt);
3016-
BF->addExecutionMode(new SPIRVExecutionMode(
3017-
BF, static_cast<ExecutionMode>(EMode), NBarrierCnt));
3050+
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
3051+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode),
3052+
NBarrierCnt)));
30183053
BM->addExtension(ExtensionID::SPV_INTEL_vector_compute);
30193054
BM->addCapability(CapabilityVectorComputeINTEL);
30203055
} break;
@@ -3029,7 +3064,8 @@ bool LLVMToSPIRV::transExecutionMode() {
30293064
unsigned TargetWidth;
30303065
N.get(TargetWidth);
30313066
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
3032-
BF, static_cast<ExecutionMode>(EMode), TargetWidth)));
3067+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode),
3068+
TargetWidth)));
30333069
} break;
30343070
case spv::ExecutionModeRoundingModeRTPINTEL:
30353071
case spv::ExecutionModeRoundingModeRTNINTEL:
@@ -3041,12 +3077,13 @@ bool LLVMToSPIRV::transExecutionMode() {
30413077
unsigned TargetWidth;
30423078
N.get(TargetWidth);
30433079
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
3044-
BF, static_cast<ExecutionMode>(EMode), TargetWidth)));
3080+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode),
3081+
TargetWidth)));
30453082
} break;
30463083
case spv::ExecutionModeFastCompositeKernelINTEL: {
30473084
if (BM->isAllowedToUseExtension(ExtensionID::SPV_INTEL_fast_composite))
3048-
BF->addExecutionMode(BM->add(
3049-
new SPIRVExecutionMode(BF, static_cast<ExecutionMode>(EMode))));
3085+
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
3086+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode))));
30503087
} break;
30513088
default:
30523089
llvm_unreachable("invalid execution mode");
@@ -3091,8 +3128,8 @@ void LLVMToSPIRV::transFPContract() {
30913128
}
30923129

30933130
if (DisableContraction) {
3094-
BF->addExecutionMode(BF->getModule()->add(
3095-
new SPIRVExecutionMode(BF, spv::ExecutionModeContractionOff)));
3131+
BF->addExecutionMode(BF->getModule()->add(new SPIRVExecutionMode(
3132+
OpExecutionMode, BF, spv::ExecutionModeContractionOff)));
30963133
}
30973134
}
30983135
}

lib/SPIRV/SPIRVWriter.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ class LLVMToSPIRV : public ModulePass {
113113
SPIRVWord transFunctionControlMask(Function *);
114114
SPIRVFunction *transFunctionDecl(Function *F);
115115
void transVectorComputeMetadata(Function *F);
116+
void transFunctionMetadataAsExecutionMode(SPIRVFunction *BF, Function *F);
116117
bool transGlobalVariables();
117118

118119
Op transBoolOpCode(SPIRVValue *Opn, Op OC);

lib/SPIRV/libSPIRV/SPIRVEntry.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,7 @@ SPIRVEntryPoint::SPIRVEntryPoint(SPIRVModule *TheModule,
518518
SPIRVExecutionModelKind TheExecModel,
519519
SPIRVId TheId, const std::string &TheName,
520520
std::vector<SPIRVId> Variables)
521-
: SPIRVAnnotation(TheModule->get<SPIRVFunction>(TheId),
521+
: SPIRVAnnotation(OpEntryPoint, TheModule->get<SPIRVFunction>(TheId),
522522
getSizeInWords(TheName) + Variables.size() + 3),
523523
ExecModel(TheExecModel), Name(TheName), Variables(Variables) {}
524524

@@ -538,7 +538,7 @@ void SPIRVExecutionMode::encode(spv_ostream &O) const {
538538

539539
void SPIRVExecutionMode::decode(std::istream &I) {
540540
getDecoder(I) >> Target >> ExecMode;
541-
switch (ExecMode) {
541+
switch (static_cast<uint32_t>(ExecMode)) {
542542
case ExecutionModeLocalSize:
543543
case ExecutionModeLocalSizeHint:
544544
WordLiterals.resize(3);
@@ -558,6 +558,9 @@ void SPIRVExecutionMode::decode(std::istream &I) {
558558
case ExecutionModeSharedLocalMemorySizeINTEL:
559559
case ExecutionModeNamedBarrierCountINTEL:
560560
case ExecutionModeSubgroupSize:
561+
case ExecutionModeMaximumRegistersINTEL:
562+
case ExecutionModeMaximumRegistersIdINTEL:
563+
case ExecutionModeNamedMaximumRegistersINTEL:
561564
WordLiterals.resize(1);
562565
break;
563566
default:
@@ -579,7 +582,8 @@ SPIRVForward *SPIRVAnnotationGeneric::getOrCreateTarget() const {
579582
}
580583

581584
SPIRVName::SPIRVName(const SPIRVEntry *TheTarget, const std::string &TheStr)
582-
: SPIRVAnnotation(TheTarget, getSizeInWords(TheStr) + 2), Str(TheStr) {}
585+
: SPIRVAnnotation(OpName, TheTarget, getSizeInWords(TheStr) + 2),
586+
Str(TheStr) {}
583587

584588
void SPIRVName::encode(spv_ostream &O) const { getEncoder(O) << Target << Str; }
585589

0 commit comments

Comments
 (0)