Skip to content

Commit ef2f432

Browse files
alan-bakerdneto0
andauthored
Add support for SPV_KHR_float_controls2 (KhronosGroup#5543)
* Test asm/dis for SPV_KHR_float_controls2 * SPV_KHR_float_controls2 validation --------- Co-authored-by: David Neto <[email protected]>
1 parent de3d5ac commit ef2f432

9 files changed

+961
-23
lines changed

source/val/validate.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,10 @@ spv_result_t ValidateEntryPoints(ValidationState_t& _) {
141141
}
142142
}
143143

144+
if (auto error = ValidateFloatControls2(_)) {
145+
return error;
146+
}
147+
144148
return SPV_SUCCESS;
145149
}
146150

source/val/validate.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,18 @@ spv_result_t ValidateAdjacency(ValidationState_t& _);
8282
/// @return SPV_SUCCESS if no errors are found.
8383
spv_result_t ValidateInterfaces(ValidationState_t& _);
8484

85+
/// @brief Validates entry point call tree requirements of
86+
/// SPV_KHR_float_controls2
87+
///
88+
/// Checks that no entry point using FPFastMathDefault uses:
89+
/// * FPFastMathMode Fast
90+
/// * NoContraction
91+
///
92+
/// @param[in] _ the validation state of the module
93+
///
94+
/// @return SPV_SUCCESS if no errors are found.
95+
spv_result_t ValidateFloatControls2(ValidationState_t& _);
96+
8597
/// @brief Validates memory instructions
8698
///
8799
/// @param[in] _ the validation state of the module

source/val/validate_annotation.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,34 @@ spv_result_t ValidateDecorate(ValidationState_t& _, const Instruction* inst) {
267267
}
268268
}
269269

