Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
cbe7a07
Initial support for SPV_KHR_float_controls2.
maarquitos14 Jun 25, 2025
4c325ac
Fix test failures.
maarquitos14 Jun 30, 2025
a664820
Deprecations only apply when the extension is enabled.
maarquitos14 Jul 3, 2025
0648840
Add extension specific tests.
maarquitos14 Jul 3, 2025
34c0628
Further work to apply deprecations only when extension is enabled.
maarquitos14 Jul 3, 2025
1d87705
Remove leftover comment.
maarquitos14 Jul 3, 2025
de07e83
Yet more work to apply deprecations only when extension is enabled.
maarquitos14 Jul 3, 2025
3edf82a
Renamings and bugfixes.
maarquitos14 Jul 3, 2025
c0ef106
Renamings, bugfixes and undo undesired changes.
maarquitos14 Jul 3, 2025
97720d5
Undo undesired changes.
maarquitos14 Jul 3, 2025
3792464
Merge remote-tracking branch 'origin/main' into maronas/float-controls2
maarquitos14 Jul 4, 2025
eccd27a
Fix after-merge bugs.
maarquitos14 Jul 4, 2025
868c6a4
Fix clang-format issues.
maarquitos14 Jul 4, 2025
21c17d1
Replace undef with poison for placeholders.
maarquitos14 Jul 4, 2025
cfdc210
Add SPIR-V Instructions Mapped to LLVM Metadata section.
maarquitos14 Jul 7, 2025
9b8b43d
Address code review.
maarquitos14 Jul 7, 2025
a353295
Address code review feedback.
maarquitos14 Jul 16, 2025
f1403d6
Fix documentation build.
maarquitos14 Jul 17, 2025
37a3abe
Fix another error in documentation build.
maarquitos14 Jul 17, 2025
d76db3e
Ignore spirv.Decorations metadata related to FP flags.
maarquitos14 Aug 1, 2025
00548e5
Merge remote-tracking branch 'origin/main' into maronas/float-controls2
maarquitos14 Aug 1, 2025
f2aefbf
Address code review feedback.
maarquitos14 Aug 8, 2025
6a77fcd
Merge remote-tracking branch 'origin/main' into maronas/float-controls2
maarquitos14 Aug 8, 2025
848611a
Fix clang-format.
maarquitos14 Aug 8, 2025
527311d
Merge remote-tracking branch 'origin/main' into maronas/float-controls2
maarquitos14 Aug 26, 2025
bde0292
Address reviewer feedback.
maarquitos14 Aug 26, 2025
2295de4
Merge branch 'main' into maronas/float-controls2
maarquitos14 Aug 26, 2025
6420f43
Use constant instruction for flags.
maarquitos14 Sep 2, 2025
0bf9ca6
Enable spirv-val and update tests accordingly.
maarquitos14 Sep 2, 2025
5f26101
Merge remote-tracking branch 'origin/main' into maronas/float-controls2
maarquitos14 Sep 2, 2025
368793f
Merge branch 'maronas/float-controls2-conflict' into maronas/float-co…
maarquitos14 Sep 2, 2025
cb1065e
Rename new function to better represent what it does.
maarquitos14 Sep 2, 2025
d337ef2
Merge branch 'main' into maronas/float-controls2
maarquitos14 Sep 2, 2025
c19d296
Fix clang-format issues.
maarquitos14 Sep 2, 2025
c9d226c
Merge branch 'maronas/float-controls2-conflict' into maronas/float-co…
maarquitos14 Sep 2, 2025
b5ce97a
Remove unused variable.
maarquitos14 Sep 2, 2025
0cd1923
Merge remote-tracking branch 'origin/main' into maronas/float-controls2
maarquitos14 Sep 2, 2025
c973c1a
Fix test failure.
maarquitos14 Sep 2, 2025
f684280
Apply code review suggestions.
maarquitos14 Sep 11, 2025
931431b
Merge remote-tracking branch 'origin/main' into maronas/float-controls2
maarquitos14 Sep 11, 2025
f2af0d4
Fix clang-format issue.
maarquitos14 Sep 11, 2025
910cd69
Merge branch 'main' into maronas/float-controls2
MrSidims Sep 29, 2025
b487250
Merge branch 'main' into maronas/float-controls2
maarquitos14 Sep 29, 2025
9768733
Merge remote-tracking branch 'origin/main' into HEAD
maarquitos14 Sep 30, 2025
fe49f8b
Address test failures caused by main branch merge.
maarquitos14 Sep 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llvm/docs/SPIRVUsage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ list of supported SPIR-V extensions, sorted alphabetically by their extension na
* - ``SPV_INTEL_int4``
- Adds support for 4-bit integer type, and allow this type to be used in cooperative matrices.
* - ``SPV_KHR_float_controls2``
- Adds ability to specify the floating-point environment in shaders. It can be used on whole modules and individual instructions.
- Adds execution modes and decorations to control floating-point computations in both kernels and shaders. It can be used on whole modules and individual instructions.

