Skip to content

Commit ebcf025

Browse files
[SPIR-V] Implement SPV_KHR_float_controls2 (#146941)
Implementation of [SPV_KHR_float_controls2](https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_float_controls2.html) extension, and corresponding tests. Some of the tests make use of `!spirv.ExecutionMode` LLVM named metadata. This is because some SPIR-V instructions don't have a direct equivalent in LLVM IR, so the SPIR-V Target uses different LLVM named metadata to convey the necessary information. Below, you will find an example from one of the newly added tests: ``` !spirv.ExecutionMode = !{!19, !20, !21, !22, !23, !24, !25, !26, !27} !19 = !{ptr @k_float_controls_float, i32 6028, float poison, i32 131079} !20 = !{ptr @k_float_controls_all, i32 6028, float poison, i32 131079} !21 = !{ptr @k_float_controls_float, i32 31} !22 = !{ptr @k_float_controls_all, i32 31} !23 = !{ptr @k_float_controls_float, i32 4461, i32 32} !24 = !{ptr @k_float_controls_all, i32 4461, i32 16} !25 = !{ptr @k_float_controls_all, i32 4461, i32 32} !26 = !{ptr @k_float_controls_all, i32 4461, i32 64} !27 = !{ptr @k_float_controls_all, i32 4461, i32 128} ``` `!spirv.ExecutionMode` contains a list of metadata nodes, and each of them specifies the required operands for expressing a particular `OpExecutionMode` instruction in SPIR-V. For example, `!19 = !{ptr @k_float_controls_float, i32 6028, float poison, i32 131079}` will be lowered to `OpExecutionMode [[k_float_controls_float_ID]] FPFastMathDefault [[float_type_ID]] 131079`. --------- Co-authored-by: Dmitry Sidorov <[email protected]>
1 parent e6425a7 commit ebcf025

22 files changed

+1324
-56
lines changed

llvm/docs/SPIRVUsage.rst

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ Below is a list of supported SPIR-V extensions, sorted alphabetically by their e
232232
* - ``SPV_INTEL_int4``
233233
- Adds support for 4-bit integer type, and allow this type to be used in cooperative matrices.
234234
* - ``SPV_KHR_float_controls2``
235-
- Adds ability to specify the floating-point environment in shaders. It can be used on whole modules and individual instructions.
235+
- Adds execution modes and decorations to control floating-point computations in both kernels and shaders. It can be used on whole modules and individual instructions.
236236

237237
SPIR-V representation in LLVM IR
238238
================================
@@ -589,3 +589,31 @@ Group and Subgroup Operations
589589
For workgroup and subgroup operations, LLVM uses function calls to represent SPIR-V's
590590
group-based instructions. These builtins facilitate group synchronization, data sharing,
591591
and collective operations essential for efficient parallel computation.
592+
593+
SPIR-V Instructions Mapped to LLVM Metadata
594+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
595+
Some SPIR-V instructions don't have a direct equivalent in the LLVM IR language. To
596+
address this, the SPIR-V Target uses different specific LLVM named metadata to convey
597+
the necessary information. The SPIR-V specification allows multiple module-scope
598+
instructions, where as LLVM named metadata must be unique. Therefore, the encoding of
599+
such instructions has the following format:
600+
601+
.. code-block:: llvm
602+
603+
!spirv.<OpCodeName> = !{!<InstructionMetadata1>, !<InstructionMetadata2>, ..}
604+
!<InstructionMetadata1> = !{<Operand1>, <Operand2>, ..}
605+
!<InstructionMetadata2> = !{<Operand1>, <Operand2>, ..}
606+
607+
Below, you will find the mappings between SPIR-V instruction and their corresponding
608+
LLVM IR representations.
609+
610+
+--------------------+---------------------------------------------------------+
611+
| SPIR-V instruction | LLVM IR |
612+
+====================+=========================================================+
613+
| OpExecutionMode | .. code-block:: llvm |
614+
| | |
615+
| | !spirv.ExecutionMode = !{!0} |
616+
| | !0 = !{void @worker, i32 30, i32 262149} |
617+
| | ; Set execution mode with id 30 (VecTypeHint) and |
618+
| | ; literal `262149` operand. |
619+
+--------------------+---------------------------------------------------------+

llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp

Lines changed: 194 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ class SPIRVAsmPrinter : public AsmPrinter {
8181
void outputExecutionMode(const Module &M);
8282
void outputAnnotations(const Module &M);
8383
void outputModuleSections();
84+
void outputFPFastMathDefaultInfo();
8485
bool isHidden() {
8586
return MF->getFunction()
8687
.getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME)
@@ -498,11 +499,27 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) {
498499
NamedMDNode *Node = M.getNamedMetadata("spirv.ExecutionMode");
499500
if (Node) {
500501
for (unsigned i = 0; i < Node->getNumOperands(); i++) {
502+
// If SPV_KHR_float_controls2 is enabled and we find any of
503+
// FPFastMathDefault, ContractionOff or SignedZeroInfNanPreserve execution
504+
// modes, skip it, it'll be done somewhere else.
505+
if (ST->canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) {
506+
const auto EM =
507+
cast<ConstantInt>(
508+
cast<ConstantAsMetadata>((Node->getOperand(i))->getOperand(1))
509+
->getValue())
510+
->getZExtValue();
511+
if (EM == SPIRV::ExecutionMode::FPFastMathDefault ||
512+
EM == SPIRV::ExecutionMode::ContractionOff ||
513+
EM == SPIRV::ExecutionMode::SignedZeroInfNanPreserve)
514+
continue;
515+
}
516+
501517
MCInst Inst;
502518
Inst.setOpcode(SPIRV::OpExecutionMode);
503519
addOpsFromMDNode(cast<MDNode>(Node->getOperand(i)), Inst, MAI);
504520
outputMCInst(Inst);
505521
}
522+
outputFPFastMathDefaultInfo();
506523
}
507524
for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) {
508525
const Function &F = *FI;
@@ -552,12 +569,84 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) {
552569
}
553570
if (ST->isKernel() && !M.getNamedMetadata("spirv.ExecutionMode") &&
554571
!M.getNamedMetadata("opencl.enable.FP_CONTRACT")) {
555-
MCInst Inst;
556-
Inst.setOpcode(SPIRV::OpExecutionMode);
557-
Inst.addOperand(MCOperand::createReg(FReg));
558-
unsigned EM = static_cast<unsigned>(SPIRV::ExecutionMode::ContractionOff);
559-
Inst.addOperand(MCOperand::createImm(EM));
560-
outputMCInst(Inst);
572+
if (ST->canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) {
573+
// When SPV_KHR_float_controls2 is enabled, ContractionOff is
574+
// deprecated. We need to use FPFastMathDefault with the appropriate
575+
// flags instead. Since FPFastMathDefault takes a target type, we need
576+
// to emit it for each floating-point type that exists in the module
577+
// to match the effect of ContractionOff. As of now, there are 3 FP
578+
// types: fp16, fp32 and fp64.
579+
580+
// We only end up here because there is no "spirv.ExecutionMode"
581+
// metadata, so that means no FPFastMathDefault. Therefore, we only
582+
// need to make sure AllowContract is set to 0, as the rest of flags.
583+
// We still need to emit the OpExecutionMode instruction, otherwise
584+
// it's up to the client API to define the flags. Therefore, we need
585+
// to find the constant with 0 value.
586+
587+
// Collect the SPIRVTypes for fp16, fp32, and fp64 and the constant of
588+
// type int32 with 0 value to represent the FP Fast Math Mode.
589+
std::vector<const MachineInstr *> SPIRVFloatTypes;
590+
const MachineInstr *ConstZero = nullptr;
591+
for (const MachineInstr *MI :
592+
MAI->getMSInstrs(SPIRV::MB_TypeConstVars)) {
593+
// Skip if the instruction is not OpTypeFloat or OpConstant.
594+
unsigned OpCode = MI->getOpcode();
595+
if (OpCode != SPIRV::OpTypeFloat && OpCode != SPIRV::OpConstantNull)
596+
continue;
597+
598+
// Collect the SPIRV type if it's a float.
599+
if (OpCode == SPIRV::OpTypeFloat) {
600+
// Skip if the target type is not fp16, fp32, fp64.
601+
const unsigned OpTypeFloatSize = MI->getOperand(1).getImm();
602+
if (OpTypeFloatSize != 16 && OpTypeFloatSize != 32 &&
603+
OpTypeFloatSize != 64) {
604+
continue;
605+
}
606+
SPIRVFloatTypes.push_back(MI);
607+
} else {
608+
// Check if the constant is int32, if not skip it.
609+
const MachineRegisterInfo &MRI = MI->getMF()->getRegInfo();
610+
MachineInstr *TypeMI = MRI.getVRegDef(MI->getOperand(1).getReg());
611+
if (!TypeMI || TypeMI->getOperand(1).getImm() != 32)
612+
continue;
613+
614+
ConstZero = MI;
615+
}
616+
}
617+
618+
// When SPV_KHR_float_controls2 is enabled, ContractionOff is
619+
// deprecated. We need to use FPFastMathDefault with the appropriate
620+
// flags instead. Since FPFastMathDefault takes a target type, we need
621+
// to emit it for each floating-point type that exists in the module
622+
// to match the effect of ContractionOff. As of now, there are 3 FP
623+
// types: fp16, fp32 and fp64.
624+
for (const MachineInstr *MI : SPIRVFloatTypes) {
625+
MCInst Inst;
626+
Inst.setOpcode(SPIRV::OpExecutionModeId);
627+
Inst.addOperand(MCOperand::createReg(FReg));
628+
unsigned EM =
629+
static_cast<unsigned>(SPIRV::ExecutionMode::FPFastMathDefault);
630+
Inst.addOperand(MCOperand::createImm(EM));
631+
const MachineFunction *MF = MI->getMF();
632+
MCRegister TypeReg =
633+
MAI->getRegisterAlias(MF, MI->getOperand(0).getReg());
634+
Inst.addOperand(MCOperand::createReg(TypeReg));
635+
assert(ConstZero && "There should be a constant zero.");
636+
MCRegister ConstReg = MAI->getRegisterAlias(
637+
ConstZero->getMF(), ConstZero->getOperand(0).getReg());
638+
Inst.addOperand(MCOperand::createReg(ConstReg));
639+
outputMCInst(Inst);
640+
}
641+
} else {
642+
MCInst Inst;
643+
Inst.setOpcode(SPIRV::OpExecutionMode);
644+
Inst.addOperand(MCOperand::createReg(FReg));
645+
unsigned EM =
646+
static_cast<unsigned>(SPIRV::ExecutionMode::ContractionOff);
647+
Inst.addOperand(MCOperand::createImm(EM));
648+
outputMCInst(Inst);
649+
}
561650
}
562651
}
563652
}
@@ -606,6 +695,101 @@ void SPIRVAsmPrinter::outputAnnotations(const Module &M) {
606695
}
607696
}
608697

