Skip to content

Commit 143da22

Browse files
authored
[SPIRV] Fix declaration order of values in decorations (#7672)
The value that decorates another value must be declared before the value it decorates. See KhronosGroup/SPIRV-Tools#6227.
1 parent bd18344 commit 143da22

26 files changed

+227
-179
lines changed

tools/clang/include/clang/SPIRV/SpirvBuilder.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,12 +678,23 @@ class SpirvBuilder {
678678
bool isPrecise, bool isNointerp, llvm::StringRef name = "",
679679
llvm::Optional<SpirvInstruction *> init = llvm::None,
680680
SourceLocation loc = {});
681+
682+
// Adds a variable to the module.
681683
SpirvVariable *
682684
addModuleVar(const SpirvType *valueType, spv::StorageClass storageClass,
683685
bool isPrecise, bool isNointerp, llvm::StringRef name = "",
684686
llvm::Optional<SpirvInstruction *> init = llvm::None,
685687
SourceLocation loc = {});
686688

689+
// Adds a variable to the module. It will be placed in the variable list
690+
// before `pos`.
691+
SpirvVariable *
692+
addModuleVar(const SpirvType *valueType, spv::StorageClass storageClass,
693+
bool isPrecise, bool isNointerp, SpirvInstruction *before,
694+
llvm::StringRef name = "",
695+
llvm::Optional<SpirvInstruction *> init = llvm::None,
696+
SourceLocation loc = {});
697+
687698
/// \brief Decorates the given target with the given location.
688699
void decorateLocation(SpirvInstruction *target, uint32_t location);
689700

tools/clang/include/clang/SPIRV/SpirvModule.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,10 @@ class SpirvModule {
139139
// Adds a variable to the module.
140140
void addVariable(SpirvVariable *);
141141

142+
// Adds a variable to the module immediately before `pos`.
143+
// If `pos` is not found, `var` is added at the end of the variable list.
144+
void addVariable(SpirvVariable *var, SpirvInstruction *pos);
145+
142146
// Adds a decoration to the module.
143147
void addDecoration(SpirvDecoration *);
144148

tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1861,7 +1861,7 @@ void DeclResultIdMapper::createCounterVar(
18611861
}
18621862

18631863
SpirvVariable *counterInstr = spvBuilder.addModuleVar(
1864-
counterType, sc, /*isPrecise*/ false, false, counterName);
1864+
counterType, sc, /*isPrecise*/ false, false, declInstr, counterName);
18651865

18661866
if (!isAlias) {
18671867
// Non-alias counter variables should be put in to resourceVars so that

tools/clang/lib/SPIRV/EmitVisitor.cpp

Lines changed: 78 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -2644,80 +2644,15 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type) {
26442644
// NodePayloadArray types
26452645
else if (const auto *npaType = dyn_cast<NodePayloadArrayType>(type)) {
26462646
const uint32_t elemTypeId = emitType(npaType->getElementType());
2647+
2648+
// Output the decorations for the type first. This will create other values
2649+
// that are on the decorations, and they must appear before the type.
2650+
emitDecorationsForNodePayloadArrayTypes(npaType, id);
2651+
26472652
initTypeInstruction(spv::Op::OpTypeNodePayloadArrayAMDX);
26482653
curTypeInst.push_back(id);
26492654
curTypeInst.push_back(elemTypeId);
26502655
finalizeTypeInstruction();
2651-
2652-
// Emit decorations
2653-
const ParmVarDecl *nodeDecl = npaType->getNodeDecl();
2654-
if (hlsl::IsHLSLNodeOutputType(nodeDecl->getType())) {
2655-
StringRef name = nodeDecl->getName();
2656-
unsigned index = 0;
2657-
if (auto nodeID = nodeDecl->getAttr<HLSLNodeIdAttr>()) {
2658-
name = nodeID->getName();
2659-
index = nodeID->getArrayIndex();
2660-
}
2661-
2662-
auto *str = new (context) SpirvConstantString(name);
2663-
uint32_t nodeName = getOrCreateConstantString(str);
2664-
emitDecoration(id, spv::Decoration::PayloadNodeNameAMDX, {nodeName},
2665-
llvm::None, true);
2666-
if (index) {
2667-
uint32_t baseIndex = getOrCreateConstantInt(
2668-
llvm::APInt(32, index), context.getUIntType(32), false);
2669-
emitDecoration(id, spv::Decoration::PayloadNodeBaseIndexAMDX,
2670-
{baseIndex}, llvm::None, true);
2671-
}
2672-
}
2673-
2674-
uint32_t maxRecords;
2675-
if (const auto *attr = nodeDecl->getAttr<HLSLMaxRecordsAttr>()) {
2676-
maxRecords = getOrCreateConstantInt(llvm::APInt(32, attr->getMaxCount()),
2677-
context.getUIntType(32), false);
2678-
} else {
2679-
maxRecords = getOrCreateConstantInt(llvm::APInt(32, 1),
2680-
context.getUIntType(32), false);
2681-
}
2682-
emitDecoration(id, spv::Decoration::NodeMaxPayloadsAMDX, {maxRecords},
2683-
llvm::None, true);
2684-
2685-
if (const auto *attr = nodeDecl->getAttr<HLSLMaxRecordsSharedWithAttr>()) {
2686-
const DeclContext *dc = nodeDecl->getParentFunctionOrMethod();
2687-
if (const auto *funDecl = dyn_cast_or_null<FunctionDecl>(dc)) {
2688-
IdentifierInfo *ii = attr->getName();
2689-
bool alreadyExists = false;
2690-
for (auto *paramDecl : funDecl->params()) {
2691-
if (paramDecl->getIdentifier() == ii) {
2692-
assert(paramDecl != nodeDecl);
2693-
auto otherType = context.getNodeDeclPayloadType(paramDecl);
2694-
const uint32_t otherId =
2695-
getResultIdForType(otherType, &alreadyExists);
2696-
assert(alreadyExists && "forward references not allowed in "
2697-
"MaxRecordsSharedWith attribute");
2698-
emitDecoration(id, spv::Decoration::NodeSharesPayloadLimitsWithAMDX,
2699-
{otherId}, llvm::None, true);
2700-
break;
2701-
}
2702-
}
2703-
assert(alreadyExists &&
2704-
"invalid reference in MaxRecordsSharedWith attribute");
2705-
}
2706-
}
2707-
if (const auto *attr = nodeDecl->getAttr<HLSLAllowSparseNodesAttr>()) {
2708-
emitDecoration(id, spv::Decoration::PayloadNodeSparseArrayAMDX, {},
2709-
llvm::None);
2710-
}
2711-
if (const auto *attr = nodeDecl->getAttr<HLSLUnboundedSparseNodesAttr>()) {
2712-
emitDecoration(id, spv::Decoration::PayloadNodeSparseArrayAMDX, {},
2713-
llvm::None);
2714-
}
2715-
if (const auto *attr = nodeDecl->getAttr<HLSLNodeArraySizeAttr>()) {
2716-
uint32_t arraySize = getOrCreateConstantInt(
2717-
llvm::APInt(32, attr->getCount()), context.getUIntType(32), false);
2718-
emitDecoration(id, spv::Decoration::PayloadNodeArraySizeAMDX, {arraySize},
2719-
llvm::None, true);
2720-
}
27212656
}
27222657
// Structure types
27232658
else if (const auto *structType = dyn_cast<StructType>(type)) {
@@ -2998,5 +2933,78 @@ void EmitTypeHandler::emitNameForType(llvm::StringRef name,
29982933
nameInstr.end());
29992934
}
30002935

2936+
void EmitTypeHandler::emitDecorationsForNodePayloadArrayTypes(
2937+
const NodePayloadArrayType *npaType, uint32_t id) {
2938+
// Emit decorations
2939+
const ParmVarDecl *nodeDecl = npaType->getNodeDecl();
2940+
if (hlsl::IsHLSLNodeOutputType(nodeDecl->getType())) {
2941+
StringRef name = nodeDecl->getName();
2942+
unsigned index = 0;
2943+
if (auto nodeID = nodeDecl->getAttr<HLSLNodeIdAttr>()) {
2944+
name = nodeID->getName();
2945+
index = nodeID->getArrayIndex();
2946+
}
2947+
2948+
auto *str = new (context) SpirvConstantString(name);
2949+
uint32_t nodeName = getOrCreateConstantString(str);
2950+
emitDecoration(id, spv::Decoration::PayloadNodeNameAMDX, {nodeName},
2951+
llvm::None, true);
2952+
if (index) {
2953+
uint32_t baseIndex = getOrCreateConstantInt(
2954+
llvm::APInt(32, index), context.getUIntType(32), false);
2955+
emitDecoration(id, spv::Decoration::PayloadNodeBaseIndexAMDX, {baseIndex},
2956+
llvm::None, true);
2957+
}
2958+
}
2959+
2960+
uint32_t maxRecords;
2961+
if (const auto *attr = nodeDecl->getAttr<HLSLMaxRecordsAttr>()) {
2962+
maxRecords = getOrCreateConstantInt(llvm::APInt(32, attr->getMaxCount()),
2963+
context.getUIntType(32), false);
2964+
} else {
2965+
maxRecords = getOrCreateConstantInt(llvm::APInt(32, 1),
2966+
context.getUIntType(32), false);
2967+
}
2968+
emitDecoration(id, spv::Decoration::NodeMaxPayloadsAMDX, {maxRecords},
2969+
llvm::None, true);
2970+
2971+
if (const auto *attr = nodeDecl->getAttr<HLSLMaxRecordsSharedWithAttr>()) {
2972+
const DeclContext *dc = nodeDecl->getParentFunctionOrMethod();
2973+
if (const auto *funDecl = dyn_cast_or_null<FunctionDecl>(dc)) {
2974+
IdentifierInfo *ii = attr->getName();
2975+
bool alreadyExists = false;
2976+
for (auto *paramDecl : funDecl->params()) {
2977+
if (paramDecl->getIdentifier() == ii) {
2978+
assert(paramDecl != nodeDecl);
2979+
auto otherType = context.getNodeDeclPayloadType(paramDecl);
2980+
const uint32_t otherId =
2981+
getResultIdForType(otherType, &alreadyExists);
2982+
assert(alreadyExists && "forward references not allowed in "
2983+
"MaxRecordsSharedWith attribute");
2984+
emitDecoration(id, spv::Decoration::NodeSharesPayloadLimitsWithAMDX,
2985+
{otherId}, llvm::None, true);
2986+
break;
2987+
}
2988+
}
2989+
assert(alreadyExists &&
2990+
"invalid reference in MaxRecordsSharedWith attribute");
2991+
}
2992+
}
2993+
if (const auto *attr = nodeDecl->getAttr<HLSLAllowSparseNodesAttr>()) {
2994+
emitDecoration(id, spv::Decoration::PayloadNodeSparseArrayAMDX, {},
2995+
llvm::None);
2996+
}
2997+
if (const auto *attr = nodeDecl->getAttr<HLSLUnboundedSparseNodesAttr>()) {
2998+
emitDecoration(id, spv::Decoration::PayloadNodeSparseArrayAMDX, {},
2999+
llvm::None);
3000+
}
3001+
if (const auto *attr = nodeDecl->getAttr<HLSLNodeArraySizeAttr>()) {
3002+
uint32_t arraySize = getOrCreateConstantInt(
3003+
llvm::APInt(32, attr->getCount()), context.getUIntType(32), false);
3004+
emitDecoration(id, spv::Decoration::PayloadNodeArraySizeAMDX, {arraySize},
3005+
llvm::None, true);
3006+
}
3007+
}
3008+
30013009
} // end namespace spirv
30023010
} // end namespace clang