To enable multiple extensions, list them separated by comma. For example, to enable support for atomic operations on floating-point numbers and arbitrary precision integers, use:

Expand Down
183 changes: 175 additions & 8 deletions llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class SPIRVAsmPrinter : public AsmPrinter {
void outputExecutionMode(const Module &M);
void outputAnnotations(const Module &M);
void outputModuleSections();
void outputFPFastMathDefaultInfo();
bool isHidden() {
return MF->getFunction()
.getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME)
Expand Down Expand Up @@ -458,6 +459,7 @@ void SPIRVAsmPrinter::outputExecutionModeFromMDNode(
unsigned ExpectMDOps, int64_t DefVal) {
MCInst Inst;
Inst.setOpcode(SPIRV::OpExecutionMode);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

Inst.addOperand(MCOperand::createReg(Reg));
Inst.addOperand(MCOperand::createImm(static_cast<unsigned>(EM)));
addOpsFromMDNode(Node, Inst, MAI);
Expand Down Expand Up @@ -497,11 +499,27 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) {
NamedMDNode *Node = M.getNamedMetadata("spirv.ExecutionMode");
if (Node) {
for (unsigned i = 0; i < Node->getNumOperands(); i++) {
// If SPV_KHR_float_controls2 is enabled and we find any of
// FPFastMathDefault, ContractionOff or SignedZeroInfNanPreserve execution
// modes, skip it, it'll be done somewhere else.
if (ST->canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) {
const auto EM =
cast<ConstantInt>(
cast<ConstantAsMetadata>((Node->getOperand(i))->getOperand(1))
->getValue())
->getZExtValue();
if (EM == SPIRV::ExecutionMode::FPFastMathDefault ||
EM == SPIRV::ExecutionMode::ContractionOff ||
EM == SPIRV::ExecutionMode::SignedZeroInfNanPreserve)
continue;
}

MCInst Inst;
Inst.setOpcode(SPIRV::OpExecutionMode);
addOpsFromMDNode(cast<MDNode>(Node->getOperand(i)), Inst, MAI);
outputMCInst(Inst);
}
outputFPFastMathDefaultInfo();
}
for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) {
const Function &F = *FI;
Expand Down Expand Up @@ -551,12 +569,85 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) {
}
if (ST->isKernel() && !M.getNamedMetadata("spirv.ExecutionMode") &&
!M.getNamedMetadata("opencl.enable.FP_CONTRACT")) {
MCInst Inst;
Inst.setOpcode(SPIRV::OpExecutionMode);
Inst.addOperand(MCOperand::createReg(FReg));
unsigned EM = static_cast<unsigned>(SPIRV::ExecutionMode::ContractionOff);
Inst.addOperand(MCOperand::createImm(EM));
outputMCInst(Inst);
if (ST->canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) {
// When SPV_KHR_float_controls2 is enabled, ContractionOff is
// deprecated. We need to use FPFastMathDefault with the appropriate
// flags instead. Since FPFastMathDefault takes a target type, we need
// to emit it for each floating-point type to match the effect of
// ContractionOff. As of now, there are 4 FP types: fp16, fp32, fp64 and
// fp128.
constexpr size_t NumFPTypes = 4;
for (size_t i = 0; i < NumFPTypes; ++i) {
MCInst Inst;
Inst.setOpcode(SPIRV::OpExecutionMode);
Inst.addOperand(MCOperand::createReg(FReg));
unsigned EM =
static_cast<unsigned>(SPIRV::ExecutionMode::FPFastMathDefault);
Inst.addOperand(MCOperand::createImm(EM));

Type *TargetType = nullptr;
switch (i) {
case 0:
TargetType = Type::getHalfTy(M.getContext());
break;
case 1:
TargetType = Type::getFloatTy(M.getContext());
break;
case 2:
TargetType = Type::getDoubleTy(M.getContext());
break;
case 3:
TargetType = Type::getFP128Ty(M.getContext());
break;
}
assert(TargetType && "Invalid target type for FPFastMathDefault");

// Find the SPIRV type matching the target type. We'll go over all the
// TypeConstVars instructions in the SPIRV module and find the one
// that matches the target type. We know the target type is a
// floating-point type, so we can skip anything different than
// OpTypeFloat. Then, we need to check the bitwidth.
bool SPIRVTypeFound = false;
for (const MachineInstr *MI :
MAI->getMSInstrs(SPIRV::MB_TypeConstVars)) {
// Skip if the instruction is not OpTypeFloat.
if (MI->getOpcode() != SPIRV::OpTypeFloat)
continue;

// Skip if TargetTy bitwidth doesn't match MI->getOperand(1), which
// is the SPIRV type bit width.
if (TargetType->getScalarSizeInBits() != MI->getOperand(1).getImm())
continue;

SPIRVTypeFound = true;
const MachineFunction *MF = MI->getMF();
MCRegister TypeReg =
MAI->getRegisterAlias(MF, MI->getOperand(0).getReg());
Inst.addOperand(MCOperand::createReg(TypeReg));
}

if (!SPIRVTypeFound) {
// The module does not contain this FP type, so we don't need to
// emit FPFastMathDefault for it.
continue;
}
// We only end up here because there is no "spirv.ExecutionMode"
// metadata, so that means no FPFastMathDefault. Therefore, we only
// need to make sure AllowContract is set to 0, as the rest of flags.
// We still need to emit the OpExecutionMode instruction, otherwise
// it's up to the client API to define the flags.
Inst.addOperand(MCOperand::createImm(SPIRV::FPFastMathMode::None));
outputMCInst(Inst);
}
} else {
MCInst Inst;
Inst.setOpcode(SPIRV::OpExecutionMode);
Inst.addOperand(MCOperand::createReg(FReg));
unsigned EM =
static_cast<unsigned>(SPIRV::ExecutionMode::ContractionOff);
Inst.addOperand(MCOperand::createImm(EM));
outputMCInst(Inst);
}
}
}
}
Expand Down Expand Up @@ -603,6 +694,80 @@ void SPIRVAsmPrinter::outputAnnotations(const Module &M) {
}
}