270+
if (decoration == spv::Decoration::FPFastMathMode) {
271+
if (_.HasDecoration(target_id, spv::Decoration::NoContraction)) {
272+
return _.diag(SPV_ERROR_INVALID_ID, inst)
273+
<< "FPFastMathMode and NoContraction cannot decorate the same "
274+
"target";
275+
}
276+
auto mask = inst->GetOperandAs<spv::FPFastMathModeMask>(2);
277+
if ((mask & spv::FPFastMathModeMask::AllowTransform) !=
278+
spv::FPFastMathModeMask::MaskNone &&
279+
((mask & (spv::FPFastMathModeMask::AllowContract |
280+
spv::FPFastMathModeMask::AllowReassoc)) !=
281+
(spv::FPFastMathModeMask::AllowContract |
282+
spv::FPFastMathModeMask::AllowReassoc))) {
283+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
284+
<< "AllowReassoc and AllowContract must be specified when "
285+
"AllowTransform is specified";
286+
}
287+
}
288+
289+
// This is checked from both sides since we register decorations as we go.
290+
if (decoration == spv::Decoration::NoContraction) {
291+
if (_.HasDecoration(target_id, spv::Decoration::FPFastMathMode)) {
292+
return _.diag(SPV_ERROR_INVALID_ID, inst)
293+
<< "FPFastMathMode and NoContraction cannot decorate the same "
294+
"target";
295+
}
296+
}
297+
270298
if (DecorationTakesIdParameters(decoration)) {
271299
return _.diag(SPV_ERROR_INVALID_ID, inst)
272300
<< "Decorations taking ID parameters may not be used with "

source/val/validate_instruction.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,8 @@ spv_result_t InstructionPass(ValidationState_t& _, const Instruction* inst) {
470470
}
471471
_.set_addressing_model(inst->GetOperandAs<spv::AddressingModel>(0));
472472
_.set_memory_model(inst->GetOperandAs<spv::MemoryModel>(1));
473-
} else if (opcode == spv::Op::OpExecutionMode) {
473+
} else if (opcode == spv::Op::OpExecutionMode ||
474+
opcode == spv::Op::OpExecutionModeId) {
474475
const uint32_t entry_point = inst->word(1);
475476
_.RegisterExecutionModeForEntryPoint(entry_point,
476477
spv::ExecutionMode(inst->word(2)));

source/val/validate_mode_setting.cpp

Lines changed: 157 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -340,29 +340,92 @@ spv_result_t ValidateExecutionMode(ValidationState_t& _,
340340

341341
const auto mode = inst->GetOperandAs<spv::ExecutionMode>(1);
342342
if (inst->opcode() == spv::Op::OpExecutionModeId) {
343+
bool valid_mode = false;
344+
switch (mode) {
345+
case spv::ExecutionMode::SubgroupsPerWorkgroupId:
346+
case spv::ExecutionMode::LocalSizeHintId:
347+
case spv::ExecutionMode::LocalSizeId:
348+
case spv::ExecutionMode::FPFastMathDefault:
349+
valid_mode = true;
350+
break;
351+
default:
352+
valid_mode = false;
353+
break;
354+
}
355+
if (!valid_mode) {
356+
return _.diag(SPV_ERROR_INVALID_ID, inst)
357+
<< "OpExecutionModeId is only valid when the Mode operand is an "
358+
"execution mode that takes Extra Operands that are id "
359+
"operands.";
360+
}
361+
343362
size_t operand_count = inst->operands().size();
344363
for (size_t i = 2; i < operand_count; ++i) {
345-
const auto operand_id = inst->GetOperandAs<uint32_t>(2);
364+
const auto operand_id = inst->GetOperandAs<uint32_t>(i);
346365
const auto* operand_inst = _.FindDef(operand_id);
347-
if (mode == spv::ExecutionMode::SubgroupsPerWorkgroupId ||
348-
mode == spv::ExecutionMode::LocalSizeHintId ||
349-
mode == spv::ExecutionMode::LocalSizeId) {
350-
if (!spvOpcodeIsConstant(operand_inst->opcode())) {
351-
return _.diag(SPV_ERROR_INVALID_ID, inst)
352-
<< "For OpExecutionModeId all Extra Operand ids must be "
353-
"constant "
354-
"instructions.";
355-
}
356-
} else {
357-
return _.diag(SPV_ERROR_INVALID_ID, inst)
358-
<< "OpExecutionModeId is only valid when the Mode operand is an "
359-
"execution mode that takes Extra Operands that are id "
360-
"operands.";
366+
switch (mode) {
367+
case spv::ExecutionMode::SubgroupsPerWorkgroupId:
368+
case spv::ExecutionMode::LocalSizeHintId:
369+
case spv::ExecutionMode::LocalSizeId:
370+
if (!spvOpcodeIsConstant(operand_inst->opcode())) {
371+
return _.diag(SPV_ERROR_INVALID_ID, inst)
372+
<< "For OpExecutionModeId all Extra Operand ids must be "
373+
"constant instructions.";
374+
}
375+
break;
376+
case spv::ExecutionMode::FPFastMathDefault:
377+
if (i == 2) {
378+
if (!_.IsFloatScalarType(operand_id)) {
379+
return _.diag(SPV_ERROR_INVALID_ID, inst)
380+
<< "The Target Type operand must be a floating-point "
381+
"scalar type";
382+
}
383+
} else {
384+
bool is_int32 = false;
385+
bool is_const = false;
386+
uint32_t value = 0;
387+
std::tie(is_int32, is_const, value) =
388+
_.EvalInt32IfConst(operand_id);
389+
if (is_int32 && is_const) {
390+
// Valid values include up to 0x00040000 (AllowTransform).
391+
uint32_t invalid_mask = 0xfff80000;
392+
if ((invalid_mask & value) != 0) {
393+
return _.diag(SPV_ERROR_INVALID_ID, inst)
394+
<< "The Fast Math Default operand is an invalid bitmask "
395+
"value";
396+
}
397+
if (value &
398+
static_cast<uint32_t>(spv::FPFastMathModeMask::Fast)) {
399+
return _.diag(SPV_ERROR_INVALID_ID, inst)
400+
<< "The Fast Math Default operand must not include Fast";
401+
}
402+
const auto reassoc_contract =
403+
spv::FPFastMathModeMask::AllowContract |
404+
spv::FPFastMathModeMask::AllowReassoc;
405+
if ((value & static_cast<uint32_t>(
406+
spv::FPFastMathModeMask::AllowTransform)) != 0 &&
407+
((value & static_cast<uint32_t>(reassoc_contract)) !=
408+
static_cast<uint32_t>(reassoc_contract))) {
409+
return _.diag(SPV_ERROR_INVALID_ID, inst)
410+
<< "The Fast Math Default operand must include "
411+
"AllowContract and AllowReassoc when AllowTransform "
412+
"is specified";
413+
}
414+
} else {
415+
return _.diag(SPV_ERROR_INVALID_ID, inst)
416+
<< "The Fast Math Default operand must be a "
417+
"non-specialization constant";
418+
}
419+
}
420+
break;
421+
default:
422+
break;
361423
}
362424
}
363425
} else if (mode == spv::ExecutionMode::SubgroupsPerWorkgroupId ||
364426
mode == spv::ExecutionMode::LocalSizeHintId ||
365-
mode == spv::ExecutionMode::LocalSizeId) {
427+
mode == spv::ExecutionMode::LocalSizeId ||
428+
mode == spv::ExecutionMode::FPFastMathDefault) {
366429
return _.diag(SPV_ERROR_INVALID_DATA, inst)
367430
<< "OpExecutionMode is only valid when the Mode operand is an "
368431
"execution mode that takes no Extra Operands, or takes Extra "
@@ -579,6 +642,20 @@ spv_result_t ValidateExecutionMode(ValidationState_t& _,
579642
break;
580643
}
581644

645+
if (mode == spv::ExecutionMode::FPFastMathDefault) {
646+
const auto* modes = _.GetExecutionModes(entry_point_id);
647+
if (modes && modes->count(spv::ExecutionMode::ContractionOff)) {
648+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
649+
<< "FPFastMathDefault and ContractionOff execution modes cannot "
650+
"be applied to the same entry point";
651+
}
652+
if (modes && modes->count(spv::ExecutionMode::SignedZeroInfNanPreserve)) {
653+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
654+
<< "FPFastMathDefault and SignedZeroInfNanPreserve execution "
655+
"modes cannot be applied to the same entry point";
656+
}
657+
}
658+
582659
if (spvIsVulkanEnv(_.context()->target_env)) {
583660
if (mode == spv::ExecutionMode::OriginLowerLeft) {
584661
return _.diag(SPV_ERROR_INVALID_DATA, inst)
@@ -636,6 +713,70 @@ spv_result_t ValidateMemoryModel(ValidationState_t& _,
636713

637714
} // namespace
638715

716+
spv_result_t ValidateFloatControls2(ValidationState_t& _) {
717+
std::unordered_set<uint32_t> fp_fast_math_default_entry_points;
718+
for (auto entry_point : _.entry_points()) {
719+
const auto* exec_modes = _.GetExecutionModes(entry_point);
720+
if (exec_modes &&
721+
exec_modes->count(spv::ExecutionMode::FPFastMathDefault)) {
722+
fp_fast_math_default_entry_points.insert(entry_point);
723+
}
724+
}
725+
726+
std::vector<std::pair<const Instruction*, spv::Decoration>> worklist;
727+
for (const auto& inst : _.ordered_instructions()) {
728+
if (inst.opcode() != spv::Op::OpDecorate) {
729+
continue;
730+
}
731+
732+
const auto decoration = inst.GetOperandAs<spv::Decoration>(1);
733+
const auto target_id = inst.GetOperandAs<uint32_t>(0);
734+
const auto target = _.FindDef(target_id);
735+
if (decoration == spv::Decoration::NoContraction) {
736+
worklist.push_back(std::make_pair(target, decoration));
737+
} else if (decoration == spv::Decoration::FPFastMathMode) {
738+
auto mask = inst.GetOperandAs<spv::FPFastMathModeMask>(2);
739+
if ((mask & spv::FPFastMathModeMask::Fast) !=
740+
spv::FPFastMathModeMask::MaskNone) {
741+
worklist.push_back(std::make_pair(target, decoration));
742+
}
743+
}
744+
}
745+
746+
std::unordered_set<const Instruction*> visited;
747+
while (!worklist.empty()) {
748+
const auto inst = worklist.back().first;
749+
const auto decoration = worklist.back().second;
750+
worklist.pop_back();
751+
752+
if (!visited.insert(inst).second) {
753+
continue;
754+
}
755+
756+
const auto function = inst->function();
757+
if (function) {
758+
const auto& entry_points = _.FunctionEntryPoints(function->id());
759+
for (auto entry_point : entry_points) {
760+
if (fp_fast_math_default_entry_points.count(entry_point)) {
761+
const std::string dec = decoration == spv::Decoration::NoContraction
762+
? "NoContraction"
763+
: "FPFastMathMode Fast";
764+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
765+
<< dec
766+
<< " cannot be used by an entry point with the "
767+
"FPFastMathDefault execution mode";
768+
}
769+
}
770+
} else {
771+
for (const auto& pair : inst->uses()) {
772+
worklist.push_back(std::make_pair(pair.first, decoration));
773+
}
774+
}
775+
}
776+
777+
return SPV_SUCCESS;
778+
}
779+
639780
spv_result_t ModeSettingPass(ValidationState_t& _, const Instruction* inst) {
640781
switch (inst->opcode()) {
641782
case spv::Op::OpEntryPoint:

test/operand_capabilities_test.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
#include "source/assembly_grammar.h"
2222
#include "source/enum_set.h"
2323
#include "source/operand.h"
24+
#include "source/spirv_target_env.h"
25+
#include "source/table.h"
26+
#include "spirv-tools/libspirv.h"
2427
#include "test/unit_spirv.h"
2528

2629
namespace spvtools {
@@ -58,15 +61,17 @@ struct EnumCapabilityCase {
5861
uint32_t value;
5962
CapabilitySet expected_capabilities;
6063
};
61-
// Emits an EnumCapabilityCase to the ostream, returning the ostream.
62-
inline std::ostream& operator<<(std::ostream& out,
63-
const EnumCapabilityCase& ecc) {
64-
out << "EnumCapabilityCase{ " << spvOperandTypeStr(ecc.type) << "("
65-
<< unsigned(ecc.type) << "), " << ecc.value << ", "
66-
<< ecc.expected_capabilities << "}";
64+
65+
// Emits an EnumCapabilityCase to the given output stream. This is used
66+
// to emit failure cases when they occur, which helps debug tests.
67+
inline std::ostream& operator<<(std::ostream& out, EnumCapabilityCase e) {
68+
out << "{" << spvOperandTypeStr(e.type) << " " << e.value << " "
69+
<< e.expected_capabilities << " }";
6770
return out;
6871
}
6972

73+
using EnvEnumCapabilityCase = std::tuple<spv_target_env, EnumCapabilityCase>;
74+
7075
// Test fixture for testing EnumCapabilityCases.
7176
using EnumCapabilityTest =
7277
TestWithParam<std::tuple<spv_target_env, EnumCapabilityCase>>;

test/text_to_binary.extension_test.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1264,5 +1264,41 @@ INSTANTIATE_TEST_SUITE_P(
12641264
{1, (uint32_t)spv::ExecutionMode::MaximallyReconvergesKHR})},
12651265
})));
12661266

1267+
// SPV_KHR_float_controls2
1268+
1269+
INSTANTIATE_TEST_SUITE_P(
1270+
SPV_KHR_float_controls2, ExtensionRoundTripTest,
1271+
Combine(
1272+
Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_5, SPV_ENV_VULKAN_1_0,
1273+
SPV_ENV_VULKAN_1_1, SPV_ENV_VULKAN_1_2, SPV_ENV_VULKAN_1_3),
1274+
ValuesIn(std::vector<AssemblyCase>{
1275+
{"OpExtension \"SPV_KHR_float_controls2\"\n",
1276+
MakeInstruction(spv::Op::OpExtension,
1277+
MakeVector("SPV_KHR_float_controls2"))},
1278+
{"OpCapability FloatControls2\n",
1279+
MakeInstruction(spv::Op::OpCapability,
1280+
{(uint32_t)spv::Capability::FloatControls2})},
1281+
{"OpExecutionMode %1 FPFastMathDefault %2 %3\n",
1282+
// The operands are: target type, flags constant
1283+
MakeInstruction(
1284+
spv::Op::OpExecutionMode,
1285+
{1, (uint32_t)spv::ExecutionMode::FPFastMathDefault, 2, 3})},
1286+
{"OpDecorate %1 FPFastMathMode AllowContract\n",
1287+
MakeInstruction(
1288+
spv::Op::OpDecorate,
1289+
{1, (uint32_t)spv::Decoration::FPFastMathMode,
1290+
(uint32_t)spv::FPFastMathModeMask::AllowContract})},
1291+
{"OpDecorate %1 FPFastMathMode AllowReassoc\n",
1292+
MakeInstruction(
1293+
spv::Op::OpDecorate,
1294+
{1, (uint32_t)spv::Decoration::FPFastMathMode,
1295+
(uint32_t)spv::FPFastMathModeMask::AllowReassoc})},
1296+
{"OpDecorate %1 FPFastMathMode AllowTransform\n",
1297+
MakeInstruction(
1298+
spv::Op::OpDecorate,
1299+
{1, (uint32_t)spv::Decoration::FPFastMathMode,
1300+
(uint32_t)spv::FPFastMathModeMask::AllowTransform})},
1301+
})));
1302+
12671303
} // namespace
12681304
} // namespace spvtools

0 commit comments

Comments
 (0)