Skip to content

Commit b45cd0d

Browse files
committed
Add EXT_mesh_shader validation support
1. Each OpEntryPoint with the MeshEXT Execution Model can have at most one global OpVariable of storage class TaskPayloadWorkgroupEXT. 2. PerPrimitiveEXT only be used on a memory object declaration or a member of a structure type 3. PerPrimitiveEXT only Input in Fragment and Output in MeshEXT 4. Added Mesh vulkan validation support for following rules: VUID-Layer-Layer-07039 VUID-PrimitiveId-PrimitiveId-07040,VUID-PrimitivePointIndicesEXT-PrimitivePointIndicesEXT-07042, VUID-PrimitiveLineIndicesEXT-PrimitiveLineIndicesEXT-07048, VUID-PrimitiveTriangleIndicesEXT-PrimitiveTriangleIndicesEXT-07054, VUID-ViewportIndex-ViewportIndex-07060 VUID-StandaloneSpirv-ExecutionModel-07330 VUID-StandaloneSpirv-ExecutionModel-07331
1 parent 9295a8b commit b45cd0d

11 files changed

+1423
-6
lines changed

source/val/validate_annotation.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ spv_result_t ValidateDecorationTarget(ValidationState_t& _, spv::Decoration dec,
163163
case spv::Decoration::Stream:
164164
case spv::Decoration::RestrictPointer:
165165
case spv::Decoration::AliasedPointer:
166+
case spv::Decoration::PerPrimitiveNV:
166167
if (target->opcode() != spv::Op::OpVariable &&
167168
target->opcode() != spv::Op::OpUntypedVariableKHR &&
168169
target->opcode() != spv::Op::OpFunctionParameter &&

source/val/validate_builtins.cpp

Lines changed: 136 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ typedef enum VUIDError_ {
122122
VUIDErrorMax,
123123
} VUIDError;
124124

125-
const static uint32_t NumVUIDBuiltins = 39;
125+
const static uint32_t NumVUIDBuiltins = 40;
126126

127127
typedef struct {
128128
spv::BuiltIn builtIn;
@@ -172,6 +172,8 @@ std::array<BuiltinVUIDMapping, NumVUIDBuiltins> builtinVUIDInfo = {{
172172
{spv::BuiltIn::PrimitivePointIndicesEXT, {7041, 7043, 7044}},
173173
{spv::BuiltIn::PrimitiveLineIndicesEXT, {7047, 7049, 7050}},
174174
{spv::BuiltIn::PrimitiveTriangleIndicesEXT, {7053, 7055, 7056}},
175+
{spv::BuiltIn::CullPrimitiveEXT, {7034, 7035, 7036}},
176+
175177
// clang-format on
176178
}};
177179

@@ -249,6 +251,7 @@ bool IsExecutionModelValidForRtBuiltIn(spv::BuiltIn builtin,
249251
return false;
250252
}
251253

254+
252255
// Helper class managing validation of built-ins.
253256
// TODO: Generic functionality of this class can be moved into
254257
// ValidationState_t to be made available to other users.
@@ -671,6 +674,24 @@ class BuiltInsValidator {
671674
// instruction.
672675
void Update(const Instruction& inst);
673676

677+
bool isMeshInterfaceVar(const Instruction& inst) {
678+
for (const uint32_t entry_point : _.entry_points()) {
679+
const auto* models = _.GetExecutionModels(entry_point);
680+
if (models->find(spv::ExecutionModel::MeshEXT ) != models->end() ||
681+
models->find(spv::ExecutionModel::MeshNV ) != models->end()) {
682+
for (const auto& desc : _.entry_point_descriptions(entry_point)) {
683+
for (auto interface : desc.interfaces) {
684+
if (inst.id() == interface) {
685+
return true;
686+
}
687+
}
688+
}
689+
}
690+
}
691+
return false;
692+
}
693+
694+
674695
ValidationState_t& _;
675696

676697
// Mapping id -> list of rules which validate instruction referencing the
@@ -2154,6 +2175,17 @@ spv_result_t BuiltInsValidator::ValidatePrimitiveIdAtDefinition(
21542175
return error;
21552176
}
21562177
}
2178+
2179+
if (_.HasCapability(spv::Capability::MeshShadingEXT)) {
2180+
if (isMeshInterfaceVar(inst) &&
2181+
!_.HasDecoration(inst.id(), spv::Decoration::PerPrimitiveEXT)) {
2182+
return _.diag(SPV_ERROR_INVALID_DATA, &inst)
2183+
<< _.VkErrorID(7040)
2184+
<< "According to the Vulkan spec the variable decorated with "
2185+
"Builtin PrimitiveId within the MeshEXT Execution Model must "
2186+
"also be decorated with the PerPrimitiveEXT decoration. ";
2187+
}
2188+
}
21572189
}
21582190

21592191
// Seed at reference checks with this built-in.
@@ -2765,6 +2797,21 @@ spv_result_t BuiltInsValidator::ValidateLayerOrViewportIndexAtDefinition(
27652797
return error;
27662798
}
27672799
}
2800+
2801+
if (isMeshInterfaceVar(inst) &&
2802+
_.HasCapability(spv::Capability::MeshShadingEXT) &&
2803+
!_.HasDecoration(inst.id(), spv::Decoration::PerPrimitiveEXT)) {
2804+
const spv::BuiltIn label = spv::BuiltIn(decoration.params()[0]);
2805+
uint32_t vkerrid = (label == spv::BuiltIn::Layer) ? 7039 : 7060;
2806+
return _.diag(SPV_ERROR_INVALID_DATA, &inst)
2807+
<< _.VkErrorID(vkerrid)
2808+
<< "According to the Vulkan spec the variable decorated with "
2809+
"Builtin "
2810+
<< _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN,
2811+
decoration.params()[0])
2812+
<< " within the MeshEXT Execution Model must also be decorated "
2813+
"with the PerPrimitiveEXT decoration. ";
2814+
}
27682815
}
27692816