void SPIRVAsmPrinter::outputFPFastMathDefaultInfo() {
for (const auto &[Func, FPFastMathDefaultInfoVec] :
MAI->FPFastMathDefaultInfoMap) {
for (const auto &FPFastMathDefaultInfo : FPFastMathDefaultInfoVec) {
MCInst Inst;
Inst.setOpcode(SPIRV::OpExecutionMode);
MCRegister FuncReg = MAI->getFuncReg(Func);
assert(FuncReg.isValid());
Inst.addOperand(MCOperand::createReg(FuncReg));
Inst.addOperand(
MCOperand::createImm(SPIRV::ExecutionMode::FPFastMathDefault));

// Find the SPIRV type matching the target type. We'll go over all the
// TypeConstVars instructions in the SPIRV module and find the one that
// matches the target type. We know the target type is a floating-point
// type, so we can skip anything different than OpTypeFloat. Then, we
// need to check the bitwidth.
const Type *TargetTy = FPFastMathDefaultInfo.Ty;
assert(TargetTy && "Expected target type");
bool SPIRVTypeFound = false;
for (const MachineInstr *MI : MAI->getMSInstrs(SPIRV::MB_TypeConstVars)) {
// Skip if the instruction is not OpTypeFloat.
if (MI->getOpcode() != SPIRV::OpTypeFloat)
continue;

// Skip if TargetTy bitwidth doesn't match MI->getOperand(1), which is
// the SPIRV type bit width.
if (TargetTy->getScalarSizeInBits() != MI->getOperand(1).getImm())
continue;

SPIRVTypeFound = true;
const MachineFunction *MF = MI->getMF();
MCRegister TypeReg =
MAI->getRegisterAlias(MF, MI->getOperand(0).getReg());
Inst.addOperand(MCOperand::createReg(TypeReg));
}
if (!SPIRVTypeFound) {
// The module does not contain this FP type, so we don't need to emit
// FPFastMathDefault for it.
continue;
}

unsigned Flags = FPFastMathDefaultInfo.FastMathFlags;
if (FPFastMathDefaultInfo.ContractionOff &&
(Flags & SPIRV::FPFastMathMode::AllowContract) &&
FPFastMathDefaultInfo.FPFastMathDefault)
report_fatal_error(
"Conflicting FPFastMathFlags: ContractionOff and AllowContract");

if (FPFastMathDefaultInfo.SignedZeroInfNanPreserve &&
!(Flags &
(SPIRV::FPFastMathMode::NotNaN | SPIRV::FPFastMathMode::NotInf |
SPIRV::FPFastMathMode::NSZ))) {
if (FPFastMathDefaultInfo.FPFastMathDefault)
report_fatal_error("Conflicting FPFastMathFlags: "
"SignedZeroInfNanPreserve but at least one of "
"NotNaN/NotInf/NSZ is disabled.");

Flags |= SPIRV::FPFastMathMode::NotNaN | SPIRV::FPFastMathMode::NotInf |
SPIRV::FPFastMathMode::NSZ;
}

// Don't emit if none of the execution modes was used.
if (Flags == SPIRV::FPFastMathMode::None &&
!FPFastMathDefaultInfo.ContractionOff &&
!FPFastMathDefaultInfo.SignedZeroInfNanPreserve &&
!FPFastMathDefaultInfo.FPFastMathDefault)
continue;
Inst.addOperand(MCOperand::createImm(Flags));
outputMCInst(Inst);
}
}
}

