Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
cbe7a07
Initial support for SPV_KHR_float_controls2.
maarquitos14 Jun 25, 2025
4c325ac
Fix test failures.
maarquitos14 Jun 30, 2025
a664820
Deprecations only apply when the extension is enabled.
maarquitos14 Jul 3, 2025
0648840
Add extension specific tests.
maarquitos14 Jul 3, 2025
34c0628
Further work to apply deprecations only when extension is enabled.
maarquitos14 Jul 3, 2025
1d87705
Remove leftover comment.
maarquitos14 Jul 3, 2025
de07e83
Yet more work to apply deprecations only when extension is enabled.
maarquitos14 Jul 3, 2025
3edf82a
Renamings and bugfixes.
maarquitos14 Jul 3, 2025
c0ef106
Renamings, bugfixes and undo undesired changes.
maarquitos14 Jul 3, 2025
97720d5
Undo undesired changes.
maarquitos14 Jul 3, 2025
3792464
Merge remote-tracking branch 'origin/main' into maronas/float-controls2
maarquitos14 Jul 4, 2025
eccd27a
Fix after-merge bugs.
maarquitos14 Jul 4, 2025
868c6a4
Fix clang-format issues.
maarquitos14 Jul 4, 2025
21c17d1
Replace undef with poison for placeholders.
maarquitos14 Jul 4, 2025
cfdc210
Add SPIR-V Instructions Mapped to LLVM Metadata section.
maarquitos14 Jul 7, 2025
9b8b43d
Address code review.
maarquitos14 Jul 7, 2025
a353295
Address code review feedback.
maarquitos14 Jul 16, 2025
f1403d6
Fix documentation build.
maarquitos14 Jul 17, 2025
37a3abe
Fix another error in documentation build.
maarquitos14 Jul 17, 2025
d76db3e
Ignore spirv.Decorations metadata related to FP flags.
maarquitos14 Aug 1, 2025
00548e5
Merge remote-tracking branch 'origin/main' into maronas/float-controls2
maarquitos14 Aug 1, 2025
f2aefbf
Address code review feedback.
maarquitos14 Aug 8, 2025
6a77fcd
Merge remote-tracking branch 'origin/main' into maronas/float-controls2
maarquitos14 Aug 8, 2025
848611a
Fix clang-format.
maarquitos14 Aug 8, 2025
527311d
Merge remote-tracking branch 'origin/main' into maronas/float-controls2
maarquitos14 Aug 26, 2025
bde0292
Address reviewer feedback.
maarquitos14 Aug 26, 2025
2295de4
Merge branch 'main' into maronas/float-controls2
maarquitos14 Aug 26, 2025
6420f43
Use constant instruction for flags.
maarquitos14 Sep 2, 2025
0bf9ca6
Enable spirv-val and update tests accordingly.
maarquitos14 Sep 2, 2025
5f26101
Merge remote-tracking branch 'origin/main' into maronas/float-controls2
maarquitos14 Sep 2, 2025
368793f
Merge branch 'maronas/float-controls2-conflict' into maronas/float-co…
maarquitos14 Sep 2, 2025
cb1065e
Rename new function to better represent what it does.
maarquitos14 Sep 2, 2025
d337ef2
Merge branch 'main' into maronas/float-controls2
maarquitos14 Sep 2, 2025
c19d296
Fix clang-format issues.
maarquitos14 Sep 2, 2025
c9d226c
Merge branch 'maronas/float-controls2-conflict' into maronas/float-co…
maarquitos14 Sep 2, 2025
b5ce97a
Remove unused variable.
maarquitos14 Sep 2, 2025
0cd1923
Merge remote-tracking branch 'origin/main' into maronas/float-controls2
maarquitos14 Sep 2, 2025
c973c1a
Fix test failure.
maarquitos14 Sep 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion llvm/docs/SPIRVUsage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ Below is a list of supported SPIR-V extensions, sorted alphabetically by their e
* - ``SPV_INTEL_int4``
- Adds support for 4-bit integer type, and allow this type to be used in cooperative matrices.
* - ``SPV_KHR_float_controls2``
- Adds ability to specify the floating-point environment in shaders. It can be used on whole modules and individual instructions.
- Adds execution modes and decorations to control floating-point computations in both kernels and shaders. It can be used on whole modules and individual instructions.