698+
void SPIRVAsmPrinter::outputFPFastMathDefaultInfo() {
699+
// Collect the SPIRVTypes that are OpTypeFloat and the constants of type
700+
// int32, that might be used as FP Fast Math Mode.
701+
std::vector<const MachineInstr *> SPIRVFloatTypes;
702+
// Hashtable to associate immediate values with the constant holding them.
703+
std::unordered_map<int, const MachineInstr *> ConstMap;
704+
for (const MachineInstr *MI : MAI->getMSInstrs(SPIRV::MB_TypeConstVars)) {
705+
// Skip if the instruction is not OpTypeFloat or OpConstant.
706+
unsigned OpCode = MI->getOpcode();
707+
if (OpCode != SPIRV::OpTypeFloat && OpCode != SPIRV::OpConstantI &&
708+
OpCode != SPIRV::OpConstantNull)
709+
continue;
710+
711+
// Collect the SPIRV type if it's a float.
712+
if (OpCode == SPIRV::OpTypeFloat) {
713+
SPIRVFloatTypes.push_back(MI);
714+
} else {
715+
// Check if the constant is int32, if not skip it.
716+
const MachineRegisterInfo &MRI = MI->getMF()->getRegInfo();
717+
MachineInstr *TypeMI = MRI.getVRegDef(MI->getOperand(1).getReg());
718+
if (!TypeMI || TypeMI->getOpcode() != SPIRV::OpTypeInt ||
719+
TypeMI->getOperand(1).getImm() != 32)
720+
continue;
721+
722+
if (OpCode == SPIRV::OpConstantI)
723+
ConstMap[MI->getOperand(2).getImm()] = MI;
724+
else
725+
ConstMap[0] = MI;
726+
}
727+
}
728+
729+
for (const auto &[Func, FPFastMathDefaultInfoVec] :
730+
MAI->FPFastMathDefaultInfoMap) {
731+
if (FPFastMathDefaultInfoVec.empty())
732+
continue;
733+
734+
for (const MachineInstr *MI : SPIRVFloatTypes) {
735+
unsigned OpTypeFloatSize = MI->getOperand(1).getImm();
736+
unsigned Index = SPIRV::FPFastMathDefaultInfoVector::
737+
computeFPFastMathDefaultInfoVecIndex(OpTypeFloatSize);
738+
assert(Index < FPFastMathDefaultInfoVec.size() &&
739+
"Index out of bounds for FPFastMathDefaultInfoVec");
740+
const auto &FPFastMathDefaultInfo = FPFastMathDefaultInfoVec[Index];
741+
assert(FPFastMathDefaultInfo.Ty &&
742+
"Expected target type for FPFastMathDefaultInfo");
743+
assert(FPFastMathDefaultInfo.Ty->getScalarSizeInBits() ==
744+
OpTypeFloatSize &&
745+
"Mismatched float type size");
746+
MCInst Inst;
747+
Inst.setOpcode(SPIRV::OpExecutionModeId);
748+
MCRegister FuncReg = MAI->getFuncReg(Func);
749+
assert(FuncReg.isValid());
750+
Inst.addOperand(MCOperand::createReg(FuncReg));
751+
Inst.addOperand(
752+
MCOperand::createImm(SPIRV::ExecutionMode::FPFastMathDefault));
753+
MCRegister TypeReg =
754+
MAI->getRegisterAlias(MI->getMF(), MI->getOperand(0).getReg());
755+
Inst.addOperand(MCOperand::createReg(TypeReg));
756+
unsigned Flags = FPFastMathDefaultInfo.FastMathFlags;
757+
if (FPFastMathDefaultInfo.ContractionOff &&
758+
(Flags & SPIRV::FPFastMathMode::AllowContract))
759+
report_fatal_error(
760+
"Conflicting FPFastMathFlags: ContractionOff and AllowContract");
761+
762+
if (FPFastMathDefaultInfo.SignedZeroInfNanPreserve &&
763+
!(Flags &
764+
(SPIRV::FPFastMathMode::NotNaN | SPIRV::FPFastMathMode::NotInf |
765+
SPIRV::FPFastMathMode::NSZ))) {
766+
if (FPFastMathDefaultInfo.FPFastMathDefault)
767+
report_fatal_error("Conflicting FPFastMathFlags: "
768+
"SignedZeroInfNanPreserve but at least one of "
769+
"NotNaN/NotInf/NSZ is enabled.");
770+
}
771+
772+
// Don't emit if none of the execution modes was used.
773+
if (Flags == SPIRV::FPFastMathMode::None &&
774+
!FPFastMathDefaultInfo.ContractionOff &&
775+
!FPFastMathDefaultInfo.SignedZeroInfNanPreserve &&
776+
!FPFastMathDefaultInfo.FPFastMathDefault)
777+
continue;
778+
779+
// Retrieve the constant instruction for the immediate value.
780+
auto It = ConstMap.find(Flags);
781+
if (It == ConstMap.end())
782+
report_fatal_error("Expected constant instruction for FP Fast Math "
783+
"Mode operand of FPFastMathDefault execution mode.");
784+
const MachineInstr *ConstMI = It->second;
785+
MCRegister ConstReg = MAI->getRegisterAlias(
786+
ConstMI->getMF(), ConstMI->getOperand(0).getReg());
787+
Inst.addOperand(MCOperand::createReg(ConstReg));
788+
outputMCInst(Inst);
789+
}
790+
}
791+
}
792+
609793
void SPIRVAsmPrinter::outputModuleSections() {
610794
const Module *M = MMI->getModule();
611795
// Get the global subtarget to output module-level info.
@@ -614,15 +798,17 @@ void SPIRVAsmPrinter::outputModuleSections() {
614798
MAI = &SPIRVModuleAnalysis::MAI;
615799
assert(ST && TII && MAI && M && "Module analysis is required");
616800
// Output instructions according to the Logical Layout of a Module:
617-
// 1,2. All OpCapability instructions, then optional OpExtension instructions.
801+
// 1,2. All OpCapability instructions, then optional OpExtension
802+
// instructions.
618803
outputGlobalRequirements();
619804
// 3. Optional OpExtInstImport instructions.
620805
outputOpExtInstImports(*M);
621806
// 4. The single required OpMemoryModel instruction.
622807
outputOpMemoryModel();
623808
// 5. All entry point declarations, using OpEntryPoint.
624809
outputEntryPoints();
625-
// 6. Execution-mode declarations, using OpExecutionMode or OpExecutionModeId.
810+
// 6. Execution-mode declarations, using OpExecutionMode or
811+
// OpExecutionModeId.
626812
outputExecutionMode(*M);
627813
// 7a. Debug: all OpString, OpSourceExtension, OpSource, and
628814
// OpSourceContinued, without forward references.

llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,11 +1162,24 @@ static unsigned getNumSizeComponents(SPIRVType *imgType) {
11621162

11631163
static bool generateExtInst(const SPIRV::IncomingCall *Call,
11641164
MachineIRBuilder &MIRBuilder,
1165-
SPIRVGlobalRegistry *GR) {
1165+
SPIRVGlobalRegistry *GR, const CallBase &CB) {
11661166
// Lookup the extended instruction number in the TableGen records.
11671167
const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
11681168
uint32_t Number =
11691169
SPIRV::lookupExtendedBuiltin(Builtin->Name, Builtin->Set)->Number;
1170+
// fmin_common and fmax_common are now deprecated, and we should use fmin and
1171+
// fmax with NotInf and NotNaN flags instead. Keep original number to add
1172+
// later the NoNans and NoInfs flags.
1173+
uint32_t OrigNumber = Number;
1174+
const SPIRVSubtarget &ST =
1175+
cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
1176+
if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2) &&
1177+
(Number == SPIRV::OpenCLExtInst::fmin_common ||
1178+
Number == SPIRV::OpenCLExtInst::fmax_common)) {
1179+
Number = (Number == SPIRV::OpenCLExtInst::fmin_common)
1180+
? SPIRV::OpenCLExtInst::fmin
1181+
: SPIRV::OpenCLExtInst::fmax;
1182+
}
11701183

11711184
// Build extended instruction.
11721185
auto MIB =
@@ -1178,6 +1191,13 @@ static bool generateExtInst(const SPIRV::IncomingCall *Call,
11781191

11791192
for (auto Argument : Call->Arguments)
11801193
MIB.addUse(Argument);
1194+
MIB.getInstr()->copyIRFlags(CB);
1195+
if (OrigNumber == SPIRV::OpenCLExtInst::fmin_common ||
1196+
OrigNumber == SPIRV::OpenCLExtInst::fmax_common) {
1197+
// Add NoNans and NoInfs flags to fmin/fmax instruction.
1198+
MIB.getInstr()->setFlag(MachineInstr::MIFlag::FmNoNans);
1199+
MIB.getInstr()->setFlag(MachineInstr::MIFlag::FmNoInfs);
1200+
}
11811201
return true;
11821202
}
11831203

@@ -2908,7 +2928,7 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
29082928
MachineIRBuilder &MIRBuilder,
29092929
const Register OrigRet, const Type *OrigRetTy,
29102930
const SmallVectorImpl<Register> &Args,
2911-
SPIRVGlobalRegistry *GR) {
2931+
SPIRVGlobalRegistry *GR, const CallBase &CB) {
29122932
LLVM_DEBUG(dbgs() << "Lowering builtin call: " << DemangledCall << "\n");
29132933

29142934
// Lookup the builtin in the TableGen records.
@@ -2931,7 +2951,7 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
29312951
// Match the builtin with implementation based on the grouping.
29322952
switch (Call->Builtin->Group) {
29332953
case SPIRV::Extended:
2934-
return generateExtInst(Call.get(), MIRBuilder, GR);
2954+
return generateExtInst(Call.get(), MIRBuilder, GR, CB);
29352955
case SPIRV::Relational:
29362956
return generateRelationalInst(Call.get(), MIRBuilder, GR);
29372957
case SPIRV::Group:

llvm/lib/Target/SPIRV/SPIRVBuiltins.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
3939
MachineIRBuilder &MIRBuilder,
4040
const Register OrigRet, const Type *OrigRetTy,
4141
const SmallVectorImpl<Register> &Args,
42-
SPIRVGlobalRegistry *GR);
42+
SPIRVGlobalRegistry *GR, const CallBase &CB);
4343

4444
/// Helper function for finding a builtin function attributes
4545
/// by a demangled function name. Defined in SPIRVBuiltins.cpp.

llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -641,9 +641,9 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
641641
GR->getPointerSize()));
642642
}
643643
}
644-
if (auto Res =
645-
SPIRV::lowerBuiltin(DemangledName, ST->getPreferredInstructionSet(),
646-
MIRBuilder, ResVReg, OrigRetTy, ArgVRegs, GR))
644+
if (auto Res = SPIRV::lowerBuiltin(
645+
DemangledName, ST->getPreferredInstructionSet(), MIRBuilder,
646+
ResVReg, OrigRetTy, ArgVRegs, GR, *Info.CB))
647647
return *Res;
648648
}
649649

0 commit comments

Comments
 (0)