Skip to content

Commit 759e9e1

Browse files
authored
Allow cap and extention attribute on member variables and type aliases (microsoft#6742)
This implements the changes proposed in microsoft/hlsl-specs#270.
1 parent 0703d9e commit 759e9e1

12 files changed

+127
-34
lines changed

tools/clang/include/clang/Basic/Attr.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1317,7 +1317,7 @@ def VKBinding : InheritableAttr {
13171317

13181318
def VKCapabilityExt : InheritableAttr {
13191319
let Spellings = [CXX11<"vk", "ext_capability">];
1320-
let Subjects = SubjectList<[Function, Var], ErrorDiag>;
1320+
let Subjects = SubjectList<[Function, Var, ParmVar, Field, TypedefName], ErrorDiag>;
13211321
let Args = [IntArgument<"capability">];
13221322
let LangOpts = [SPIRV];
13231323
let Documentation = [Undocumented];
@@ -1357,7 +1357,7 @@ def VKDecorateStringExt : InheritableAttr {
13571357

13581358
def VKExtensionExt : InheritableAttr {
13591359
let Spellings = [CXX11<"vk", "ext_extension">];
1360-
let Subjects = SubjectList<[Function, Var], ErrorDiag>;
1360+
let Subjects = SubjectList<[Function, Var, ParmVar, Field, TypedefName], ErrorDiag>;
13611361
let Args = [StringArgument<"name">];
13621362
let LangOpts = [SPIRV];
13631363
let Documentation = [Undocumented];

tools/clang/include/clang/SPIRV/FeatureManager.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class FeatureManager {
7676
/// Allows all extensions to be used in CodeGen.
7777
void allowAllKnownExtensions();
7878

79-
/// Rqeusts the given extension for translating the given target feature at
79+
/// Requests the given extension for translating the given target feature at
8080
/// the given source location. Emits an error if the given extension is not
8181
/// permitted to use.
8282
bool requestExtension(Extension, llvm::StringRef target, SourceLocation);

tools/clang/include/clang/SPIRV/SpirvBuilder.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -764,13 +764,8 @@ class SpirvBuilder {
764764
SpirvInstruction *v);
765765
SpirvInstruction *getPerVertexStgInput(SpirvInstruction *k);
766766

767-
public:
768767
std::vector<uint32_t> takeModule();
769768

770-
protected:
771-
/// Only friend classes are allowed to add capability/extension to the module
772-
/// under construction.
773-
774769
/// \brief Adds the given capability to the module under construction due to
775770
/// the feature used at the given source location.
776771
inline void requireCapability(spv::Capability, SourceLocation loc = {});

tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -968,13 +968,6 @@ DeclResultIdMapper::getDeclSpirvInfo(const ValueDecl *decl) const {
968968
SpirvInstruction *DeclResultIdMapper::getDeclEvalInfo(const ValueDecl *decl,
969969
SourceLocation loc,
970970
SourceRange range) {
971-
if (decl->hasAttr<VKExtensionExtAttr>() ||
972-
decl->hasAttr<VKCapabilityExtAttr>()) {
973-
theEmitter.createSpirvIntrInstExt(decl->getAttrs(), QualType(),
974-
/* spvArgs */ {}, /* isInst */ false,
975-
loc);
976-
}
977-
978971
if (auto *builtinAttr = decl->getAttr<VKExtBuiltinInputAttr>()) {
979972
return getBuiltinVar(spv::BuiltIn(builtinAttr->getBuiltInID()),
980973
decl->getType(), spv::StorageClass::Input, loc);
@@ -1640,6 +1633,18 @@ DeclResultIdMapper::createShaderRecordBuffer(const HLSLBufferDecl *decl,
16401633
return bufferVar;
16411634
}
16421635

1636+
void DeclResultIdMapper::recordsSpirvTypeAlias(const Decl *decl) {
1637+
auto *typedefDecl = dyn_cast<TypedefNameDecl>(decl);
1638+
if (!typedefDecl)
1639+
return;
1640+
1641+
if (!typedefDecl->hasAttr<VKCapabilityExtAttr>() &&
1642+
!typedefDecl->hasAttr<VKExtensionExtAttr>())
1643+
return;
1644+
1645+
typeAliasesWithAttributes.push_back(typedefDecl);
1646+
}
1647+
16431648
void DeclResultIdMapper::createGlobalsCBuffer(const VarDecl *var) {
16441649
if (astDecls.count(var) != 0)
16451650
return;
@@ -4776,5 +4781,21 @@ void DeclResultIdMapper::storeOutStageVarsToStorage(
47764781
}
47774782
}
47784783

4784+
void DeclResultIdMapper::registerCapabilitiesAndExtensionsForType(
4785+
const TypedefType *type) {
4786+
for (const auto *decl : typeAliasesWithAttributes) {
4787+
if (type == decl->getTypeForDecl()) {
4788+
for (auto *attribute : decl->specific_attrs<VKExtensionExtAttr>()) {
4789+
clang::StringRef extensionName = attribute->getName();
4790+
spvBuilder.requireExtension(extensionName, decl->getLocation());
4791+
}
4792+
for (auto *attribute : decl->specific_attrs<VKCapabilityExtAttr>()) {
4793+
spv::Capability cap = spv::Capability(attribute->getCapability());
4794+
spvBuilder.requireCapability(cap, decl->getLocation());
4795+
}
4796+
}
4797+
}
4798+
}
4799+
47794800
} // end namespace spirv
47804801
} // end namespace clang

tools/clang/lib/SPIRV/DeclResultIdMapper.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,10 @@ class DeclResultIdMapper {
376376
SpirvVariable *createShaderRecordBuffer(const HLSLBufferDecl *decl,
377377
ContextUsageKind kind);
378378

379+
// Records the TypedefDecl or TypeAliasDecl of vk::SpirvType so that any
380+
// required capabilities and extensions can be added if the type is used.
381+
void recordsSpirvTypeAlias(const Decl *decl);
382+
379383
private:
380384
/// The struct containing SPIR-V information of a AST Decl.
381385
struct DeclSpirvInfo {
@@ -569,6 +573,12 @@ class DeclResultIdMapper {
569573

570574
spv::ExecutionMode getInterlockExecutionMode();
571575

576+
/// Records any Spir-V capabilities and extensions for the given type so
577+
/// they will be added to the SPIR-V module. The capabilities and extension
578+
/// required for the type will be sourced from the decls that were recorded
579+
/// using `recordSpirvTypeAlias`.
580+
void registerCapabilitiesAndExtensionsForType(const TypedefType *type);
581+
572582
private:
573583
/// \brief Wrapper method to create a fatal error message and report it
574584
/// in the diagnostic engine associated with this consumer.
@@ -1058,6 +1068,8 @@ class DeclResultIdMapper {
10581068

10591069
uint32_t perspBaryCentricsIndex, noPerspBaryCentricsIndex;
10601070

1071+
llvm::SmallVector<const TypedefNameDecl *, 4> typeAliasesWithAttributes;
1072+
10611073
public:
10621074
/// The gl_PerVertex structs for both input and output
10631075
GlPerVertex glPerVertex;

tools/clang/lib/SPIRV/LowerTypeVisitor.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1085,6 +1085,19 @@ LowerTypeVisitor::lowerStructFields(const RecordDecl *decl,
10851085
field->getBitWidthValue(field->getASTContext());
10861086
}
10871087

1088+
if (field->hasAttrs()) {
1089+
for (auto &attr : field->getAttrs()) {
1090+
if (auto capAttr = dyn_cast<VKCapabilityExtAttr>(attr)) {
1091+
spvBuilder.requireCapability(
1092+
static_cast<spv::Capability>(capAttr->getCapability()),
1093+
capAttr->getLocation());
1094+
} else if (auto extAttr = dyn_cast<VKExtensionExtAttr>(attr)) {
1095+
spvBuilder.requireExtension(extAttr->getName(),
1096+
extAttr->getLocation());
1097+
}
1098+
}
1099+
}
1100+
10881101
fields.push_back(HybridStructType::FieldInfo(
10891102
field->getType(), field->getName(),
10901103
/*vkoffset*/ field->getAttr<VKOffsetAttr>(),

tools/clang/lib/SPIRV/SpirvEmitter.cpp

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -986,8 +986,7 @@ void SpirvEmitter::HandleTranslationUnit(ASTContext &context) {
986986
}
987987

988988
void SpirvEmitter::doDecl(const Decl *decl) {
989-
if (isa<EmptyDecl>(decl) || isa<TypedefDecl>(decl) ||
990-
isa<TypeAliasDecl>(decl) || isa<TypeAliasTemplateDecl>(decl) ||
989+
if (isa<EmptyDecl>(decl) || isa<TypeAliasTemplateDecl>(decl) ||
991990
isa<VarTemplateDecl>(decl))
992991
return;
993992

@@ -1016,6 +1015,8 @@ void SpirvEmitter::doDecl(const Decl *decl) {
10161015
} else if (const auto *classTemplateDecl =
10171016
dyn_cast<ClassTemplateDecl>(decl)) {
10181017
doClassTemplateDecl(classTemplateDecl);
1018+
} else if (isa<TypedefNameDecl>(decl)) {
1019+
declIdMapper.recordsSpirvTypeAlias(decl);
10191020
} else if (isa<FunctionTemplateDecl>(decl)) {
10201021
// nothing to do.
10211022
} else if (isa<UsingDecl>(decl)) {
@@ -1696,6 +1697,27 @@ bool SpirvEmitter::validateVKAttributes(const NamedDecl *decl) {
16961697
return success;
16971698
}
16981699

1700+
void SpirvEmitter::registerCapabilitiesAndExtensionsForVarDecl(
1701+
const VarDecl *varDecl) {
1702+
// First record any extensions that are part of the actual variable
1703+
// declaration.
1704+
for (auto *attribute : varDecl->specific_attrs<VKExtensionExtAttr>()) {
1705+
clang::StringRef extensionName = attribute->getName();
1706+
spvBuilder.requireExtension(extensionName, varDecl->getLocation());
1707+
}
1708+
for (auto *attribute : varDecl->specific_attrs<VKCapabilityExtAttr>()) {
1709+
spv::Capability cap = spv::Capability(attribute->getCapability());
1710+
spvBuilder.requireCapability(cap, varDecl->getLocation());
1711+
}
1712+
1713+
// Now check for any capabilities or extensions that are part of the type.
1714+
const TypedefType *type = dyn_cast<TypedefType>(varDecl->getType());
1715+
if (!type)
1716+
return;
1717+
1718+
declIdMapper.registerCapabilitiesAndExtensionsForType(type);
1719+
}
1720+
16991721
void SpirvEmitter::doHLSLBufferDecl(const HLSLBufferDecl *bufferDecl) {
17001722
// This is a cbuffer/tbuffer decl.
17011723
// Check and emit warnings for member intializers which are not
@@ -1847,6 +1869,8 @@ void SpirvEmitter::doVarDecl(const VarDecl *decl) {
18471869
return;
18481870
}
18491871

1872+
registerCapabilitiesAndExtensionsForVarDecl(decl);
1873+
18501874
// Handle vk::ext_builtin_input and vk::ext_builtin_input by using
18511875
// getBuiltinVar to create the builtin and validate the storage class
18521876
if (decl->hasAttr<VKExtBuiltinInputAttr>()) {

tools/clang/lib/SPIRV/SpirvEmitter.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,10 @@ class SpirvEmitter : public ASTConsumer {
412412
/// errors are found.
413413
bool validateVKAttributes(const NamedDecl *decl);
414414

415+
/// Records any Spir-V capabilities and extensions for the given varDecl so
416+
/// they will be added to the SPIR-V module.
417+
void registerCapabilitiesAndExtensionsForVarDecl(const VarDecl *varDecl);
418+
415419
private:
416420
/// Converts the given value from the bitwidth of 'fromType' to the bitwidth
417421
/// of 'toType'. If the two have the same bitwidth, returns the value itself.
Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
// RUN: %dxc -T ps_6_0 -E main -fcgl -Vd %s -spirv | FileCheck %s
22

3-
// CHECK: OpCapability Int8
4-
// CHECK: OpCapability SampleMaskPostDepthCoverage
3+
// CHECK-DAG: OpCapability Int8
4+
// CHECK-DAG: OpCapability SampleMaskPostDepthCoverage
5+
// CHECK-DAG: OpCapability WorkgroupMemoryExplicitLayoutKHR
56

6-
[[vk::ext_capability(/* SampleMaskPostDepthCoverageCapability */ 4447)]]
7-
int val;
87

8+
// Test that the capability on a typedef is added to the module.
99
[[vk::ext_capability(/* Int8 */ 39)]]
10+
typedef vk::SpirvType</* OpTypeInt */ 21, 8, 8, vk::Literal<vk::integral_constant<uint, 8> >, vk::Literal<vk::integral_constant<bool, false> > > uint8_t;
11+
12+
// Test that the capability on a variable is added to the module.
13+
[[vk::ext_capability(/* SampleMaskPostDepthCoverageCapability */ 4447)]]
14+
uint8_t val;
15+
16+
// Test that the capability on the entry point is added to the module.
17+
[[vk::ext_capability(/* WorkgroupMemoryExplicitLayoutKHR */ 4428)]]
1018
void main() {
11-
int local = val;
1219
}
Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,22 @@
11
// RUN: %dxc -T ps_6_0 -E main -fcgl %s -spirv | FileCheck %s
22

3-
// CHECK: OpExtension "entry_point_extension"
4-
// CHECK: OpExtension "another_extension"
5-
// CHECK: OpExtension "some_extension"
3+
// CHECK-DAG: OpExtension "entry_point_extension"
4+
// CHECK-DAG: OpExtension "another_extension"
5+
// CHECK-DAG: OpExtension "some_extension"
6+
// CHECK-DAG: OpExtension "ext_on_field1"
7+
// CHECK-DAG: OpExtension "ext_on_field2"
68

79
[[vk::ext_extension("some_extension"), vk::ext_extension("another_extension")]]
810
int val;
911

12+
struct T
13+
{
14+
[[vk::ext_extension("ext_on_field1"), vk::ext_extension("ext_on_field2")]]
15+
int val;
16+
};
17+
1018
[[vk::ext_extension("entry_point_extension")]]
1119
void main() {
12-
int local = val;
20+
T t;
21+
int local = val+t.val;
1322
}

0 commit comments

Comments
 (0)