SPIR-V representation in LLVM IR
================================
Expand Down Expand Up @@ -589,3 +589,31 @@ Group and Subgroup Operations
For workgroup and subgroup operations, LLVM uses function calls to represent SPIR-V's
group-based instructions. These builtins facilitate group synchronization, data sharing,
and collective operations essential for efficient parallel computation.

SPIR-V Instructions Mapped to LLVM Metadata
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Some SPIR-V instructions don't have a direct equivalent in the LLVM IR language. To
address this, the SPIR-V Target uses different specific LLVM named metadata to convey
the necessary information. The SPIR-V specification allows multiple module-scope
instructions, where as LLVM named metadata must be unique. Therefore, the encoding of
such instructions has the following format:

.. code-block:: llvm

!spirv.<OpCodeName> = !{!<InstructionMetadata1>, !<InstructionMetadata2>, ..}
!<InstructionMetadata1> = !{<Operand1>, <Operand2>, ..}
!<InstructionMetadata2> = !{<Operand1>, <Operand2>, ..}

Below, you will find the mappings between SPIR-V instruction and their corresponding
LLVM IR representations.

+--------------------+---------------------------------------------------------+
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Formatting is screwed, I suspect an empty line is missing before the table.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in f2aefbf.

| SPIR-V instruction | LLVM IR |
+====================+=========================================================+
| OpExecutionMode | .. code-block:: llvm |
| | |
| | !spirv.ExecutionMode = !{!0} |
| | !0 = !{void @worker, i32 30, i32 262149} |
| | ; Set execution mode with id 30 (VecTypeHint) and |
| | ; literal `262149` operand. |
+--------------------+---------------------------------------------------------+
201 changes: 193 additions & 8 deletions llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class SPIRVAsmPrinter : public AsmPrinter {
void outputExecutionMode(const Module &M);
void outputAnnotations(const Module &M);
void outputModuleSections();
void outputFPFastMathDefaultInfo();
bool isHidden() {
return MF->getFunction()
.getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME)
Expand Down Expand Up @@ -498,11 +499,27 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) {
NamedMDNode *Node = M.getNamedMetadata("spirv.ExecutionMode");
if (Node) {
for (unsigned i = 0; i < Node->getNumOperands(); i++) {
// If SPV_KHR_float_controls2 is enabled and we find any of
// FPFastMathDefault, ContractionOff or SignedZeroInfNanPreserve execution
// modes, skip it, it'll be done somewhere else.
if (ST->canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) {
const auto EM =
cast<ConstantInt>(
cast<ConstantAsMetadata>((Node->getOperand(i))->getOperand(1))
->getValue())
->getZExtValue();
if (EM == SPIRV::ExecutionMode::FPFastMathDefault ||
EM == SPIRV::ExecutionMode::ContractionOff ||
EM == SPIRV::ExecutionMode::SignedZeroInfNanPreserve)
continue;
}

MCInst Inst;
Inst.setOpcode(SPIRV::OpExecutionMode);
addOpsFromMDNode(cast<MDNode>(Node->getOperand(i)), Inst, MAI);
outputMCInst(Inst);
}
outputFPFastMathDefaultInfo();
}
for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) {
const Function &F = *FI;
Expand Down Expand Up @@ -552,12 +569,84 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) {
}
if (ST->isKernel() && !M.getNamedMetadata("spirv.ExecutionMode") &&
!M.getNamedMetadata("opencl.enable.FP_CONTRACT")) {
MCInst Inst;
Inst.setOpcode(SPIRV::OpExecutionMode);
Inst.addOperand(MCOperand::createReg(FReg));
unsigned EM = static_cast<unsigned>(SPIRV::ExecutionMode::ContractionOff);
Inst.addOperand(MCOperand::createImm(EM));
outputMCInst(Inst);
if (ST->canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) {
// When SPV_KHR_float_controls2 is enabled, ContractionOff is
// deprecated. We need to use FPFastMathDefault with the appropriate
// flags instead. Since FPFastMathDefault takes a target type, we need
// to emit it for each floating-point type that exists in the module
// to match the effect of ContractionOff. As of now, there are 3 FP
// types: fp16, fp32 and fp64.

// We only end up here because there is no "spirv.ExecutionMode"
// metadata, so that means no FPFastMathDefault. Therefore, we only
// need to make sure AllowContract is set to 0, as the rest of flags.
// We still need to emit the OpExecutionMode instruction, otherwise
// it's up to the client API to define the flags. Therefore, we need
// to find the constant with 0 value.

// Collect the SPIRVTypes for fp16, fp32, and fp64 and the constant of
// type int32 with 0 value to represent the FP Fast Math Mode.
std::vector<const MachineInstr *> SPIRVFloatTypes;
const MachineInstr *ConstZero = nullptr;
for (const MachineInstr *MI :
MAI->getMSInstrs(SPIRV::MB_TypeConstVars)) {
// Skip if the instruction is not OpTypeFloat or OpConstant.
unsigned OpCode = MI->getOpcode();
if (OpCode != SPIRV::OpTypeFloat && OpCode != SPIRV::OpConstantNull)
continue;

// Collect the SPIRV type if it's a float.
if (OpCode == SPIRV::OpTypeFloat) {
// Skip if the target type is not fp16, fp32, fp64.
const unsigned OpTypeFloatSize = MI->getOperand(1).getImm();
if (OpTypeFloatSize != 16 && OpTypeFloatSize != 32 &&
OpTypeFloatSize != 64) {
continue;
}
SPIRVFloatTypes.push_back(MI);
} else {
// Check if the constant is int32, if not skip it.
const MachineRegisterInfo &MRI = MI->getMF()->getRegInfo();
MachineInstr *TypeMI = MRI.getVRegDef(MI->getOperand(1).getReg());
if (!TypeMI || TypeMI->getOperand(1).getImm() != 32)
continue;

ConstZero = MI;
}
}

// When SPV_KHR_float_controls2 is enabled, ContractionOff is
// deprecated. We need to use FPFastMathDefault with the appropriate
// flags instead. Since FPFastMathDefault takes a target type, we need
// to emit it for each floating-point type that exists in the module
// to match the effect of ContractionOff. As of now, there are 3 FP
// types: fp16, fp32 and fp64.
for (const MachineInstr *MI : SPIRVFloatTypes) {
MCInst Inst;
Inst.setOpcode(SPIRV::OpExecutionModeId);
Inst.addOperand(MCOperand::createReg(FReg));
unsigned EM =
static_cast<unsigned>(SPIRV::ExecutionMode::FPFastMathDefault);
Inst.addOperand(MCOperand::createImm(EM));
const MachineFunction *MF = MI->getMF();
MCRegister TypeReg =
MAI->getRegisterAlias(MF, MI->getOperand(0).getReg());
Inst.addOperand(MCOperand::createReg(TypeReg));
assert(ConstZero && "There should be a constant zero.");
MCRegister ConstReg = MAI->getRegisterAlias(
ConstZero->getMF(), ConstZero->getOperand(0).getReg());
Inst.addOperand(MCOperand::createReg(ConstReg));
outputMCInst(Inst);
}
} else {
MCInst Inst;
Inst.setOpcode(SPIRV::OpExecutionMode);
Inst.addOperand(MCOperand::createReg(FReg));
unsigned EM =
static_cast<unsigned>(SPIRV::ExecutionMode::ContractionOff);
Inst.addOperand(MCOperand::createImm(EM));
outputMCInst(Inst);
}
}
}
}
Expand Down Expand Up @@ -606,6 +695,100 @@ void SPIRVAsmPrinter::outputAnnotations(const Module &M) {
}
}

