Skip to content

Commit 44936c4

Browse files
Add support for SPV_KHR_compute_shader_derivative (#5817)
* Add support for SPV_KHR_compute_shader_derivative * Update tests for SPV_KHR_compute_shader_derivatives --------- Co-authored-by: MagicPoncho <[email protected]>
1 parent 362ce7c commit 44936c4

9 files changed

+90
-69
lines changed

source/opt/aggressive_dead_code_elim_pass.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1010,7 +1010,7 @@ void AggressiveDCEPass::InitExtensions() {
10101010
"SPV_NV_bindless_texture",
10111011
"SPV_EXT_shader_atomic_float_add",
10121012
"SPV_EXT_fragment_shader_interlock",
1013-
"SPV_NV_compute_shader_derivatives",
1013+
"SPV_KHR_compute_shader_derivatives",
10141014
"SPV_NV_cooperative_matrix",
10151015
"SPV_KHR_cooperative_matrix",
10161016
"SPV_KHR_ray_tracing_position_fetch"

source/opt/local_access_chain_convert_pass.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -428,9 +428,9 @@ void LocalAccessChainConvertPass::InitExtensions() {
428428
"SPV_KHR_uniform_group_instructions",
429429
"SPV_KHR_fragment_shader_barycentric", "SPV_KHR_vulkan_memory_model",
430430
"SPV_NV_bindless_texture", "SPV_EXT_shader_atomic_float_add",
431-
"SPV_EXT_fragment_shader_interlock", "SPV_NV_compute_shader_derivatives",
432-
"SPV_NV_cooperative_matrix", "SPV_KHR_cooperative_matrix",
433-
"SPV_KHR_ray_tracing_position_fetch"});
431+
"SPV_EXT_fragment_shader_interlock",
432+
"SPV_KHR_compute_shader_derivatives", "SPV_NV_cooperative_matrix",
433+
"SPV_KHR_cooperative_matrix", "SPV_KHR_ray_tracing_position_fetch"});
434434
}
435435

436436
bool LocalAccessChainConvertPass::AnyIndexIsOutOfBounds(

source/opt/local_single_block_elim_pass.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ void LocalSingleBlockLoadStoreElimPass::InitExtensions() {
291291
"SPV_NV_bindless_texture",
292292
"SPV_EXT_shader_atomic_float_add",
293293
"SPV_EXT_fragment_shader_interlock",
294-
"SPV_NV_compute_shader_derivatives",
294+
"SPV_KHR_compute_shader_derivatives",
295295
"SPV_NV_cooperative_matrix",
296296
"SPV_KHR_cooperative_matrix",
297297
"SPV_KHR_ray_tracing_position_fetch"});

source/opt/local_single_store_elim_pass.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ void LocalSingleStoreElimPass::InitExtensionAllowList() {
141141
"SPV_NV_bindless_texture",
142142
"SPV_EXT_shader_atomic_float_add",
143143
"SPV_EXT_fragment_shader_interlock",
144-
"SPV_NV_compute_shader_derivatives",
144+
"SPV_KHR_compute_shader_derivatives",
145145
"SPV_NV_cooperative_matrix",
146146
"SPV_KHR_cooperative_matrix",
147147
"SPV_KHR_ray_tracing_position_fetch"});