27702817
// Seed at reference checks with this built-in.
@@ -3473,6 +3520,7 @@ spv_result_t BuiltInsValidator::ValidateViewIndexAtReference(
34733520
referenced_from_inst, execution_model);
34743521
}
34753522
}
3523+
34763524
}
34773525

34783526
if (function_id_ == 0) {
@@ -4273,7 +4321,6 @@ spv_result_t BuiltInsValidator::ValidateMeshShadingEXTBuiltinsAtReference(
42734321
referenced_from_inst)
42744322
<< " " << GetStorageClassDesc(referenced_from_inst);
42754323
}
4276-
42774324
for (const spv::ExecutionModel execution_model : execution_models_) {
42784325
if (execution_model != spv::ExecutionModel::MeshEXT) {
42794326
uint32_t vuid = GetVUIDForBuiltin(builtin, VUIDErrorExecutionModel);
@@ -4288,6 +4335,92 @@ spv_result_t BuiltInsValidator::ValidateMeshShadingEXTBuiltinsAtReference(
42884335
referenced_from_inst, execution_model);
42894336
}
42904337
}
4338+
4339+
for (const uint32_t entry_point : *entry_points_) {
4340+
// Every entry point from which this function is called needs to have
4341+
// Execution Mode DepthReplacing.
4342+
const auto* modes = _.GetExecutionModes(entry_point);
4343+
uint32_t maxOutputPrimitives = _.GetOutputPrimitivesEXT(entry_point);
4344+
uint32_t underlying_type = 0;
4345+
if (spv_result_t error =
4346+
GetUnderlyingType(_, decoration, referenced_inst, &underlying_type)) {
4347+
return error;
4348+
}
4349+
4350+
uint32_t primitiveArrayDim = 0;
4351+
// Strip the array, if present.
4352+
if (_.GetIdOpcode(underlying_type) == spv::Op::OpTypeArray) {
4353+
underlying_type = _.FindDef(underlying_type)->word(3u);
4354+
if (_.GetIdOpcode(underlying_type) == spv::Op::OpConstant) {
4355+
assert(_.GetIdOpcode(_.FindDef(underlying_type)->word(1)) == spv::Op::OpTypeInt);
4356+
primitiveArrayDim = _.FindDef(underlying_type)->word(3);
4357+
}
4358+
}
4359+
switch (builtin) {
4360+
case spv::BuiltIn::PrimitivePointIndicesEXT:
4361+
if (!modes || !modes->count(spv::ExecutionMode::OutputPoints)) {
4362+
return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst)
4363+
<< _.VkErrorID(7042)
4364+
<< "The PrimitivePointIndicesEXT decoration must be used "
4365+
"with "
4366+
"the OutputPoints Execution Mode. "
4367+
<< GetReferenceDesc(decoration, built_in_inst,
4368+
referenced_inst, referenced_from_inst);
4369+
}
4370+
if (primitiveArrayDim && primitiveArrayDim != maxOutputPrimitives) {
4371+
return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst)
4372+
<< _.VkErrorID(7046)
4373+
<< "The size of the array decorated with "
4374+
"PrimitivePointIndicesEXT must match the value specified "
4375+
"by OutputPrimitivesEXT. "
4376+
<< GetReferenceDesc(decoration, built_in_inst,
4377+
referenced_inst, referenced_from_inst);
4378+
}
4379+
break;
4380+
case spv::BuiltIn::PrimitiveLineIndicesEXT:
4381+
if (!modes || !modes->count(spv::ExecutionMode::OutputLinesEXT)) {
4382+
return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst)
4383+
<< _.VkErrorID(7048)
4384+
<< "The PrimitiveLineIndicesEXT decoration must be used "
4385+
"with "
4386+
"the OutputLinesEXT Execution Mode. "
4387+
<< GetReferenceDesc(decoration, built_in_inst,
4388+
referenced_inst, referenced_from_inst);
4389+
}
4390+
if (primitiveArrayDim && primitiveArrayDim != maxOutputPrimitives) {
4391+
return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst)
4392+
<< _.VkErrorID(7052)
4393+
<< "The size of the array decorated with "
4394+
"PrimitiveLineIndicesEXT must match the value specified "
4395+
"by OutputPrimitivesEXT. "
4396+
<< GetReferenceDesc(decoration, built_in_inst,
4397+
referenced_inst, referenced_from_inst);
4398+
}
4399+
break;
4400+
case spv::BuiltIn::PrimitiveTriangleIndicesEXT:
4401+
if (!modes || !modes->count(spv::ExecutionMode::OutputTrianglesEXT)) {
4402+
return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst)
4403+
<< _.VkErrorID(7054)
4404+
<< "The PrimitiveTriangleIndicesEXT decoration must be used "
4405+
"with "
4406+
"the OutputTrianglesEXT Execution Mode. "
4407+
<< GetReferenceDesc(decoration, built_in_inst,
4408+
referenced_inst, referenced_from_inst);
4409+
}
4410+
if (primitiveArrayDim && primitiveArrayDim != maxOutputPrimitives) {
4411+
return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst)
4412+
<< _.VkErrorID(7058)
4413+
<< "The size of the array decorated with "
4414+
"PrimitiveTriangleIndicesEXT must match the value specified "
4415+
"by OutputPrimitivesEXT. "
4416+
<< GetReferenceDesc(decoration, built_in_inst,
4417+
referenced_inst, referenced_from_inst);
4418+
}
4419+
break;
4420+
default:
4421+
break; // no validation rules
4422+
}
4423+
}
42914424
}
42924425