void SPIRVAsmPrinter::outputFPFastMathDefaultInfo() {
// Collect the SPIRVTypes that are OpTypeFloat and the constants of type
// int32, that might be used as FP Fast Math Mode.
std::vector<const MachineInstr *> SPIRVFloatTypes;
// Hashtable to associate immediate values with the constant holding them.
std::unordered_map<unsigned, const MachineInstr *> ConstMap;
for (const MachineInstr *MI : MAI->getMSInstrs(SPIRV::MB_TypeConstVars)) {
// Skip if the instruction is not OpTypeFloat or OpConstant.
unsigned OpCode = MI->getOpcode();
if (OpCode != SPIRV::OpTypeFloat && OpCode != SPIRV::OpConstantI &&
OpCode != SPIRV::OpConstantNull)
continue;

// Collect the SPIRV type if it's a float.
if (OpCode == SPIRV::OpTypeFloat) {
SPIRVFloatTypes.push_back(MI);
} else {
// Check if the constant is int32, if not skip it.
const MachineRegisterInfo &MRI = MI->getMF()->getRegInfo();
MachineInstr *TypeMI = MRI.getVRegDef(MI->getOperand(1).getReg());
if (!TypeMI || TypeMI->getOpcode() != SPIRV::OpTypeInt ||
TypeMI->getOperand(1).getImm() != 32)
continue;

if (OpCode == SPIRV::OpConstantI)
ConstMap[MI->getOperand(2).getImm()] = MI;
else
ConstMap[0] = MI;
}
}

for (const auto &[Func, FPFastMathDefaultInfoVec] :
MAI->FPFastMathDefaultInfoMap) {
if (FPFastMathDefaultInfoVec.empty())
continue;

for (const MachineInstr *MI : SPIRVFloatTypes) {
unsigned OpTypeFloatSize = MI->getOperand(1).getImm();
unsigned Index = computeFPFastMathDefaultInfoVecIndex(OpTypeFloatSize);
assert(Index < FPFastMathDefaultInfoVec.size() &&
"Index out of bounds for FPFastMathDefaultInfoVec");
const auto &FPFastMathDefaultInfo = FPFastMathDefaultInfoVec[Index];
assert(FPFastMathDefaultInfo.Ty &&
"Expected target type for FPFastMathDefaultInfo");
assert(FPFastMathDefaultInfo.Ty->getScalarSizeInBits() ==
OpTypeFloatSize &&
"Mismatched float type size");
MCInst Inst;
Inst.setOpcode(SPIRV::OpExecutionModeId);
MCRegister FuncReg = MAI->getFuncReg(Func);
assert(FuncReg.isValid());
Inst.addOperand(MCOperand::createReg(FuncReg));
Inst.addOperand(
MCOperand::createImm(SPIRV::ExecutionMode::FPFastMathDefault));
MCRegister TypeReg =
MAI->getRegisterAlias(MI->getMF(), MI->getOperand(0).getReg());
Inst.addOperand(MCOperand::createReg(TypeReg));
unsigned Flags = FPFastMathDefaultInfo.FastMathFlags;
if (FPFastMathDefaultInfo.ContractionOff &&
(Flags & SPIRV::FPFastMathMode::AllowContract))
report_fatal_error(
"Conflicting FPFastMathFlags: ContractionOff and AllowContract");

if (FPFastMathDefaultInfo.SignedZeroInfNanPreserve &&
!(Flags &
(SPIRV::FPFastMathMode::NotNaN | SPIRV::FPFastMathMode::NotInf |
SPIRV::FPFastMathMode::NSZ))) {
if (FPFastMathDefaultInfo.FPFastMathDefault)
report_fatal_error("Conflicting FPFastMathFlags: "
"SignedZeroInfNanPreserve but at least one of "
"NotNaN/NotInf/NSZ is enabled.");
}

// Don't emit if none of the execution modes was used.
if (Flags == SPIRV::FPFastMathMode::None &&
!FPFastMathDefaultInfo.ContractionOff &&
!FPFastMathDefaultInfo.SignedZeroInfNanPreserve &&
!FPFastMathDefaultInfo.FPFastMathDefault)
continue;

// Retrieve the constant instruction for the immediate value.
auto It = ConstMap.find(Flags);
if (It == ConstMap.end())
report_fatal_error("Expected constant instruction for FP Fast Math "
"Mode operand of FPFastMathDefault execution mode.");
const MachineInstr *ConstMI = It->second;
MCRegister ConstReg = MAI->getRegisterAlias(
ConstMI->getMF(), ConstMI->getOperand(0).getReg());
Inst.addOperand(MCOperand::createReg(ConstReg));
outputMCInst(Inst);
}
}
}