void SPIRVAsmPrinter::outputModuleSections() {
const Module *M = MMI->getModule();
// Get the global subtarget to output module-level info.
Expand All @@ -611,15 +776,17 @@ void SPIRVAsmPrinter::outputModuleSections() {
MAI = &SPIRVModuleAnalysis::MAI;
assert(ST && TII && MAI && M && "Module analysis is required");
// Output instructions according to the Logical Layout of a Module:
// 1,2. All OpCapability instructions, then optional OpExtension instructions.
// 1,2. All OpCapability instructions, then optional OpExtension
// instructions.
outputGlobalRequirements();
// 3. Optional OpExtInstImport instructions.
outputOpExtInstImports(*M);
// 4. The single required OpMemoryModel instruction.
outputOpMemoryModel();
// 5. All entry point declarations, using OpEntryPoint.
outputEntryPoints();
// 6. Execution-mode declarations, using OpExecutionMode or OpExecutionModeId.
// 6. Execution-mode declarations, using OpExecutionMode or
// OpExecutionModeId.
outputExecutionMode(*M);
// 7a. Debug: all OpString, OpSourceExtension, OpSource, and
// OpSourceContinued, without forward references.
Expand Down
29 changes: 25 additions & 4 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,8 @@ static bool buildAtomicStoreInst(const SPIRV::IncomingCall *Call,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
if (Call->isSpirvOp())
return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call, Register(0));
return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call,
Register(0));

Register ScopeRegister =
buildConstantIntReg32(SPIRV::Scope::Device, MIRBuilder, GR);
Expand Down Expand Up @@ -1125,11 +1126,24 @@ static unsigned getNumSizeComponents(SPIRVType *imgType) {

static bool generateExtInst(const SPIRV::IncomingCall *Call,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
SPIRVGlobalRegistry *GR, const CallBase &CB) {
// Lookup the extended instruction number in the TableGen records.
const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
uint32_t Number =
SPIRV::lookupExtendedBuiltin(Builtin->Name, Builtin->Set)->Number;
// fmin_common and fmax_common are now deprecated, and we should use fmin and
// fmax with NotInf and NotNaN flags instead. Keep original number to add
// later the NoNans and NoInfs flags.
uint32_t OrigNumber = Number;
const SPIRVSubtarget &ST =
cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2) &&
(Number == SPIRV::OpenCLExtInst::fmin_common ||
Number == SPIRV::OpenCLExtInst::fmax_common)) {
Number = (Number == SPIRV::OpenCLExtInst::fmin_common)
? SPIRV::OpenCLExtInst::fmin
: SPIRV::OpenCLExtInst::fmax;
}

// Build extended instruction.
auto MIB =
Expand All @@ -1141,6 +1155,13 @@ static bool generateExtInst(const SPIRV::IncomingCall *Call,

for (auto Argument : Call->Arguments)
MIB.addUse(Argument);
MIB.getInstr()->copyIRFlags(CB);
if (OrigNumber == SPIRV::OpenCLExtInst::fmin_common ||
OrigNumber == SPIRV::OpenCLExtInst::fmax_common) {
// Add NoNans and NoInfs flags to fmin/fmax instruction.
MIB.getInstr()->setFlag(MachineInstr::MIFlag::FmNoNans);
MIB.getInstr()->setFlag(MachineInstr::MIFlag::FmNoInfs);
}
return true;
}

Expand Down Expand Up @@ -2844,7 +2865,7 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
MachineIRBuilder &MIRBuilder,
const Register OrigRet, const Type *OrigRetTy,
const SmallVectorImpl<Register> &Args,
SPIRVGlobalRegistry *GR) {
SPIRVGlobalRegistry *GR, const CallBase &CB) {
LLVM_DEBUG(dbgs() << "Lowering builtin call: " << DemangledCall << "\n");

// Lookup the builtin in the TableGen records.
Expand All @@ -2867,7 +2888,7 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
// Match the builtin with implementation based on the grouping.
switch (Call->Builtin->Group) {
case SPIRV::Extended:
return generateExtInst(Call.get(), MIRBuilder, GR);
return generateExtInst(Call.get(), MIRBuilder, GR, CB);
case SPIRV::Relational:
return generateRelationalInst(Call.get(), MIRBuilder, GR);
case SPIRV::Group:
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/SPIRV/SPIRVBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
MachineIRBuilder &MIRBuilder,
const Register OrigRet, const Type *OrigRetTy,
const SmallVectorImpl<Register> &Args,
SPIRVGlobalRegistry *GR);
SPIRVGlobalRegistry *GR, const CallBase &CB);

