-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[SPIR-V] Implement SPV_KHR_float_controls2 #146941
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 19 commits
cbe7a07
4c325ac
a664820
0648840
34c0628
1d87705
de07e83
3edf82a
c0ef106
97720d5
3792464
eccd27a
868c6a4
21c17d1
cfdc210
9b8b43d
a353295
f1403d6
37a3abe
d76db3e
00548e5
f2aefbf
6a77fcd
848611a
527311d
bde0292
2295de4
6420f43
0bf9ca6
5f26101
368793f
cb1065e
d337ef2
c19d296
c9d226c
b5ce97a
0cd1923
c973c1a
f684280
931431b
f2af0d4
910cd69
b487250
9768733
fe49f8b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -80,6 +80,7 @@ class SPIRVAsmPrinter : public AsmPrinter { | |
| void outputExecutionMode(const Module &M); | ||
| void outputAnnotations(const Module &M); | ||
| void outputModuleSections(); | ||
| void outputFPFastMathDefaultInfo(); | ||
| bool isHidden() { | ||
| return MF->getFunction() | ||
| .getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME) | ||
|
|
@@ -497,11 +498,27 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) { | |
| NamedMDNode *Node = M.getNamedMetadata("spirv.ExecutionMode"); | ||
| if (Node) { | ||
| for (unsigned i = 0; i < Node->getNumOperands(); i++) { | ||
| // If SPV_KHR_float_controls2 is enabled and we find any of | ||
| // FPFastMathDefault, ContractionOff or SignedZeroInfNanPreserve execution | ||
| // modes, skip it, it'll be done somewhere else. | ||
| if (ST->canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) { | ||
| const auto EM = | ||
| cast<ConstantInt>( | ||
| cast<ConstantAsMetadata>((Node->getOperand(i))->getOperand(1)) | ||
| ->getValue()) | ||
| ->getZExtValue(); | ||
| if (EM == SPIRV::ExecutionMode::FPFastMathDefault || | ||
| EM == SPIRV::ExecutionMode::ContractionOff || | ||
| EM == SPIRV::ExecutionMode::SignedZeroInfNanPreserve) | ||
| continue; | ||
| } | ||
|
|
||
| MCInst Inst; | ||
| Inst.setOpcode(SPIRV::OpExecutionMode); | ||
| addOpsFromMDNode(cast<MDNode>(Node->getOperand(i)), Inst, MAI); | ||
| outputMCInst(Inst); | ||
| } | ||
| outputFPFastMathDefaultInfo(); | ||
| } | ||
| for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) { | ||
| const Function &F = *FI; | ||
|
|
@@ -551,12 +568,53 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) { | |
| } | ||
| if (ST->isKernel() && !M.getNamedMetadata("spirv.ExecutionMode") && | ||
| !M.getNamedMetadata("opencl.enable.FP_CONTRACT")) { | ||
| MCInst Inst; | ||
| Inst.setOpcode(SPIRV::OpExecutionMode); | ||
| Inst.addOperand(MCOperand::createReg(FReg)); | ||
| unsigned EM = static_cast<unsigned>(SPIRV::ExecutionMode::ContractionOff); | ||
| Inst.addOperand(MCOperand::createImm(EM)); | ||
| outputMCInst(Inst); | ||
| if (ST->canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) { | ||
| // When SPV_KHR_float_controls2 is enabled, ContractionOff is | ||
| // deprecated. We need to use FPFastMathDefault with the appropriate | ||
| // flags instead. Since FPFastMathDefault takes a target type, we need | ||
| // to emit it for each floating-point type that exists in the module | ||
| // to match the effect of ContractionOff. As of now, there are 4 FP | ||
| // types: fp16, fp32 and fp64. | ||
| for (const MachineInstr *MI : | ||
| MAI->getMSInstrs(SPIRV::MB_TypeConstVars)) { | ||
| // Skip if the instruction is not OpTypeFloat. | ||
| if (MI->getOpcode() != SPIRV::OpTypeFloat) | ||
|
||
| continue; | ||
|
|
||
| // Skip if the target type is not fp16, fp32, fp64. | ||
| const unsigned OpTypeFloatSize = MI->getOperand(1).getImm(); | ||
| if (OpTypeFloatSize != 16 && OpTypeFloatSize != 32 && | ||
| OpTypeFloatSize != 64) { | ||
| continue; | ||
| } | ||
|
|
||
| MCInst Inst; | ||
| Inst.setOpcode(SPIRV::OpExecutionMode); | ||
Keenuts marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| Inst.addOperand(MCOperand::createReg(FReg)); | ||
| unsigned EM = | ||
| static_cast<unsigned>(SPIRV::ExecutionMode::FPFastMathDefault); | ||
| Inst.addOperand(MCOperand::createImm(EM)); | ||
| const MachineFunction *MF = MI->getMF(); | ||
| MCRegister TypeReg = | ||
| MAI->getRegisterAlias(MF, MI->getOperand(0).getReg()); | ||
| Inst.addOperand(MCOperand::createReg(TypeReg)); | ||
| // We only end up here because there is no "spirv.ExecutionMode" | ||
| // metadata, so that means no FPFastMathDefault. Therefore, we only | ||
| // need to make sure AllowContract is set to 0, as the rest of flags. | ||
| // We still need to emit the OpExecutionMode instruction, otherwise | ||
| // it's up to the client API to define the flags. | ||
| Inst.addOperand(MCOperand::createImm(SPIRV::FPFastMathMode::None)); | ||
| outputMCInst(Inst); | ||
| } | ||
| } else { | ||
| MCInst Inst; | ||
| Inst.setOpcode(SPIRV::OpExecutionMode); | ||
| Inst.addOperand(MCOperand::createReg(FReg)); | ||
| unsigned EM = | ||
| static_cast<unsigned>(SPIRV::ExecutionMode::ContractionOff); | ||
| Inst.addOperand(MCOperand::createImm(EM)); | ||
| outputMCInst(Inst); | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -603,6 +661,76 @@ void SPIRVAsmPrinter::outputAnnotations(const Module &M) { | |
| } | ||
| } | ||
|
|
||
| void SPIRVAsmPrinter::outputFPFastMathDefaultInfo() { | ||
| // Collect the SPIRVTypes that are OpTypeFloat. | ||
| std::vector<const MachineInstr *> SPIRVFloatTypes; | ||
| for (const MachineInstr *MI : MAI->getMSInstrs(SPIRV::MB_TypeConstVars)) { | ||
| // Skip if the instruction is not OpTypeFloat. | ||
| if (MI->getOpcode() != SPIRV::OpTypeFloat) | ||
| continue; | ||
|
|
||
| // Collect the SPIRV type. | ||
| SPIRVFloatTypes.push_back(MI); | ||
| } | ||
|
|
||
| for (const auto &[Func, FPFastMathDefaultInfoVec] : | ||
| MAI->FPFastMathDefaultInfoMap) { | ||
| if (FPFastMathDefaultInfoVec.empty()) | ||
| continue; | ||
|
|
||
| for (const MachineInstr *MI : SPIRVFloatTypes) { | ||
| unsigned OpTypeFloatSize = MI->getOperand(1).getImm(); | ||
| unsigned Index = computeFPFastMathDefaultInfoVecIndex(OpTypeFloatSize); | ||
| assert(Index < FPFastMathDefaultInfoVec.size() && | ||
| "Index out of bounds for FPFastMathDefaultInfoVec"); | ||
| const auto &FPFastMathDefaultInfo = FPFastMathDefaultInfoVec[Index]; | ||
| assert(FPFastMathDefaultInfo.Ty && | ||
| "Expected target type for FPFastMathDefaultInfo"); | ||
| assert(FPFastMathDefaultInfo.Ty->getScalarSizeInBits() == | ||
| OpTypeFloatSize && | ||
| "Mismatched float type size"); | ||
| MCInst Inst; | ||
| Inst.setOpcode(SPIRV::OpExecutionMode); | ||
Keenuts marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| MCRegister FuncReg = MAI->getFuncReg(Func); | ||
| assert(FuncReg.isValid()); | ||
| Inst.addOperand(MCOperand::createReg(FuncReg)); | ||
| Inst.addOperand( | ||
| MCOperand::createImm(SPIRV::ExecutionMode::FPFastMathDefault)); | ||
| MCRegister TypeReg = | ||
| MAI->getRegisterAlias(MI->getMF(), MI->getOperand(0).getReg()); | ||
| Inst.addOperand(MCOperand::createReg(TypeReg)); | ||
| unsigned Flags = FPFastMathDefaultInfo.FastMathFlags; | ||
| if (FPFastMathDefaultInfo.ContractionOff && | ||
| (Flags & SPIRV::FPFastMathMode::AllowContract) && | ||
| FPFastMathDefaultInfo.FPFastMathDefault) | ||
Keenuts marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| report_fatal_error( | ||
| "Conflicting FPFastMathFlags: ContractionOff and AllowContract"); | ||
|
|
||
| if (FPFastMathDefaultInfo.SignedZeroInfNanPreserve && | ||
| !(Flags & | ||
| (SPIRV::FPFastMathMode::NotNaN | SPIRV::FPFastMathMode::NotInf | | ||
| SPIRV::FPFastMathMode::NSZ))) { | ||
| if (FPFastMathDefaultInfo.FPFastMathDefault) | ||
| report_fatal_error("Conflicting FPFastMathFlags: " | ||
Keenuts marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| "SignedZeroInfNanPreserve but at least one of " | ||
| "NotNaN/NotInf/NSZ is disabled."); | ||
|
|
||
| Flags |= SPIRV::FPFastMathMode::NotNaN | SPIRV::FPFastMathMode::NotInf | | ||
| SPIRV::FPFastMathMode::NSZ; | ||
Keenuts marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| // Don't emit if none of the execution modes was used. | ||
| if (Flags == SPIRV::FPFastMathMode::None && | ||
| !FPFastMathDefaultInfo.ContractionOff && | ||
| !FPFastMathDefaultInfo.SignedZeroInfNanPreserve && | ||
| !FPFastMathDefaultInfo.FPFastMathDefault) | ||
| continue; | ||
| Inst.addOperand(MCOperand::createImm(Flags)); | ||
Keenuts marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| outputMCInst(Inst); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| void SPIRVAsmPrinter::outputModuleSections() { | ||
| const Module *M = MMI->getModule(); | ||
| // Get the global subtarget to output module-level info. | ||
|
|
@@ -611,15 +739,17 @@ void SPIRVAsmPrinter::outputModuleSections() { | |
| MAI = &SPIRVModuleAnalysis::MAI; | ||
| assert(ST && TII && MAI && M && "Module analysis is required"); | ||
| // Output instructions according to the Logical Layout of a Module: | ||
| // 1,2. All OpCapability instructions, then optional OpExtension instructions. | ||
| // 1,2. All OpCapability instructions, then optional OpExtension | ||
| // instructions. | ||
| outputGlobalRequirements(); | ||
| // 3. Optional OpExtInstImport instructions. | ||
| outputOpExtInstImports(*M); | ||
| // 4. The single required OpMemoryModel instruction. | ||
| outputOpMemoryModel(); | ||
| // 5. All entry point declarations, using OpEntryPoint. | ||
| outputEntryPoints(); | ||
| // 6. Execution-mode declarations, using OpExecutionMode or OpExecutionModeId. | ||
| // 6. Execution-mode declarations, using OpExecutionMode or | ||
| // OpExecutionModeId. | ||
| outputExecutionMode(*M); | ||
| // 7a. Debug: all OpString, OpSourceExtension, OpSource, and | ||
| // OpSourceContinued, without forward references. | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.