Skip to content

Commit 6556410

Browse files
[SPIR-V] Implement QuadAny and QuadAll (microsoft#7266)
If `"SPV_KHR_quad_control"` can be used, uses `OpGroupNonUniformQuadAnyKHR` and `OpGroupNonUniformQuadAllKHR`. If not, falls back to constructing the value using `OpGroupNonUniformQuadSwap`. Fixes microsoft#7247
1 parent 3b1a29b commit 6556410

File tree

12 files changed

+130
-9
lines changed

12 files changed

+130
-9
lines changed

docs/SPIR-V.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ Supported extensions
320320
* SPV_KHR_maximal_reconvergence
321321
* SPV_KHR_float_controls
322322
* SPV_NV_shader_subgroup_partitioned
323+
* SPV_KHR_quad_control
323324

324325
Vulkan specific attributes
325326
--------------------------
@@ -4008,6 +4009,8 @@ Quad ``QuadReadAcrossX()`` ``OpGroupNonUniformQuadSwap``
40084009
Quad ``QuadReadAcrossY()`` ``OpGroupNonUniformQuadSwap``
40094010
Quad ``QuadReadAcrossDiagonal()`` ``OpGroupNonUniformQuadSwap``
40104011
Quad ``QuadReadLaneAt()`` ``OpGroupNonUniformQuadBroadcast``
4012+
Quad ``QuadAny()`` ``OpGroupNonUniformQuadAnyKHR``
4013+
Quad ``QuadAll()`` ``OpGroupNonUniformQuadAllKHR``
40114014
N/A ``WaveMatch()`` ``OpGroupNonUniformPartitionNV``
40124015
Multiprefix ``WaveMultiPrefixSum()`` ``OpGroupNonUniform*Add`` ``PartitionedExclusiveScanNV``
40134016
Multiprefix ``WaveMultiPrefixProduct()`` ``OpGroupNonUniform*Mul`` ``PartitionedExclusiveScanNV``
@@ -4016,6 +4019,11 @@ Multiprefix ``WaveMultiPrefixBitOr()`` ``OpGroupNonUniformLogicalOr`` `
40164019
Multiprefix ``WaveMultiPrefixBitXor()`` ``OpGroupNonUniformLogicalXor`` ``PartitionedExclusiveScanNV``
40174020
============= ============================ =================================== ==============================
40184021

4022+
``QuadAny`` and ``QuadAll`` will use the ``OpGroupNonUniformQuadAnyKHR`` and
4023+
``OpGroupNonUniformQuadAllKHR`` instructions if the ``SPV_KHR_quad_control``
4024+
extension is enabled. If it is not, they will fall back to constructing the
4025+
value using multiple calls to ``OpGroupNonUniformQuadBroadcast``.
4026+
40194027
The Implicit ``vk`` Namespace
40204028
=============================
40214029

tools/clang/include/clang/SPIRV/FeatureManager.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ enum class Extension {
6464
KHR_maximal_reconvergence,
6565
KHR_float_controls,
6666
NV_shader_subgroup_partitioned,
67+
KHR_quad_control,
6768
Unknown,
6869
};
6970

tools/clang/include/clang/SPIRV/SpirvBuilder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ class SpirvBuilder {
242242
/// \brief Creates an operation with the given OpGroupNonUniform* SPIR-V
243243
/// opcode.
244244
SpirvGroupNonUniformOp *createGroupNonUniformOp(
245-
spv::Op op, QualType resultType, spv::Scope execScope,
245+
spv::Op op, QualType resultType, llvm::Optional<spv::Scope> execScope,
246246
llvm::ArrayRef<SpirvInstruction *> operands, SourceLocation,
247247
llvm::Optional<spv::GroupOperation> groupOp = llvm::None);
248248

tools/clang/include/clang/SPIRV/SpirvInstruction.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1566,7 +1566,8 @@ class SpirvFunctionCall : public SpirvInstruction {
15661566
/// \brief OpGroupNonUniform* instructions
15671567
class SpirvGroupNonUniformOp : public SpirvInstruction {
15681568
public:
1569-
SpirvGroupNonUniformOp(spv::Op opcode, QualType resultType, spv::Scope scope,
1569+
SpirvGroupNonUniformOp(spv::Op opcode, QualType resultType,
1570+
llvm::Optional<spv::Scope> scope,
15701571
llvm::ArrayRef<SpirvInstruction *> operands,
15711572
SourceLocation loc,
15721573
llvm::Optional<spv::GroupOperation> group);
@@ -1580,7 +1581,8 @@ class SpirvGroupNonUniformOp : public SpirvInstruction {
15801581

15811582
bool invokeVisitor(Visitor *v) override;
15821583

1583-
spv::Scope getExecutionScope() const { return execScope; }
1584+
bool hasExecutionScope() const { return execScope.hasValue(); }
1585+
spv::Scope getExecutionScope() const { return execScope.getValue(); }
15841586

15851587
llvm::ArrayRef<SpirvInstruction *> getOperands() const { return operands; }
15861588

@@ -1598,7 +1600,7 @@ class SpirvGroupNonUniformOp : public SpirvInstruction {
15981600
}
15991601

16001602
private:
1601-
spv::Scope execScope;
1603+
llvm::Optional<spv::Scope> execScope;
16021604
llvm::SmallVector<SpirvInstruction *, 4> operands;
16031605
llvm::Optional<spv::GroupOperation> groupOp;
16041606
};

tools/clang/lib/SPIRV/CapabilityVisitor.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -887,6 +887,9 @@ bool CapabilityVisitor::visit(SpirvModule *, Visitor::Phase phase) {
887887

888888
addCapability(spv::Capability::InterpolationFunction);
889889

890+
addExtensionAndCapabilitiesIfEnabled(Extension::KHR_quad_control,
891+
{spv::Capability::QuadControlKHR});
892+
890893
return true;
891894
}
892895

tools/clang/lib/SPIRV/EmitVisitor.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1134,9 +1134,10 @@ bool EmitVisitor::visit(SpirvGroupNonUniformOp *inst) {
11341134
initInstruction(inst);
11351135
curInst.push_back(inst->getResultTypeId());
11361136
curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst));
1137-
curInst.push_back(typeHandler.getOrCreateConstantInt(
1138-
llvm::APInt(32, static_cast<uint32_t>(inst->getExecutionScope())),
1139-
context.getUIntType(32), /* isSpecConst */ false));
1137+
if (inst->hasExecutionScope())
1138+
curInst.push_back(typeHandler.getOrCreateConstantInt(
1139+
llvm::APInt(32, static_cast<uint32_t>(inst->getExecutionScope())),
1140+
context.getUIntType(32), /* isSpecConst */ false));
11401141
if (inst->hasGroupOp())
11411142
curInst.push_back(static_cast<uint32_t>(inst->getGroupOp()));
11421143
for (auto *operand : inst->getOperands())

tools/clang/lib/SPIRV/FeatureManager.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ Extension FeatureManager::getExtensionSymbol(llvm::StringRef name) {
226226
.Case("SPV_KHR_float_controls", Extension::KHR_float_controls)
227227
.Case("SPV_NV_shader_subgroup_partitioned",
228228
Extension::NV_shader_subgroup_partitioned)
229+
.Case("SPV_KHR_quad_control", Extension::KHR_quad_control)
229230
.Default(Extension::Unknown);
230231
}
231232

@@ -297,6 +298,8 @@ const char *FeatureManager::getExtensionName(Extension symbol) {
297298
return "SPV_KHR_float_controls";
298299
case Extension::NV_shader_subgroup_partitioned:
299300
return "SPV_NV_shader_subgroup_partitioned";
301+
case Extension::KHR_quad_control:
302+
return "SPV_KHR_quad_control";
300303
default:
301304
break;
302305
}

tools/clang/lib/SPIRV/SpirvBuilder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ SpirvSpecConstantBinaryOp *SpirvBuilder::createSpecConstantBinaryOp(
453453
}
454454

455455
SpirvGroupNonUniformOp *SpirvBuilder::createGroupNonUniformOp(
456-
spv::Op op, QualType resultType, spv::Scope execScope,
456+
spv::Op op, QualType resultType, llvm::Optional<spv::Scope> execScope,
457457
llvm::ArrayRef<SpirvInstruction *> operands, SourceLocation loc,
458458
llvm::Optional<spv::GroupOperation> groupOp) {
459459
assert(insertPoint && "null insert point");

tools/clang/lib/SPIRV/SpirvEmitter.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9271,6 +9271,10 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
92719271
case hlsl::IntrinsicOp::IOP_QuadReadLaneAt:
92729272
retVal = processWaveQuadWideShuffle(callExpr, hlslOpcode);
92739273
break;
9274+
case hlsl::IntrinsicOp::IOP_QuadAny:
9275+
case hlsl::IntrinsicOp::IOP_QuadAll:
9276+
retVal = processWaveQuadAnyAll(callExpr, hlslOpcode);
9277+
break;
92749278
case hlsl::IntrinsicOp::IOP_abort:
92759279
case hlsl::IntrinsicOp::IOP_GetRenderTargetSampleCount:
92769280
case hlsl::IntrinsicOp::IOP_GetRenderTargetSamplePosition: {
@@ -10233,6 +10237,53 @@ SpirvEmitter::processWaveQuadWideShuffle(const CallExpr *callExpr,
1023310237
opcode, retType, spv::Scope::Subgroup, {value, target}, srcLoc);
1023410238
}
1023510239

10240+
SpirvInstruction *SpirvEmitter::processWaveQuadAnyAll(const CallExpr *callExpr,
10241+
hlsl::IntrinsicOp op) {
10242+
// Signatures:
10243+
// bool QuadAny(bool localValue)
10244+
// bool QuadAll(bool localValue)
10245+
assert(callExpr->getNumArgs() == 1);
10246+
assert(op == hlsl::IntrinsicOp::IOP_QuadAny ||
10247+
op == hlsl::IntrinsicOp::IOP_QuadAll);
10248+
featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "Wave Operation",
10249+
callExpr->getExprLoc());
10250+
10251+
auto *predicate = doExpr(callExpr->getArg(0));
10252+
const auto srcLoc = callExpr->getExprLoc();
10253+
10254+
if (!featureManager.isExtensionEnabled(Extension::KHR_quad_control)) {
10255+
// We can't use QuadAny/QuadAll, so implement them using QuadSwap. We
10256+
// will read the value at each quad invocation, then combine them.
10257+
10258+
spv::Op reducer = op == hlsl::IntrinsicOp::IOP_QuadAny
10259+
? spv::Op::OpLogicalOr
10260+
: spv::Op::OpLogicalAnd;
10261+
10262+
SpirvInstruction *result = predicate;
10263+
10264+
for (size_t i = 0; i < 3; i++) {
10265+
SpirvInstruction *invocationValue = spvBuilder.createGroupNonUniformOp(
10266+
spv::Op::OpGroupNonUniformQuadSwap, astContext.BoolTy,
10267+
spv::Scope::Subgroup,
10268+
{predicate, spvBuilder.getConstantInt(astContext.UnsignedIntTy,
10269+
llvm::APInt(32, i))},
10270+
srcLoc);
10271+
result = spvBuilder.createBinaryOp(reducer, astContext.BoolTy, result,
10272+
invocationValue, srcLoc);
10273+
}
10274+
10275+
return result;
10276+
}
10277+
10278+
spv::Op opcode = op == hlsl::IntrinsicOp::IOP_QuadAny
10279+
? spv::Op::OpGroupNonUniformQuadAnyKHR
10280+
: spv::Op::OpGroupNonUniformQuadAllKHR;
10281+
10282+
return spvBuilder.createGroupNonUniformOp(opcode, astContext.BoolTy,
10283+
llvm::Optional<spv::Scope>(),
10284+
{predicate}, srcLoc);
10285+
}
10286+
1023610287
SpirvInstruction *
1023710288
SpirvEmitter::processWaveActiveAllEqual(const CallExpr *callExpr) {
1023810289
assert(callExpr->getNumArgs() == 1);

tools/clang/lib/SPIRV/SpirvEmitter.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,10 @@ class SpirvEmitter : public ASTConsumer {
670670
SpirvInstruction *processWaveQuadWideShuffle(const CallExpr *,
671671
hlsl::IntrinsicOp op);
672672

673+
/// Processes SM6.7 quad any/all.
674+
SpirvInstruction *processWaveQuadAnyAll(const CallExpr *,
675+
hlsl::IntrinsicOp op);
676+
673677
/// Generates the Spir-V instructions needed to implement the given call to
674678
/// WaveActiveAllEqual. Returns a pointer to the instruction that produces the
675679
/// final result.

0 commit comments

Comments
 (0)