/// Helper function for finding a builtin function attributes
/// by a demangled function name. Defined in SPIRVBuiltins.cpp.
Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -641,9 +641,9 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
GR->getPointerSize()));
}
}
if (auto Res =
SPIRV::lowerBuiltin(DemangledName, ST->getPreferredInstructionSet(),
MIRBuilder, ResVReg, OrigRetTy, ArgVRegs, GR))
if (auto Res = SPIRV::lowerBuiltin(
DemangledName, ST->getPreferredInstructionSet(), MIRBuilder,
ResVReg, OrigRetTy, ArgVRegs, GR, *Info.CB))
return *Res;
}

Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,7 @@ Register SPIRVGlobalRegistry::buildGlobalVariable(
// arguments.
MDNode *GVarMD = nullptr;
if (GVar && (GVarMD = GVar->getMetadata("spirv.Decorations")) != nullptr)
buildOpSpirvDecorations(Reg, MIRBuilder, GVarMD);
buildOpSpirvDecorations(Reg, MIRBuilder, GVarMD, ST);

return Reg;
}
Expand Down
21 changes: 20 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ bool SPIRVInstrInfo::isHeaderInstr(const MachineInstr &MI) const {
}
}

bool SPIRVInstrInfo::canUseFastMathFlags(const MachineInstr &MI) const {
bool SPIRVInstrInfo::canUseFastMathFlags(const MachineInstr &MI,
bool KHRFloatControls2) const {
switch (MI.getOpcode()) {
case SPIRV::OpFAddS:
case SPIRV::OpFSubS:
Expand All @@ -144,6 +145,24 @@ bool SPIRVInstrInfo::canUseFastMathFlags(const MachineInstr &MI) const {
case SPIRV::OpFRemV:
case SPIRV::OpFMod:
return true;
case SPIRV::OpFNegateV:
case SPIRV::OpFNegate:
case SPIRV::OpOrdered:
case SPIRV::OpUnordered:
case SPIRV::OpFOrdEqual:
case SPIRV::OpFOrdNotEqual:
case SPIRV::OpFOrdLessThan:
case SPIRV::OpFOrdLessThanEqual:
case SPIRV::OpFOrdGreaterThan:
case SPIRV::OpFOrdGreaterThanEqual:
case SPIRV::OpFUnordEqual:
case SPIRV::OpFUnordNotEqual:
case SPIRV::OpFUnordLessThan:
case SPIRV::OpFUnordLessThanEqual:
case SPIRV::OpFUnordGreaterThan:
case SPIRV::OpFUnordGreaterThanEqual:
case SPIRV::OpExtInst:
return KHRFloatControls2 ? true : false;
default:
return false;
}
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVInstrInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ class SPIRVInstrInfo : public SPIRVGenInstrInfo {
bool isTypeDeclInstr(const MachineInstr &MI) const;
bool isDecorationInstr(const MachineInstr &MI) const;
bool isAliasingInstr(const MachineInstr &MI) const;
bool canUseFastMathFlags(const MachineInstr &MI) const;
bool canUseFastMathFlags(const MachineInstr &MI,
bool KHRFloatControls2) const;
bool canUseNSW(const MachineInstr &MI) const;
bool canUseNUW(const MachineInstr &MI) const;

Expand Down
Loading
Loading