42934426
if (function_id_ == 0) {
@@ -4476,6 +4609,7 @@ spv_result_t BuiltInsValidator::ValidateSingleBuiltInAtDefinition(
44764609
case spv::BuiltIn::CullMaskKHR: {
44774610
return ValidateRayTracingBuiltinsAtDefinition(decoration, inst);
44784611
}
4612+
case spv::BuiltIn::CullPrimitiveEXT:
44794613
case spv::BuiltIn::PrimitivePointIndicesEXT:
44804614
case spv::BuiltIn::PrimitiveLineIndicesEXT:
44814615
case spv::BuiltIn::PrimitiveTriangleIndicesEXT: {

source/val/validate_decorations.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -767,6 +767,7 @@ spv_result_t CheckDecorationsOfEntryPoints(ValidationState_t& vstate) {
767767
int num_workgroup_variables = 0;
768768
int num_workgroup_variables_with_block = 0;
769769
int num_workgroup_variables_with_aliased = 0;
770+
bool has_task_payload = false;
770771
for (const auto& desc : descs) {
771772
std::unordered_set<Instruction*> seen_vars;
772773
std::unordered_set<spv::BuiltIn> input_var_builtin;
@@ -786,6 +787,19 @@ spv_result_t CheckDecorationsOfEntryPoints(ValidationState_t& vstate) {
786787
const auto sc_index = 2u;
787788
const spv::StorageClass storage_class =
788789
var_instr->GetOperandAs<spv::StorageClass>(sc_index);
790+
if (vstate.version() >= SPV_SPIRV_VERSION_WORD(1, 4)) {
791+
// SPV_EXT_mesh_shader, at most one task payload is permitted
792+
// per entry point
793+
if (storage_class == spv::StorageClass::TaskPayloadWorkgroupEXT) {
794+
if (has_task_payload) {
795+
return vstate.diag(SPV_ERROR_INVALID_ID, var_instr)
796+
<< "There can be at most one OpVariable with storage "
797+
"class TaskPayloadWorkgroupEXT associated with "
798+
"an OpEntryPoint";
799+
}
800+
has_task_payload = true;
801+
}
802+
}
789803
if (vstate.version() >= SPV_SPIRV_VERSION_WORD(1, 4)) {
790804
// Starting in 1.4, OpEntryPoint must list all global variables
791805
// it statically uses and those interfaces must be unique.

source/val/validate_instruction.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,10 @@ spv_result_t InstructionPass(ValidationState_t& _, const Instruction* inst) {
481481
spv::ExecutionMode::LocalSizeId) {
482482
_.RegisterEntryPointLocalSize(entry_point, inst);
483483
}
484+
if (inst->GetOperandAs<spv::ExecutionMode>(1) ==
485+
spv::ExecutionMode::OutputPrimitivesEXT) {
486+
_.RegisterEntryPointOutputPrimitivesEXT(entry_point, inst);
487+
}
484488
} else if (opcode == spv::Op::OpVariable) {
485489
const auto storage_class = inst->GetOperandAs<spv::StorageClass>(2);
486490
if (auto error = LimitCheckNumVars(_, inst->id(), storage_class)) {

source/val/validate_mesh_shading.cpp

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,32 @@
1515
// Validates ray query instructions from SPV_KHR_ray_query
1616

1717
#include "source/opcode.h"
18+
#include "source/spirv_target_env.h"
1819
#include "source/val/instruction.h"
1920
#include "source/val/validate.h"
2021
#include "source/val/validation_state.h"
2122

2223
namespace spvtools {
2324
namespace val {
2425

26+
bool IsInterfaceVariable(ValidationState_t& _, const Instruction* inst,
27+
spv::ExecutionModel model) {
28+
bool foundInterface = false;
29+
for (auto entry_point : _.entry_points()) {
30+
const auto* models = _.GetExecutionModels(entry_point);
31+
if (models->find(model) == models->end()) return false;
32+
for (const auto& desc : _.entry_point_descriptions(entry_point)) {
33+
for (auto interface : desc.interfaces) {
34+
if (inst->id() == interface) {
35+
foundInterface = true;
36+
break;
37+
}
38+
}
39+
}
40+
}
41+
return foundInterface;
42+
}
43+
2544
spv_result_t MeshShadingPass(ValidationState_t& _, const Instruction* inst) {
2645
const spv::Op opcode = inst->opcode();
2746
switch (opcode) {
@@ -103,15 +122,45 @@ spv_result_t MeshShadingPass(ValidationState_t& _, const Instruction* inst) {
103122
return _.diag(SPV_ERROR_INVALID_DATA, inst)
104123
<< "Primitive Count must be a 32-bit unsigned int scalar";
105124
}
106-
125+
107126
break;
108127
}
109128

110129
case spv::Op::OpWritePackedPrimitiveIndices4x8NV: {
111130
// No validation rules (for the moment).
112131
break;
113132
}
114-
133+
case spv::Op::OpVariable: {
134+
if (_.HasCapability(spv::Capability::MeshShadingEXT)) {
135+
bool meshInterfaceVar = IsInterfaceVariable(
136+
_, inst, spv::ExecutionModel::MeshEXT);
137+
bool fragInterfaceVar = IsInterfaceVariable(
138+
_, inst, spv::ExecutionModel::Fragment);
139+
140+
const spv::StorageClass storage_class =
141+
inst->GetOperandAs<spv::StorageClass>(2);
142+
bool storage_output = (storage_class == spv::StorageClass::Output);
143+
bool storage_input = (storage_class == spv::StorageClass::Input);
144+
145+
if (_.HasDecoration(inst->id(), spv::Decoration::PerPrimitiveEXT)) {
146+
if (fragInterfaceVar && !storage_input) {
147+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
148+
<< "PerPrimitiveEXT decoration must be applied only to "
149+
"variables in the Input Storage Class in the Fragment "
150+
"Execution Model.";
151+
}
152+
153+
if (meshInterfaceVar && !storage_output) {
154+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
155+
<< _.VkErrorID(4336)
156+
<< "PerPrimitiveEXT decoration must be applied only to "
157+
"variables in the Output Storage Class in the "
158+
"Storage Class in the MeshEXT Execution Model.";
159+
}
160+
}
161+
}
162+
break;
163+
}
115164
default:
116165
break;
117166
}

source/val/validate_mode_setting.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,15 @@ spv_result_t ValidateExecutionMode(ValidationState_t& _,
543543
"tessellation execution model.";
544544
}
545545
}
546+
if (spvIsVulkanEnv(_.context()->target_env)) {
547+
if (_.HasCapability(spv::Capability::MeshShadingEXT) &&
548+
inst->GetOperandAs<uint32_t>(2) == 0) {
549+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
550+
<< _.VkErrorID(7330)
551+
<< "In mesh shaders using the MeshEXT Execution Model the "
552+
"OutputVertices Execution Mode must be greater than 0";
553+
}
554+
}
546555
break;
547556
case spv::ExecutionMode::OutputLinesEXT:
548557
case spv::ExecutionMode::OutputTrianglesEXT:
@@ -557,6 +566,16 @@ spv_result_t ValidateExecutionMode(ValidationState_t& _,
557566
"execution "
558567
"model.";
559568
}
569+
if (mode == spv::ExecutionMode::OutputPrimitivesEXT &&
570+
spvIsVulkanEnv(_.context()->target_env)) {
571+
if (_.HasCapability(spv::Capability::MeshShadingEXT) &&
572+
inst->GetOperandAs<uint32_t>(2) == 0) {
573+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
574+
<< _.VkErrorID(7331)
575+
<< "In mesh shaders using the MeshEXT Execution Model the "
576+
"OutputPrimitivesEXT Execution Mode must be greater than 0";
577+
}
578+
}
560579
break;
561580
case spv::ExecutionMode::QuadDerivativesKHR:
562581
if (!std::all_of(models->begin(), models->end(),

0 commit comments

Comments
 (0)