void SPIRVAsmPrinter::outputModuleSections() {
const Module *M = MMI->getModule();
// Get the global subtarget to output module-level info.
Expand All @@ -614,15 +797,17 @@ void SPIRVAsmPrinter::outputModuleSections() {
MAI = &SPIRVModuleAnalysis::MAI;
assert(ST && TII && MAI && M && "Module analysis is required");
// Output instructions according to the Logical Layout of a Module:
// 1,2. All OpCapability instructions, then optional OpExtension instructions.
// 1,2. All OpCapability instructions, then optional OpExtension
// instructions.
outputGlobalRequirements();
// 3. Optional OpExtInstImport instructions.
outputOpExtInstImports(*M);
// 4. The single required OpMemoryModel instruction.
outputOpMemoryModel();
// 5. All entry point declarations, using OpEntryPoint.
outputEntryPoints();
// 6. Execution-mode declarations, using OpExecutionMode or OpExecutionModeId.
// 6. Execution-mode declarations, using OpExecutionMode or
// OpExecutionModeId.
outputExecutionMode(*M);
// 7a. Debug: all OpString, OpSourceExtension, OpSource, and
// OpSourceContinued, without forward references.
Expand Down
26 changes: 23 additions & 3 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1127,11 +1127,24 @@ static unsigned getNumSizeComponents(SPIRVType *imgType) {

static bool generateExtInst(const SPIRV::IncomingCall *Call,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
SPIRVGlobalRegistry *GR, const CallBase &CB) {
// Lookup the extended instruction number in the TableGen records.
const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
uint32_t Number =
SPIRV::lookupExtendedBuiltin(Builtin->Name, Builtin->Set)->Number;
// fmin_common and fmax_common are now deprecated, and we should use fmin and
// fmax with NotInf and NotNaN flags instead. Keep original number to add
// later the NoNans and NoInfs flags.
uint32_t OrigNumber = Number;
const SPIRVSubtarget &ST =
cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2) &&
(Number == SPIRV::OpenCLExtInst::fmin_common ||
Number == SPIRV::OpenCLExtInst::fmax_common)) {
Number = (Number == SPIRV::OpenCLExtInst::fmin_common)
? SPIRV::OpenCLExtInst::fmin
: SPIRV::OpenCLExtInst::fmax;
}

// Build extended instruction.
auto MIB =
Expand All @@ -1143,6 +1156,13 @@ static bool generateExtInst(const SPIRV::IncomingCall *Call,

for (auto Argument : Call->Arguments)
MIB.addUse(Argument);
MIB.getInstr()->copyIRFlags(CB);
if (OrigNumber == SPIRV::OpenCLExtInst::fmin_common ||
OrigNumber == SPIRV::OpenCLExtInst::fmax_common) {
// Add NoNans and NoInfs flags to fmin/fmax instruction.
MIB.getInstr()->setFlag(MachineInstr::MIFlag::FmNoNans);
MIB.getInstr()->setFlag(MachineInstr::MIFlag::FmNoInfs);
}
return true;
}

Expand Down Expand Up @@ -2859,7 +2879,7 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
MachineIRBuilder &MIRBuilder,
const Register OrigRet, const Type *OrigRetTy,
const SmallVectorImpl<Register> &Args,
SPIRVGlobalRegistry *GR) {
SPIRVGlobalRegistry *GR, const CallBase &CB) {
LLVM_DEBUG(dbgs() << "Lowering builtin call: " << DemangledCall << "\n");

// Lookup the builtin in the TableGen records.
Expand All @@ -2882,7 +2902,7 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
// Match the builtin with implementation based on the grouping.
switch (Call->Builtin->Group) {
case SPIRV::Extended:
return generateExtInst(Call.get(), MIRBuilder, GR);
return generateExtInst(Call.get(), MIRBuilder, GR, CB);
case SPIRV::Relational:
return generateRelationalInst(Call.get(), MIRBuilder, GR);
case SPIRV::Group:
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/SPIRV/SPIRVBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
MachineIRBuilder &MIRBuilder,
const Register OrigRet, const Type *OrigRetTy,
const SmallVectorImpl<Register> &Args,
SPIRVGlobalRegistry *GR);
SPIRVGlobalRegistry *GR, const CallBase &CB);

/// Helper function for finding a builtin function attributes
/// by a demangled function name. Defined in SPIRVBuiltins.cpp.
Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -641,9 +641,9 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
GR->getPointerSize()));
}
}
if (auto Res =
SPIRV::lowerBuiltin(DemangledName, ST->getPreferredInstructionSet(),
MIRBuilder, ResVReg, OrigRetTy, ArgVRegs, GR))
if (auto Res = SPIRV::lowerBuiltin(
DemangledName, ST->getPreferredInstructionSet(), MIRBuilder,
ResVReg, OrigRetTy, ArgVRegs, GR, *Info.CB))
return *Res;
}

Expand Down
Loading