-
Notifications
You must be signed in to change notification settings - Fork 15k
[SPIR-V] Implement SPV_KHR_float_controls2 #146941
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 19 commits
cbe7a07
4c325ac
a664820
0648840
34c0628
1d87705
de07e83
3edf82a
c0ef106
97720d5
3792464
eccd27a
868c6a4
21c17d1
cfdc210
9b8b43d
a353295
f1403d6
37a3abe
d76db3e
00548e5
f2aefbf
6a77fcd
848611a
527311d
bde0292
2295de4
6420f43
0bf9ca6
5f26101
368793f
cb1065e
d337ef2
c19d296
c9d226c
b5ce97a
0cd1923
c973c1a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -80,6 +80,7 @@ class SPIRVAsmPrinter : public AsmPrinter { | |||||
void outputExecutionMode(const Module &M); | ||||||
void outputAnnotations(const Module &M); | ||||||
void outputModuleSections(); | ||||||
void outputFPFastMathDefaultInfo(); | ||||||
bool isHidden() { | ||||||
return MF->getFunction() | ||||||
.getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME) | ||||||
|
@@ -497,11 +498,27 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) { | |||||
NamedMDNode *Node = M.getNamedMetadata("spirv.ExecutionMode"); | ||||||
if (Node) { | ||||||
for (unsigned i = 0; i < Node->getNumOperands(); i++) { | ||||||
// If SPV_KHR_float_controls2 is enabled and we find any of | ||||||
// FPFastMathDefault, ContractionOff or SignedZeroInfNanPreserve execution | ||||||
// modes, skip it, it'll be done somewhere else. | ||||||
if (ST->canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) { | ||||||
const auto EM = | ||||||
cast<ConstantInt>( | ||||||
cast<ConstantAsMetadata>((Node->getOperand(i))->getOperand(1)) | ||||||
->getValue()) | ||||||
->getZExtValue(); | ||||||
if (EM == SPIRV::ExecutionMode::FPFastMathDefault || | ||||||
EM == SPIRV::ExecutionMode::ContractionOff || | ||||||
EM == SPIRV::ExecutionMode::SignedZeroInfNanPreserve) | ||||||
continue; | ||||||
} | ||||||
|
||||||
MCInst Inst; | ||||||
Inst.setOpcode(SPIRV::OpExecutionMode); | ||||||
addOpsFromMDNode(cast<MDNode>(Node->getOperand(i)), Inst, MAI); | ||||||
outputMCInst(Inst); | ||||||
} | ||||||
outputFPFastMathDefaultInfo(); | ||||||
} | ||||||
for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) { | ||||||
const Function &F = *FI; | ||||||
|
@@ -551,12 +568,53 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) { | |||||
} | ||||||
if (ST->isKernel() && !M.getNamedMetadata("spirv.ExecutionMode") && | ||||||
!M.getNamedMetadata("opencl.enable.FP_CONTRACT")) { | ||||||
MCInst Inst; | ||||||
Inst.setOpcode(SPIRV::OpExecutionMode); | ||||||
Inst.addOperand(MCOperand::createReg(FReg)); | ||||||
unsigned EM = static_cast<unsigned>(SPIRV::ExecutionMode::ContractionOff); | ||||||
Inst.addOperand(MCOperand::createImm(EM)); | ||||||
outputMCInst(Inst); | ||||||
if (ST->canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) { | ||||||
// When SPV_KHR_float_controls2 is enabled, ContractionOff is | ||||||
// deprecated. We need to use FPFastMathDefault with the appropriate | ||||||
// flags instead. Since FPFastMathDefault takes a target type, we need | ||||||
// to emit it for each floating-point type that exists in the module | ||||||
// to match the effect of ContractionOff. As of now, there are 4 FP | ||||||
// types: fp16, fp32 and fp64. | ||||||
for (const MachineInstr *MI : | ||||||
MAI->getMSInstrs(SPIRV::MB_TypeConstVars)) { | ||||||
// Skip if the instruction is not OpTypeFloat. | ||||||
if (MI->getOpcode() != SPIRV::OpTypeFloat) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should also handle vector of floats. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Somewhere should also be a OpTypeMatrix check, but I'm unsure, how SPIR-V backend works with shader matrices. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, this is the wording in the spec:
So I think we should not. However, I will double-check if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are right, thanks! Yet, would be nice to see a test for vector variant of annotated floating point arithmetic operation to check correctness of per-instruction floating controls. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, absolutely. Actually, adding the test you suggested I discovered that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done in f2aefbf. |
||||||
continue; | ||||||
|
||||||
// Skip if the target type is not fp16, fp32, fp64. | ||||||
const unsigned OpTypeFloatSize = MI->getOperand(1).getImm(); | ||||||
if (OpTypeFloatSize != 16 && OpTypeFloatSize != 32 && | ||||||
OpTypeFloatSize != 64) { | ||||||
continue; | ||||||
} | ||||||
|
||||||
MCInst Inst; | ||||||
Inst.setOpcode(SPIRV::OpExecutionMode); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Float controls requires
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||||||
Inst.addOperand(MCOperand::createReg(FReg)); | ||||||
unsigned EM = | ||||||
static_cast<unsigned>(SPIRV::ExecutionMode::FPFastMathDefault); | ||||||
Inst.addOperand(MCOperand::createImm(EM)); | ||||||
const MachineFunction *MF = MI->getMF(); | ||||||
MCRegister TypeReg = | ||||||
MAI->getRegisterAlias(MF, MI->getOperand(0).getReg()); | ||||||
Inst.addOperand(MCOperand::createReg(TypeReg)); | ||||||
// We only end up here because there is no "spirv.ExecutionMode" | ||||||
// metadata, so that means no FPFastMathDefault. Therefore, we only | ||||||
// need to make sure AllowContract is set to 0, as the rest of flags. | ||||||
// We still need to emit the OpExecutionMode instruction, otherwise | ||||||
// it's up to the client API to define the flags. | ||||||
Inst.addOperand(MCOperand::createImm(SPIRV::FPFastMathMode::None)); | ||||||
outputMCInst(Inst); | ||||||
} | ||||||
} else { | ||||||
MCInst Inst; | ||||||
Inst.setOpcode(SPIRV::OpExecutionMode); | ||||||
Inst.addOperand(MCOperand::createReg(FReg)); | ||||||
unsigned EM = | ||||||
static_cast<unsigned>(SPIRV::ExecutionMode::ContractionOff); | ||||||
Inst.addOperand(MCOperand::createImm(EM)); | ||||||
outputMCInst(Inst); | ||||||
} | ||||||
} | ||||||
} | ||||||
} | ||||||
|
@@ -603,6 +661,76 @@ void SPIRVAsmPrinter::outputAnnotations(const Module &M) { | |||||
} | ||||||
} | ||||||
|
||||||
void SPIRVAsmPrinter::outputFPFastMathDefaultInfo() { | ||||||
// Collect the SPIRVTypes that are OpTypeFloat. | ||||||
std::vector<const MachineInstr *> SPIRVFloatTypes; | ||||||
for (const MachineInstr *MI : MAI->getMSInstrs(SPIRV::MB_TypeConstVars)) { | ||||||
// Skip if the instruction is not OpTypeFloat. | ||||||
if (MI->getOpcode() != SPIRV::OpTypeFloat) | ||||||
continue; | ||||||
|
||||||
// Collect the SPIRV type. | ||||||
SPIRVFloatTypes.push_back(MI); | ||||||
} | ||||||
|
||||||
for (const auto &[Func, FPFastMathDefaultInfoVec] : | ||||||
MAI->FPFastMathDefaultInfoMap) { | ||||||
if (FPFastMathDefaultInfoVec.empty()) | ||||||
continue; | ||||||
|
||||||
for (const MachineInstr *MI : SPIRVFloatTypes) { | ||||||
unsigned OpTypeFloatSize = MI->getOperand(1).getImm(); | ||||||
unsigned Index = computeFPFastMathDefaultInfoVecIndex(OpTypeFloatSize); | ||||||
assert(Index < FPFastMathDefaultInfoVec.size() && | ||||||
"Index out of bounds for FPFastMathDefaultInfoVec"); | ||||||
const auto &FPFastMathDefaultInfo = FPFastMathDefaultInfoVec[Index]; | ||||||
assert(FPFastMathDefaultInfo.Ty && | ||||||
"Expected target type for FPFastMathDefaultInfo"); | ||||||
assert(FPFastMathDefaultInfo.Ty->getScalarSizeInBits() == | ||||||
OpTypeFloatSize && | ||||||
"Mismatched float type size"); | ||||||
MCInst Inst; | ||||||
Inst.setOpcode(SPIRV::OpExecutionMode); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||||||
MCRegister FuncReg = MAI->getFuncReg(Func); | ||||||
assert(FuncReg.isValid()); | ||||||
Inst.addOperand(MCOperand::createReg(FuncReg)); | ||||||
Inst.addOperand( | ||||||
MCOperand::createImm(SPIRV::ExecutionMode::FPFastMathDefault)); | ||||||
MCRegister TypeReg = | ||||||
MAI->getRegisterAlias(MI->getMF(), MI->getOperand(0).getReg()); | ||||||
Inst.addOperand(MCOperand::createReg(TypeReg)); | ||||||
unsigned Flags = FPFastMathDefaultInfo.FastMathFlags; | ||||||
if (FPFastMathDefaultInfo.ContractionOff && | ||||||
(Flags & SPIRV::FPFastMathMode::AllowContract) && | ||||||
FPFastMathDefaultInfo.FPFastMathDefault) | ||||||
Keenuts marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
report_fatal_error( | ||||||
"Conflicting FPFastMathFlags: ContractionOff and AllowContract"); | ||||||
|
||||||
if (FPFastMathDefaultInfo.SignedZeroInfNanPreserve && | ||||||
!(Flags & | ||||||
(SPIRV::FPFastMathMode::NotNaN | SPIRV::FPFastMathMode::NotInf | | ||||||
SPIRV::FPFastMathMode::NSZ))) { | ||||||
if (FPFastMathDefaultInfo.FPFastMathDefault) | ||||||
report_fatal_error("Conflicting FPFastMathFlags: " | ||||||
Keenuts marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
"SignedZeroInfNanPreserve but at least one of " | ||||||
"NotNaN/NotInf/NSZ is disabled."); | ||||||
|
||||||
Flags |= SPIRV::FPFastMathMode::NotNaN | SPIRV::FPFastMathMode::NotInf | | ||||||
SPIRV::FPFastMathMode::NSZ; | ||||||
Keenuts marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
} | ||||||
|
||||||
// Don't emit if none of the execution modes was used. | ||||||
if (Flags == SPIRV::FPFastMathMode::None && | ||||||
!FPFastMathDefaultInfo.ContractionOff && | ||||||
!FPFastMathDefaultInfo.SignedZeroInfNanPreserve && | ||||||
!FPFastMathDefaultInfo.FPFastMathDefault) | ||||||
continue; | ||||||
Inst.addOperand(MCOperand::createImm(Flags)); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you have an issue to track the OpConstant emission intead of the immediate value? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for bringing this up! I haven't noticed when reading the spec this had to be a constant. I will address it in this PR, since going for the immediate is directly illegal according to the spec. |
||||||
outputMCInst(Inst); | ||||||
} | ||||||
} | ||||||
} | ||||||
|
||||||
void SPIRVAsmPrinter::outputModuleSections() { | ||||||
const Module *M = MMI->getModule(); | ||||||
// Get the global subtarget to output module-level info. | ||||||
|
@@ -611,15 +739,17 @@ void SPIRVAsmPrinter::outputModuleSections() { | |||||
MAI = &SPIRVModuleAnalysis::MAI; | ||||||
assert(ST && TII && MAI && M && "Module analysis is required"); | ||||||
// Output instructions according to the Logical Layout of a Module: | ||||||
// 1,2. All OpCapability instructions, then optional OpExtension instructions. | ||||||
// 1,2. All OpCapability instructions, then optional OpExtension | ||||||
// instructions. | ||||||
outputGlobalRequirements(); | ||||||
// 3. Optional OpExtInstImport instructions. | ||||||
outputOpExtInstImports(*M); | ||||||
// 4. The single required OpMemoryModel instruction. | ||||||
outputOpMemoryModel(); | ||||||
// 5. All entry point declarations, using OpEntryPoint. | ||||||
outputEntryPoints(); | ||||||
// 6. Execution-mode declarations, using OpExecutionMode or OpExecutionModeId. | ||||||
// 6. Execution-mode declarations, using OpExecutionMode or | ||||||
// OpExecutionModeId. | ||||||
outputExecutionMode(*M); | ||||||
// 7a. Debug: all OpString, OpSourceExtension, OpSource, and | ||||||
// OpSourceContinued, without forward references. | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Formatting is screwed, I suspect an empty line is missing before the table.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in f2aefbf.