source/opt/trim_capabilities_pass.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ class TrimCapabilitiesPass : public Pass {
7474
// contains unsupported instruction, the pass could yield bad results.
7575
static constexpr std::array kSupportedCapabilities{
7676
// clang-format off
77-
spv::Capability::ComputeDerivativeGroupLinearNV,
78-
spv::Capability::ComputeDerivativeGroupQuadsNV,
77+
spv::Capability::ComputeDerivativeGroupLinearKHR,
78+
spv::Capability::ComputeDerivativeGroupQuadsKHR,
7979
spv::Capability::Float16,
8080
spv::Capability::Float64,
8181
spv::Capability::FragmentShaderPixelInterlockEXT,

source/val/validate_derivatives.cpp

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,14 @@ spv_result_t DerivativesPass(ValidationState_t& _, const Instruction* inst) {
6060
->RegisterExecutionModelLimitation([opcode](spv::ExecutionModel model,
6161
std::string* message) {
6262
if (model != spv::ExecutionModel::Fragment &&
63-
model != spv::ExecutionModel::GLCompute) {
63+
model != spv::ExecutionModel::GLCompute &&
64+
model != spv::ExecutionModel::MeshEXT &&
65+
model != spv::ExecutionModel::TaskEXT) {
6466
if (message) {
6567
*message =
6668
std::string(
67-
"Derivative instructions require Fragment or GLCompute "
68-
"execution model: ") +
69+
"Derivative instructions require Fragment, GLCompute, "
70+
"MeshEXT or TaskEXT execution model: ") +
6971
spvOpcodeString(opcode);
7072
}
7173
return false;
@@ -79,19 +81,23 @@ spv_result_t DerivativesPass(ValidationState_t& _, const Instruction* inst) {
7981
const auto* models = state.GetExecutionModels(entry_point->id());
8082
const auto* modes = state.GetExecutionModes(entry_point->id());
8183
if (models &&
82-
models->find(spv::ExecutionModel::GLCompute) != models->end() &&
84+
(models->find(spv::ExecutionModel::GLCompute) !=
85+
models->end() ||
86+
models->find(spv::ExecutionModel::MeshEXT) != models->end() ||
87+
models->find(spv::ExecutionModel::TaskEXT) != models->end()) &&
8388
(!modes ||
84-
(modes->find(spv::ExecutionMode::DerivativeGroupLinearNV) ==
89+
(modes->find(spv::ExecutionMode::DerivativeGroupLinearKHR) ==
8590
modes->end() &&
86-
modes->find(spv::ExecutionMode::DerivativeGroupQuadsNV) ==
91+
modes->find(spv::ExecutionMode::DerivativeGroupQuadsKHR) ==
8792
modes->end()))) {
8893
if (message) {
89-
*message = std::string(
90-
"Derivative instructions require "
91-
"DerivativeGroupQuadsNV "
92-
"or DerivativeGroupLinearNV execution mode for "
93-
"GLCompute execution model: ") +
94-
spvOpcodeString(opcode);
94+
*message =
95+
std::string(
96+
"Derivative instructions require "
97+
"DerivativeGroupQuadsKHR "
98+
"or DerivativeGroupLinearKHR execution mode for "
99+
"GLCompute, MeshEXT or TaskEXT execution model: ") +
100+
spvOpcodeString(opcode);
95101
}
96102
return false;
97103
}

source/val/validate_image.cpp

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2026,11 +2026,13 @@ spv_result_t ValidateImageQueryLod(ValidationState_t& _,
20262026
->RegisterExecutionModelLimitation(
20272027
[&](spv::ExecutionModel model, std::string* message) {
20282028
if (model != spv::ExecutionModel::Fragment &&
2029-
model != spv::ExecutionModel::GLCompute) {
2029+
model != spv::ExecutionModel::GLCompute &&
2030+
model != spv::ExecutionModel::MeshEXT &&
2031+
model != spv::ExecutionModel::TaskEXT) {
20302032
if (message) {
20312033
*message = std::string(
2032-
"OpImageQueryLod requires Fragment or GLCompute execution "
2033-
"model");
2034+
"OpImageQueryLod requires Fragment, GLCompute, MeshEXT or "
2035+
"TaskEXT execution model");
20342036
}
20352037
return false;
20362038
}
@@ -2042,16 +2044,20 @@ spv_result_t ValidateImageQueryLod(ValidationState_t& _,
20422044
std::string* message) {
20432045
const auto* models = state.GetExecutionModels(entry_point->id());
20442046
const auto* modes = state.GetExecutionModes(entry_point->id());
2045-
if (models->find(spv::ExecutionModel::GLCompute) != models->end() &&
2046-
modes->find(spv::ExecutionMode::DerivativeGroupLinearNV) ==
2047-
modes->end() &&
2048-
modes->find(spv::ExecutionMode::DerivativeGroupQuadsNV) ==
2049-
modes->end()) {
2047+
if (models &&
2048+
(models->find(spv::ExecutionModel::GLCompute) != models->end() ||
2049+
models->find(spv::ExecutionModel::MeshEXT) != models->end() ||
2050+
models->find(spv::ExecutionModel::TaskEXT) != models->end()) &&
2051+
(!modes ||
2052+
(modes->find(spv::ExecutionMode::DerivativeGroupLinearKHR) ==
2053+
modes->end() &&
2054+
modes->find(spv::ExecutionMode::DerivativeGroupQuadsKHR) ==
2055+
modes->end()))) {
20502056
if (message) {
20512057
*message = std::string(
2052-
"OpImageQueryLod requires DerivativeGroupQuadsNV "
2053-
"or DerivativeGroupLinearNV execution mode for GLCompute "
2054-
"execution model");
2058+
"OpImageQueryLod requires DerivativeGroupQuadsKHR "
2059+
"or DerivativeGroupLinearKHR execution mode for GLCompute, "
2060+
"MeshEXT or TaskEXT execution model");
20552061
}
20562062
return false;
20572063
}
@@ -2320,12 +2326,14 @@ spv_result_t ImagePass(ValidationState_t& _, const Instruction* inst) {
23202326
->RegisterExecutionModelLimitation([opcode](spv::ExecutionModel model,
23212327
std::string* message) {
23222328
if (model != spv::ExecutionModel::Fragment &&
2323-
model != spv::ExecutionModel::GLCompute) {
2329+
model != spv::ExecutionModel::GLCompute &&
2330+
model != spv::ExecutionModel::MeshEXT &&
2331+
model != spv::ExecutionModel::TaskEXT) {
23242332
if (message) {
23252333
*message =
23262334
std::string(
2327-
"ImplicitLod instructions require Fragment or GLCompute "
2328-
"execution model: ") +
2335+
"ImplicitLod instructions require Fragment, GLCompute, "
2336+
"MeshEXT or TaskEXT execution model: ") +
23292337
spvOpcodeString(opcode);
23302338
}
23312339
return false;
@@ -2339,19 +2347,22 @@ spv_result_t ImagePass(ValidationState_t& _, const Instruction* inst) {
23392347
const auto* models = state.GetExecutionModels(entry_point->id());
23402348
const auto* modes = state.GetExecutionModes(entry_point->id());
23412349
if (models &&
2342-
models->find(spv::ExecutionModel::GLCompute) != models->end() &&
2350+
(models->find(spv::ExecutionModel::GLCompute) != models->end() ||
2351+
models->find(spv::ExecutionModel::MeshEXT) != models->end() ||
2352+
models->find(spv::ExecutionModel::TaskEXT) != models->end()) &&
23432353
(!modes ||
2344-
(modes->find(spv::ExecutionMode::DerivativeGroupLinearNV) ==
2354+
(modes->find(spv::ExecutionMode::DerivativeGroupLinearKHR) ==
23452355
modes->end() &&
2346-
modes->find(spv::ExecutionMode::DerivativeGroupQuadsNV) ==
2356+
modes->find(spv::ExecutionMode::DerivativeGroupQuadsKHR) ==
23472357
modes->end()))) {
23482358
if (message) {
2349-
*message =
2350-
std::string(
2351-
"ImplicitLod instructions require DerivativeGroupQuadsNV "
2352-
"or DerivativeGroupLinearNV execution mode for GLCompute "
2353-
"execution model: ") +
2354-
spvOpcodeString(opcode);
2359+
*message = std::string(
2360+
"ImplicitLod instructions require "
2361+
"DerivativeGroupQuadsKHR "
2362+
"or DerivativeGroupLinearKHR execution mode for "
2363+
"GLCompute, "
2364+
"MeshEXT or TaskEXT execution model: ") +
2365+
spvOpcodeString(opcode);
23552366
}
23562367
return false;
23572368
}

test/val/val_derivatives_test.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,8 @@ TEST_F(ValidateDerivatives, OpDPdxWrongExecutionModel) {
156156
CompileSuccessfully(GenerateShaderCode(body, "", "Vertex").c_str());
157157
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
158158
EXPECT_THAT(getDiagnosticString(),
159-
HasSubstr("Derivative instructions require Fragment or GLCompute "
160-
"execution model: DPdx"));
159+
HasSubstr("Derivative instructions require Fragment, GLCompute, "
160+
"MeshEXT or TaskEXT execution model: DPdx"));
161161
}
162162

163163
TEST_F(ValidateDerivatives, NoExecutionModeGLCompute) {
@@ -181,8 +181,9 @@ OpFunctionEnd
181181
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
182182
EXPECT_THAT(getDiagnosticString(),
183183
HasSubstr("Derivative instructions require "
184-
"DerivativeGroupQuadsNV or DerivativeGroupLinearNV "
185-
"execution mode for GLCompute execution model"));
184+
"DerivativeGroupQuadsKHR or DerivativeGroupLinearKHR "
185+
"execution mode for GLCompute, MeshEXT or TaskEXT "
186+
"execution model"));
186187
}
187188

188189
using ValidateHalfDerivatives = spvtest::ValidateBase<std::string>;

test/val/val_image_test.cpp

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4780,7 +4780,8 @@ TEST_F(ValidateImage, QueryLodWrongExecutionModel) {
47804780
EXPECT_THAT(
47814781
getDiagnosticString(),
47824782
HasSubstr(
4783-
"OpImageQueryLod requires Fragment or GLCompute execution model"));
4783+
"OpImageQueryLod requires Fragment, GLCompute, MeshEXT or TaskEXT "
4784+
"execution model"));
47844785
}
47854786

47864787
TEST_F(ValidateImage, QueryLodWrongExecutionModelWithFunc) {
@@ -4801,7 +4802,8 @@ OpFunctionEnd
48014802
EXPECT_THAT(
48024803
getDiagnosticString(),
48034804
HasSubstr(
4804-
"OpImageQueryLod requires Fragment or GLCompute execution model"));
4805+
"OpImageQueryLod requires Fragment, GLCompute, MeshEXT or TaskEXT "
4806+
"execution model"));
48054807
}
48064808

48074809
TEST_F(ValidateImage, QueryLodComputeShaderDerivatives) {
@@ -4813,12 +4815,12 @@ TEST_F(ValidateImage, QueryLodComputeShaderDerivatives) {
48134815
)";
48144816

48154817
const std::string extra = R"(
4816-
OpCapability ComputeDerivativeGroupLinearNV
4817-
OpExtension "SPV_NV_compute_shader_derivatives"
4818+
OpCapability ComputeDerivativeGroupLinearKHR
4819+
OpExtension "SPV_KHR_compute_shader_derivatives"
48184820
)";
48194821
const std::string mode = R"(
48204822
OpExecutionMode %main LocalSize 8 8 1
4821-
OpExecutionMode %main DerivativeGroupLinearNV
4823+
OpExecutionMode %main DerivativeGroupLinearKHR
48224824
)";
48234825
CompileSuccessfully(
48244826
GenerateShaderCode(body, extra, "GLCompute", mode).c_str());
@@ -4930,8 +4932,8 @@ TEST_F(ValidateImage, QueryLodComputeShaderDerivativesMissingMode) {
49304932
)";
49314933

49324934
const std::string extra = R"(
4933-
OpCapability ComputeDerivativeGroupLinearNV
4934-
OpExtension "SPV_NV_compute_shader_derivatives"
4935+
OpCapability ComputeDerivativeGroupLinearKHR
4936+
OpExtension "SPV_KHR_compute_shader_derivatives"
49354937
)";
49364938
const std::string mode = R"(
49374939
OpExecutionMode %main LocalSize 8 8 1
@@ -4940,9 +4942,9 @@ OpExecutionMode %main LocalSize 8 8 1
49404942
GenerateShaderCode(body, extra, "GLCompute", mode).c_str());
49414943
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
49424944
EXPECT_THAT(getDiagnosticString(),
4943-
HasSubstr("OpImageQueryLod requires DerivativeGroupQuadsNV or "
4944-
"DerivativeGroupLinearNV execution mode for GLCompute "
4945-
"execution model"));
4945+
HasSubstr("OpImageQueryLod requires DerivativeGroupQuadsKHR or "
4946+
"DerivativeGroupLinearKHR execution mode for "
4947+
"GLCompute, MeshEXT or TaskEXT execution model"));
49464948
}
49474949

49484950
TEST_F(ValidateImage, ImplicitLodWrongExecutionModel) {
@@ -4956,8 +4958,8 @@ TEST_F(ValidateImage, ImplicitLodWrongExecutionModel) {
49564958
CompileSuccessfully(GenerateShaderCode(body, "", "Vertex").c_str());
49574959
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
49584960
EXPECT_THAT(getDiagnosticString(),
4959-
HasSubstr("ImplicitLod instructions require Fragment or "
4960-
"GLCompute execution model"));
4961+
HasSubstr("ImplicitLod instructions require Fragment, "
4962+
"GLCompute, MeshEXT or TaskEXT execution model"));
49614963
}
49624964

49634965
TEST_F(ValidateImage, ImplicitLodComputeShaderDerivatives) {
@@ -4969,12 +4971,12 @@ TEST_F(ValidateImage, ImplicitLodComputeShaderDerivatives) {
49694971
)";
49704972

49714973
const std::string extra = R"(
4972-
OpCapability ComputeDerivativeGroupLinearNV
4973-
OpExtension "SPV_NV_compute_shader_derivatives"
4974+
OpCapability ComputeDerivativeGroupLinearKHR
4975+
OpExtension "SPV_KHR_compute_shader_derivatives"
49744976
)";
49754977
const std::string mode = R"(
49764978
OpExecutionMode %main LocalSize 8 8 1
4977-
OpExecutionMode %main DerivativeGroupLinearNV
4979+
OpExecutionMode %main DerivativeGroupLinearKHR
49784980
)";
49794981
CompileSuccessfully(
49804982
GenerateShaderCode(body, extra, "GLCompute", mode).c_str());
@@ -4990,8 +4992,8 @@ TEST_F(ValidateImage, ImplicitLodComputeShaderDerivativesMissingMode) {
49904992
)";
49914993

49924994
const std::string extra = R"(
4993-
OpCapability ComputeDerivativeGroupLinearNV
4994-
OpExtension "SPV_NV_compute_shader_derivatives"
4995+
OpCapability ComputeDerivativeGroupLinearKHR
4996+
OpExtension "SPV_KHR_compute_shader_derivatives"
49954997
)";
49964998
const std::string mode = R"(
49974999
OpExecutionMode %main LocalSize 8 8 1
@@ -5001,9 +5003,9 @@ OpExecutionMode %main LocalSize 8 8 1
50015003
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
50025004
EXPECT_THAT(
50035005
getDiagnosticString(),
5004-
HasSubstr("ImplicitLod instructions require DerivativeGroupQuadsNV or "
5005-
"DerivativeGroupLinearNV execution mode for GLCompute "
5006-
"execution model"));
5006+
HasSubstr("ImplicitLod instructions require DerivativeGroupQuadsKHR or "
5007+
"DerivativeGroupLinearKHR execution mode for GLCompute, "
5008+
"MeshEXT or TaskEXT execution model"));
50075009
}
50085010

50095011
TEST_F(ValidateImage, ReadSubpassDataWrongExecutionModel) {
@@ -6505,8 +6507,9 @@ OpFunctionEnd
65056507
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
65066508
EXPECT_THAT(getDiagnosticString(),
65076509
HasSubstr("ImplicitLod instructions require "
6508-
"DerivativeGroupQuadsNV or DerivativeGroupLinearNV "
6509-
"execution mode for GLCompute execution model"));
6510+
"DerivativeGroupQuadsKHR or DerivativeGroupLinearKHR "
6511+
"execution mode for GLCompute, MeshEXT or TaskEXT "
6512+
"execution model"));
65106513
}
65116514

65126515
TEST_F(ValidateImage, TypeSampledImageNotBufferPost1p6) {

0 commit comments

Comments
 (0)