diff --git a/docs/SPIR-V.rst b/docs/SPIR-V.rst index a695e5854d..d9281956d3 100644 --- a/docs/SPIR-V.rst +++ b/docs/SPIR-V.rst @@ -285,6 +285,21 @@ Right now the following ```` are supported: Please see Vulkan spec. `15.9. Built-In Variables `_ for detailed explanation of these builtins. +Helper Lane Support +~~~~~~~~~~~~~~~~~~~ + +Shader Model 6.7 introduces the `[WaveOpsIncludeHelperLanes]` attribute. When this +attribute is applied to a shader entry point, the SPIR-V backend will: + +1. Add the ``SPV_KHR_maximal_reconvergence`` and ``SPV_KHR_quad_control`` + extensions to the module. +2. Add the ``QuadControlKHR`` capability. +3. Add the ``MaximallyReconvergesKHR`` and ``RequireFullQuadsKHR`` execution modes + to the entry point. + +This ensures that helper lanes are included in wave operations, which is the +behavior required by the HLSL specification. + Supported extensions ~~~~~~~~~~~~~~~~~~~~ diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.cpp b/tools/clang/lib/SPIRV/SpirvEmitter.cpp index f231625001..6b67719b4e 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.cpp +++ b/tools/clang/lib/SPIRV/SpirvEmitter.cpp @@ -862,6 +862,27 @@ void SpirvEmitter::HandleTranslationUnit(ASTContext &context) { SourceLocation()); } + for (uint32_t i = 0; i < workQueue.size(); ++i) { + const FunctionInfo *entryInfo = workQueue[i]; + if (entryInfo->isEntryFunction) { + const auto *funcDecl = entryInfo->funcDecl; + if (funcDecl->hasAttr()) { + spvBuilder.requireExtension("SPV_KHR_maximal_reconvergence", + funcDecl->getLocation()); + spvBuilder.requireExtension("SPV_KHR_quad_control", + funcDecl->getLocation()); + spvBuilder.requireCapability(spv::Capability::QuadControlKHR, + funcDecl->getLocation()); + spvBuilder.addExecutionMode(entryInfo->entryFunction, + spv::ExecutionMode::MaximallyReconvergesKHR, + {}, funcDecl->getLocation()); + spvBuilder.addExecutionMode(entryInfo->entryFunction, + spv::ExecutionMode::RequireFullQuadsKHR, {}, + funcDecl->getLocation()); + } + } + } + // For Vulkan 1.2 and later, add SignedZeroInfNanPreserve when -Gis is // provided to preserve NaN/Inf and signed zeros. if (spirvOptions.IEEEStrict) { diff --git a/tools/clang/test/CodeGenSPIRV/wave-ops-include-helper-lanes-lib.hlsl b/tools/clang/test/CodeGenSPIRV/wave-ops-include-helper-lanes-lib.hlsl new file mode 100644 index 0000000000..e9d0fe187e --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/wave-ops-include-helper-lanes-lib.hlsl @@ -0,0 +1,35 @@ +// RUN: %dxc -T lib_6_7 -spirv %s | FileCheck %s + +// CHECK: OpCapability QuadControlKHR +// CHECK-DAG: OpExtension "SPV_KHR_maximal_reconvergence" +// CHECK-DAG: OpExtension "SPV_KHR_quad_control" + +// CHECK: OpEntryPoint Fragment %ps_main1 "ps_main1" +// CHECK: OpEntryPoint Fragment %ps_main2 "ps_main2" +// CHECK: OpEntryPoint Fragment %ps_main3 "ps_main3" + +// CHECK-DAG: OpExecutionMode %ps_main1 MaximallyReconvergesKHR +// CHECK-DAG: OpExecutionMode %ps_main1 RequireFullQuadsKHR + +// CHECK-NOT: OpExecutionMode %ps_main2 MaximallyReconvergesKHR +// CHECK-NOT: OpExecutionMode %ps_main2 RequireFullQuadsKHR + +// CHECK-DAG: OpExecutionMode %ps_main3 MaximallyReconvergesKHR +// CHECK-DAG: OpExecutionMode %ps_main3 RequireFullQuadsKHR + +[WaveOpsIncludeHelperLanes] +[shader("pixel")] +void ps_main1() : SV_Target0 +{ +} + +[shader("pixel")] +void ps_main2() : SV_Target0 +{ +} + +[WaveOpsIncludeHelperLanes] +[shader("pixel")] +void ps_main3() : SV_Target0 +{ +} diff --git a/tools/clang/test/CodeGenSPIRV/wave-ops-include-helper-lanes.hlsl b/tools/clang/test/CodeGenSPIRV/wave-ops-include-helper-lanes.hlsl new file mode 100644 index 0000000000..8d79bd6a28 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/wave-ops-include-helper-lanes.hlsl @@ -0,0 +1,14 @@ +// RUN: %dxc -T ps_6_7 -E main -spirv %s | FileCheck %s + +// CHECK: OpCapability QuadControlKHR +// CHECK-DAG: OpExtension "SPV_KHR_maximal_reconvergence" +// CHECK-DAG: OpExtension "SPV_KHR_quad_control" + +// CHECK: OpExecutionMode %main MaximallyReconvergesKHR +// CHECK: OpExecutionMode %main RequireFullQuadsKHR + +[WaveOpsIncludeHelperLanes] +float4 main(float4 pos : SV_Position) : SV_Target +{ + return pos; +}