tools/clang/lib/SPIRV/EmitVisitor.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,10 @@ class EmitTypeHandler {
135135
void emitNameForType(llvm::StringRef name, uint32_t targetTypeId,
136136
llvm::Optional<uint32_t> memberIndex = llvm::None);
137137

138+
void
139+
emitDecorationsForNodePayloadArrayTypes(const NodePayloadArrayType *npaType,
140+
uint32_t id);
141+
138142
// There is no guarantee that an instruction or a function or a basic block
139143
// has been assigned result-id. This method returns the result-id for the
140144
// given object. If a result-id has not been assigned yet, it'll assign

tools/clang/lib/SPIRV/SpirvBuilder.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1675,6 +1675,21 @@ SpirvVariable *SpirvBuilder::addModuleVar(
16751675
return var;
16761676
}
16771677

1678+
SpirvVariable *SpirvBuilder::addModuleVar(
1679+
const SpirvType *type, spv::StorageClass storageClass, bool isPrecise,
1680+
bool isNointerp, SpirvInstruction *pos, llvm::StringRef name,
1681+
llvm::Optional<SpirvInstruction *> init, SourceLocation loc) {
1682+
assert(storageClass != spv::StorageClass::Function);
1683+
// Note: We store the underlying type in the variable, *not* the pointer type.
1684+
auto *var = new (context)
1685+
SpirvVariable(type, loc, storageClass, isPrecise, isNointerp,
1686+
init.hasValue() ? init.getValue() : nullptr);
1687+
var->setResultType(type);
1688+
var->setDebugName(name);
1689+
mod->addVariable(var, pos);
1690+
return var;
1691+
}
1692+
16781693
void SpirvBuilder::decorateLocation(SpirvInstruction *target,
16791694
uint32_t location) {
16801695
auto *decor =

tools/clang/lib/SPIRV/SpirvModule.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,12 @@ void SpirvModule::addVariable(SpirvVariable *var) {
341341
variables.push_back(var);
342342
}
343343

344+
void SpirvModule::addVariable(SpirvVariable *var, SpirvInstruction *pos) {
345+
assert(var && "cannot add null variable to the module");
346+
auto location = std::find(variables.begin(), variables.end(), pos);
347+
variables.insert(location, var);
348+
}
349+
344350
void SpirvModule::addDecoration(SpirvDecoration *decor) {
345351
assert(decor && "cannot add null decoration to the module");
346352
decorations.insert(decor);

tools/clang/test/CodeGenSPIRV/spirv.legal.counter.nested-struct.hlsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// RUN: %dxc -T vs_6_0 -E main -fcgl %s -spirv | FileCheck %s
22

3-
// CHECK: %counter_var_t_1_0 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
43
// CHECK: %counter_var_rw = OpVariable %_ptr_Uniform_type_ACSBuffer_counter Uniform
4+
// CHECK: %counter_var_t_1_0 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
55
// CHECK: %counter_var_s_0 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
66

77
RWStructuredBuffer<uint> rw : register(u0);

tools/clang/test/CodeGenSPIRV/spirv.legal.sbuffer.counter.hlsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ RWStructuredBuffer<S1> selectRWSBuffer(RWStructuredBuffer<S1> paramRWSBu
2626
AppendStructuredBuffer<S2> selectASBuffer(AppendStructuredBuffer<S2> paramASBuffer, bool selector);
2727
ConsumeStructuredBuffer<S3> selectCSBuffer(ConsumeStructuredBuffer<S3> paramCSBuffer, bool selector);
2828

29+
// CHECK: %counter_var_globalRWSBuffer = OpVariable %_ptr_Uniform_type_ACSBuffer_counter Uniform
2930
RWStructuredBuffer<S1> globalRWSBuffer;
3031
// CHECK: %counter_var_globalASBuffer = OpVariable %_ptr_Uniform_type_ACSBuffer_counter Uniform
3132
AppendStructuredBuffer<S2> globalASBuffer;
@@ -38,7 +39,6 @@ static RWStructuredBuffer<S1> staticgRWSBuffer = globalRWSBuffer;
3839
// CHECK: %counter_var_staticgASBuffer = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
3940
static AppendStructuredBuffer<S2> staticgASBuffer = globalASBuffer;
4041
// CHECK: %counter_var_staticgCSBuffer = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
41-
// CHECK: %counter_var_globalRWSBuffer = OpVariable %_ptr_Uniform_type_ACSBuffer_counter Uniform
4242
static ConsumeStructuredBuffer<S3> staticgCSBuffer = globalCSBuffer;
4343

4444
// Counter variables for function returns, function parameters, and local variables have an extra level of pointer.

0 commit comments

Comments
 (0)