Skip to content

Commit df0287c

Browse files
authored
[SPIRV] Implement SM6.6 implicit LOD operations in compute shaders (microsoft#5829)
SPIRV has not yet implemented the changes in SM6.6 that allows [derivatives in compute, mesh, and amplification shaders](https://microsoft.github.io/DirectX-Specs/d3d/HLSL_SM_6_6_Derivatives.html). This is because there is no KHR extension that enabled that capability in SPIR-V. However, we have decided to use [SPV_NV_compute_shader_derivatives](https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/NV/SPV_NV_compute_shader_derivatives.asciidoc) to implement it for compute shader while we wait for a KHR extension. This change only deals with the texture sample instructions. The changes involve 1. modifying code that makes sure these only appear in fragment shaders to allow compute shaders as well. 1. add the extension and capability 1. set the correct execution mode on the function when the intrinsics are used in compute shaders.
1 parent 3ac44bc commit df0287c

18 files changed

+309
-9
lines changed

tools/clang/include/clang/SPIRV/FeatureManager.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ enum class Extension {
5858
EXT_shader_image_int64,
5959
KHR_physical_storage_buffer,
6060
KHR_vulkan_memory_model,
61+
NV_compute_shader_derivatives,
6162
Unknown,
6263
};
6364

tools/clang/lib/SPIRV/CapabilityVisitor.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -876,6 +876,13 @@ bool CapabilityVisitor::visit(SpirvModule *, Visitor::Phase phase) {
876876
spv::Capability::FragmentShaderShadingRateInterlockEXT,
877877
});
878878

879+
addExtensionAndCapabilitiesIfEnabled(
880+
Extension::NV_compute_shader_derivatives,
881+
{
882+
spv::Capability::ComputeDerivativeGroupQuadsNV,
883+
spv::Capability::ComputeDerivativeGroupLinearNV,
884+
});
885+
879886
// AccelerationStructureType or RayQueryType can be provided by both
880887
// ray_tracing and ray_query extension. By default, we select ray_query to
881888
// provide it. This is an arbitrary decision. If the user wants avoid one

tools/clang/lib/SPIRV/FeatureManager.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,8 @@ Extension FeatureManager::getExtensionSymbol(llvm::StringRef name) {
199199
.Case("SPV_KHR_physical_storage_buffer",
200200
Extension::KHR_physical_storage_buffer)
201201
.Case("SPV_KHR_vulkan_memory_model", Extension::KHR_vulkan_memory_model)
202+
.Case("SPV_NV_compute_shader_derivatives",
203+
Extension::NV_compute_shader_derivatives)
202204
.Default(Extension::Unknown);
203205
}
204206

@@ -262,6 +264,8 @@ const char *FeatureManager::getExtensionName(Extension symbol) {
262264
return "SPV_KHR_physical_storage_buffer";
263265
case Extension::KHR_vulkan_memory_model:
264266
return "SPV_KHR_vulkan_memory_model";
267+
case Extension::NV_compute_shader_derivatives:
268+
return "SPV_NV_compute_shader_derivatives";
265269
default:
266270
break;
267271
}

tools/clang/lib/SPIRV/SpirvEmitter.cpp

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4084,6 +4084,9 @@ SpirvEmitter::processTextureLevelOfDetail(const CXXMemberCallExpr *expr,
40844084
spvBuilder.createImageQuery(spv::Op::OpImageQueryLod, queryResultType,
40854085
expr->getExprLoc(), sampledImage, coordinate);
40864086

4087+
if (spvContext.isCS()) {
4088+
addDerivativeGroupExecutionMode();
4089+
}
40874090
// The first component of the float2 contains the mipmap array layer.
40884091
// The second component of the float2 represents the unclamped lod.
40894092
return spvBuilder.createCompositeExtract(astContext.FloatTy, query,
@@ -5305,9 +5308,10 @@ SpirvInstruction *SpirvEmitter::createImageSample(
53055308
// Otherwise we use implicit-lod instructions.
53065309
const bool isExplicit = lod || (grad.first && grad.second);
53075310

5308-
// Implicit-lod instructions are only allowed in pixel shader.
5309-
if (!spvContext.isPS() && !isExplicit)
5310-
emitError("sampling with implicit lod is only allowed in fragment shaders",
5311+
// Implicit-lod instructions are only allowed in pixel and compute shaders.
5312+
if (!spvContext.isPS() && !spvContext.isCS() && !isExplicit)
5313+
emitError("sampling with implicit lod is only allowed in fragment and "
5314+
"compute shaders",
53115315
loc);
53125316

53135317
auto *retVal = spvBuilder.createImageSample(
@@ -5384,6 +5388,9 @@ SpirvEmitter::processTextureSampleGather(const CXXMemberCallExpr *expr,
53845388

53855389
const auto retType = expr->getDirectCallee()->getReturnType();
53865390
if (isSample) {
5391+
if (spvContext.isCS()) {
5392+
addDerivativeGroupExecutionMode();
5393+
}
53875394
return createImageSample(retType, imageType, image, sampler, coordinate,
53885395
/*compareVal*/ nullptr, /*bias*/ nullptr,
53895396
/*lod*/ nullptr, std::make_pair(nullptr, nullptr),
@@ -5471,6 +5478,9 @@ SpirvEmitter::processTextureSampleBiasLevel(const CXXMemberCallExpr *expr,
54715478

54725479
const auto retType = expr->getDirectCallee()->getReturnType();
54735480

5481+
if (!lod && spvContext.isCS()) {
5482+
addDerivativeGroupExecutionMode();
5483+
}
54745484
return createImageSample(
54755485
retType, imageType, image, sampler, coordinate,
54765486
/*compareVal*/ nullptr, bias, lod, std::make_pair(nullptr, nullptr),
@@ -5620,6 +5630,10 @@ SpirvEmitter::processTextureSampleCmpCmpLevelZero(const CXXMemberCallExpr *expr,
56205630
const auto retType = expr->getDirectCallee()->getReturnType();
56215631
const auto imageType = imageExpr->getType();
56225632

5633+
if (!lod && spvContext.isCS()) {
5634+
addDerivativeGroupExecutionMode();
5635+
}
5636+
56235637
return createImageSample(
56245638
retType, imageType, image, sampler, coordinate, compareVal,
56255639
/*bias*/ nullptr, lod, std::make_pair(nullptr, nullptr), constOffset,
@@ -14074,6 +14088,33 @@ bool SpirvEmitter::spirvToolsValidate(std::vector<uint32_t> *mod,
1407414088
return tools.Validate(mod->data(), mod->size(), options);
1407514089
}
1407614090

14091+
void SpirvEmitter::addDerivativeGroupExecutionMode() {
14092+
assert(spvContext.isCS());
14093+
14094+
SpirvExecutionMode *numThreadsEm = spvBuilder.getModule()->findExecutionMode(
14095+
entryFunction, spv::ExecutionMode::LocalSize);
14096+
auto numThreads = numThreadsEm->getParams();
14097+
14098+
// The layout of the quad is determined by the numer of threads in each
14099+
// dimention. From the HLSL spec
14100+
// (https://microsoft.github.io/DirectX-Specs/d3d/HLSL_SM_6_6_Derivatives.html):
14101+
//
14102+
// Where numthreads has an X value divisible by 4 and Y and Z are both 1, the
14103+
// quad layouts are determined according to 1D quad rules. Where numthreads X
14104+
// and Y values are divisible by 2, the quad layouts are determined according
14105+
// to 2D quad rules. Using derivative operations in any numthreads
14106+
// configuration not matching either of these is invalid and will produce an
14107+
// error.
14108+
spv::ExecutionMode em = spv::ExecutionMode::DerivativeGroupQuadsNV;
14109+
if (numThreads[0] % 4 == 0 && numThreads[1] == 1 && numThreads[2] == 1) {
14110+
em = spv::ExecutionMode::DerivativeGroupLinearNV;
14111+
} else {
14112+
assert(numThreads[0] % 2 == 0 && numThreads[1] % 2 == 0);
14113+
}
14114+
14115+
spvBuilder.addExecutionMode(entryFunction, em, {}, SourceLocation());
14116+
}
14117+
1407714118
bool SpirvEmitter::spirvToolsTrimCapabilities(std::vector<uint32_t> *mod,
1407814119
std::string *messages) {
1407914120
spvtools::Optimizer optimizer(featureManager.getTargetEnv());

tools/clang/lib/SPIRV/SpirvEmitter.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,6 +1219,13 @@ class SpirvEmitter : public ASTConsumer {
12191219
/// Returns true on success and false otherwise.
12201220
bool spirvToolsValidate(std::vector<uint32_t> *mod, std::string *messages);
12211221

1222+
/// Adds the appropriate derivative group execution mode to the entry point.
1223+
/// The entry point must already have a LocalSize execution mode, which will
1224+
/// be used to determine which execution mode (quad or linear) is required.
1225+
/// This decision is made according to the rules in
1226+
/// https://microsoft.github.io/DirectX-Specs/d3d/HLSL_SM_6_6_Derivatives.html.
1227+
void addDerivativeGroupExecutionMode();
1228+
12221229
public:
12231230
/// \brief Wrapper method to create a fatal error message and report it
12241231
/// in the diagnostic engine associated with this consumer.
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// RUN: %dxc -T cs_6_6 -E main -fspv-extension=SPV_NV_compute_shader_derivatives -fcgl %s -spirv 2>&1 | FileCheck %s
2+
3+
// This test checks that the execution mode is not added multiple times. Other
4+
// tests will verify that the code generation is correct.
5+
6+
// CHECK: OpCapability ComputeDerivativeGroupQuadsNV
7+
// CHECK: OpExtension "SPV_NV_compute_shader_derivatives"
8+
// CHECK: OpExecutionMode %main DerivativeGroupQuadsNV
9+
// CHECK-NOT: OpExecutionMode %main DerivativeGroupQuadsNV
10+
11+
SamplerState ss : register(s2);
12+
SamplerComparisonState scs;
13+
14+
RWStructuredBuffer<uint> o;
15+
Texture1D <float> t1;
16+
17+
[numthreads(2,2,1)]
18+
void main(uint3 id : SV_GroupThreadID)
19+
{
20+
uint v = id.x;
21+
o[0] = t1.CalculateLevelOfDetail(ss, 0.5);
22+
o[1] = t1.CalculateLevelOfDetailUnclamped(ss, 0.5);
23+
o[2] = t1.Sample(ss, 1);
24+
o[3] = t1.SampleBias(ss, 1, 0.5);
25+
o[4] = t1.SampleCmp(scs, 1, 0.5);
26+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: %dxc -T cs_6_6 -E main -fspv-extension=SPV_NV_compute_shader_derivatives -fcgl %s -spirv 2>&1 | FileCheck %s
2+
3+
// CHECK: OpCapability ComputeDerivativeGroupLinearNV
4+
// CHECK: OpExtension "SPV_NV_compute_shader_derivatives"
5+
// CHECK: OpExecutionMode %main DerivativeGroupLinearNV
6+
7+
SamplerState ss : register(s2);
8+
SamplerComparisonState scs;
9+
10+
RWStructuredBuffer<uint> o;
11+
Texture1D <float> t1;
12+
13+
[numthreads(16,1,1)]
14+
void main(uint3 id : SV_GroupThreadID)
15+
{
16+
//CHECK: [[t1:%[0-9]+]] = OpLoad %type_1d_image %t1
17+
//CHECK-NEXT: [[ss1:%[0-9]+]] = OpLoad %type_sampler %ss
18+
//CHECK-NEXT: [[si1:%[0-9]+]] = OpSampledImage %type_sampled_image [[t1]] [[ss1]]
19+
//CHECK-NEXT: [[query1:%[0-9]+]] = OpImageQueryLod %v2float [[si1]] %float_0_5
20+
//CHECK-NEXT: {{%[0-9]+}} = OpCompositeExtract %float [[query1]] 0
21+
o[0] = t1.CalculateLevelOfDetail(ss, 0.5);
22+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: %dxc -T cs_6_6 -E main -fspv-extension=SPV_NV_compute_shader_derivatives -fcgl %s -spirv 2>&1 | FileCheck %s
2+
3+
// CHECK: OpCapability ComputeDerivativeGroupQuadsNV
4+
// CHECK: OpExtension "SPV_NV_compute_shader_derivatives"
5+
// CHECK: OpExecutionMode %main DerivativeGroupQuadsNV
6+
7+
SamplerState ss : register(s2);
8+
SamplerComparisonState scs;
9+
10+
RWStructuredBuffer<uint> o;
11+
Texture1D <float> t1;
12+
13+
[numthreads(8,8,1)]
14+
void main(uint3 id : SV_GroupThreadID)
15+
{
16+
//CHECK: [[t1:%[0-9]+]] = OpLoad %type_1d_image %t1
17+
//CHECK-NEXT: [[ss1:%[0-9]+]] = OpLoad %type_sampler %ss
18+
//CHECK-NEXT: [[si1:%[0-9]+]] = OpSampledImage %type_sampled_image [[t1]] [[ss1]]
19+
//CHECK-NEXT: [[query1:%[0-9]+]] = OpImageQueryLod %v2float [[si1]] %float_0_5
20+
//CHECK-NEXT: {{%[0-9]+}} = OpCompositeExtract %float [[query1]] 0
21+
o[0] = t1.CalculateLevelOfDetail(ss, 0.5);
22+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: %dxc -T cs_6_6 -E main -fspv-extension=SPV_NV_compute_shader_derivatives -fcgl %s -spirv 2>&1 | FileCheck %s
2+
3+
// CHECK: OpCapability ComputeDerivativeGroupLinearNV
4+
// CHECK: OpExtension "SPV_NV_compute_shader_derivatives"
5+
// CHECK: OpExecutionMode %main DerivativeGroupLinearNV
6+
7+
SamplerState ss : register(s2);
8+
SamplerComparisonState scs;
9+
10+
RWStructuredBuffer<uint> o;
11+
Texture1D <float> t1;
12+
13+
[numthreads(4,1,1)]
14+
void main(uint3 id : SV_GroupThreadID)
15+
{
16+
//CHECK: [[t1:%[0-9]+]] = OpLoad %type_1d_image %t1
17+
//CHECK-NEXT: [[ss1:%[0-9]+]] = OpLoad %type_sampler %ss
18+
//CHECK-NEXT: [[si1:%[0-9]+]] = OpSampledImage %type_sampled_image [[t1]] [[ss1]]
19+
//CHECK-NEXT: [[query1:%[0-9]+]] = OpImageQueryLod %v2float [[si1]] %float_0_5
20+
//CHECK-NEXT: {{%[0-9]+}} = OpCompositeExtract %float [[query1]] 1
21+
o[0] = t1.CalculateLevelOfDetailUnclamped(ss, 0.5);
22+
}

0 commit comments

Comments
 (0)