Skip to content

Commit d60dffe

Browse files
[SPIR-V] Implement SpirvType and SpirvOpaqueType (microsoft#6156)
Implements hlsl-specs proposal 0011, adding `vk::SpirvType` and `vk::SpirvOpaqueType` templates which allow users to define and use SPIR-V level types.
1 parent dc84d72 commit d60dffe

23 files changed

+533
-54
lines changed

tools/clang/include/clang/AST/HlslBuiltinTypeDeclBuilder.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ class BuiltinTypeDeclBuilder final {
3434

3535
clang::TemplateTypeParmDecl *
3636
addTypeTemplateParam(llvm::StringRef name,
37-
clang::TypeSourceInfo *defaultValue = nullptr);
37+
clang::TypeSourceInfo *defaultValue = nullptr,
38+
bool parameterPack = false);
3839
clang::TemplateTypeParmDecl *
3940
addTypeTemplateParam(llvm::StringRef name, clang::QualType defaultValue);
4041
clang::NonTypeTemplateParmDecl *

tools/clang/include/clang/AST/HlslTypes.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,16 @@ DeclareNodeOrRecordType(clang::ASTContext &Ctx, DXIL::NodeIOKind Type,
399399
bool HasGetMethods = false, bool IsArray = false,
400400
bool IsCompleteType = false);
401401

402+
#ifdef ENABLE_SPIRV_CODEGEN
403+
clang::CXXRecordDecl *DeclareInlineSpirvType(clang::ASTContext &context,
404+
clang::DeclContext *declContext,
405+
llvm::StringRef typeName,
406+
bool opaque);
407+
clang::CXXRecordDecl *DeclareVkIntegralConstant(
408+
clang::ASTContext &context, clang::DeclContext *declContext,
409+
llvm::StringRef typeName, clang::ClassTemplateDecl **templateDecl);
410+
#endif
411+
402412
clang::CXXRecordDecl *DeclareNodeOutputArray(clang::ASTContext &Ctx,
403413
DXIL::NodeIOKind Type,
404414
clang::CXXRecordDecl *OutputType,

tools/clang/include/clang/SPIRV/SpirvContext.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -286,9 +286,12 @@ class SpirvContext {
286286

287287
const RayQueryTypeKHR *getRayQueryTypeKHR() const { return rayQueryTypeKHR; }
288288

289-
const SpirvIntrinsicType *
290-
getSpirvIntrinsicType(unsigned typeId, unsigned typeOpCode,
291-
llvm::ArrayRef<SpvIntrinsicTypeOperand> operands);
289+
const SpirvIntrinsicType *getOrCreateSpirvIntrinsicType(
290+
unsigned typeId, unsigned typeOpCode,
291+
llvm::ArrayRef<SpvIntrinsicTypeOperand> operands);
292+
293+
const SpirvIntrinsicType *getOrCreateSpirvIntrinsicType(
294+
unsigned typeOpCode, llvm::ArrayRef<SpvIntrinsicTypeOperand> operands);
292295

293296
SpirvIntrinsicType *getCreatedSpirvIntrinsicType(unsigned typeId);
294297

@@ -471,7 +474,8 @@ class SpirvContext {
471474
llvm::DenseMap<const SpirvType *, SCToPtrTyMap> pointerTypes;
472475
llvm::SmallVector<const HybridPointerType *, 8> hybridPointerTypes;
473476
llvm::DenseSet<FunctionType *, FunctionTypeMapInfo> functionTypes;
474-
llvm::DenseMap<unsigned, SpirvIntrinsicType *> spirvIntrinsicTypes;
477+
llvm::DenseMap<unsigned, SpirvIntrinsicType *> spirvIntrinsicTypesById;
478+
llvm::SmallVector<const SpirvIntrinsicType *, 8> spirvIntrinsicTypes;
475479
const AccelerationStructureTypeNV *accelerationStructureTypeNV;
476480
const RayQueryTypeKHR *rayQueryTypeKHR;
477481

tools/clang/include/clang/SPIRV/SpirvInstruction.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,6 +1193,8 @@ class SpirvConstant : public SpirvInstruction {
11931193
inst->getKind() <= IK_ConstantNull;
11941194
}
11951195

1196+
bool operator==(const SpirvConstant &that) const;
1197+
11961198
bool isSpecConstant() const;
11971199
void setLiteral(bool literal = true) { literalConstant = literal; }
11981200
bool isLiteral() { return literalConstant; }

tools/clang/include/clang/SPIRV/SpirvType.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -429,15 +429,16 @@ class RayQueryTypeKHR : public SpirvType {
429429

430430
class SpirvInstruction;
431431
struct SpvIntrinsicTypeOperand {
432-
SpvIntrinsicTypeOperand(SpirvType *type_operand)
432+
SpvIntrinsicTypeOperand(const SpirvType *type_operand)
433433
: operand_as_type(type_operand), isTypeOperand(true) {}
434434
SpvIntrinsicTypeOperand(SpirvInstruction *inst_operand)
435435
: operand_as_inst(inst_operand), isTypeOperand(false) {}
436+
bool operator==(const SpvIntrinsicTypeOperand &that) const;
436437
union {
437-
SpirvType *operand_as_type;
438+
const SpirvType *operand_as_type;
438439
SpirvInstruction *operand_as_inst;
439440
};
440-
bool isTypeOperand;
441+
const bool isTypeOperand;
441442
};
442443

443444
class SpirvIntrinsicType : public SpirvType {
@@ -453,6 +454,12 @@ class SpirvIntrinsicType : public SpirvType {
453454
return operands;
454455
}
455456

457+
bool operator==(const SpirvIntrinsicType &that) const {
458+
return typeOpCode == that.typeOpCode &&
459+
operands.size() == that.operands.size() &&
460+
std::equal(operands.begin(), operands.end(), that.operands.begin());
461+
}
462+
456463
private:
457464
unsigned typeOpCode;
458465
llvm::SmallVector<SpvIntrinsicTypeOperand, 3> operands;

tools/clang/lib/AST/ASTContextHLSL.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,6 +1241,43 @@ CXXRecordDecl *hlsl::DeclareNodeOrRecordType(
12411241
return Builder.getRecordDecl();
12421242
}
12431243

1244+
#ifdef ENABLE_SPIRV_CODEGEN
1245+
CXXRecordDecl *hlsl::DeclareInlineSpirvType(clang::ASTContext &context,
1246+
clang::DeclContext *declContext,
1247+
llvm::StringRef typeName,
1248+
bool opaque) {
1249+
// template<uint opcode, int size, int alignment> vk::SpirvType { ... }
1250+
// template<uint opcode> vk::SpirvOpaqueType { ... }
1251+
BuiltinTypeDeclBuilder typeDeclBuilder(declContext, typeName,
1252+
clang::TagTypeKind::TTK_Class);
1253+
typeDeclBuilder.addIntegerTemplateParam("opcode", context.UnsignedIntTy);
1254+
if (!opaque) {
1255+
typeDeclBuilder.addIntegerTemplateParam("size", context.UnsignedIntTy);
1256+
typeDeclBuilder.addIntegerTemplateParam("alignment", context.UnsignedIntTy);
1257+
}
1258+
typeDeclBuilder.addTypeTemplateParam("operands", nullptr, true);
1259+
typeDeclBuilder.startDefinition();
1260+
typeDeclBuilder.addField(
1261+
"h", context.UnsignedIntTy); // Add an 'h' field to hold the handle.
1262+
return typeDeclBuilder.getRecordDecl();
1263+
}
1264+
1265+
CXXRecordDecl *hlsl::DeclareVkIntegralConstant(
1266+
clang::ASTContext &context, clang::DeclContext *declContext,
1267+
llvm::StringRef typeName, ClassTemplateDecl **templateDecl) {
1268+
// template<typename T, T v> vk::integral_constant { ... }
1269+
BuiltinTypeDeclBuilder typeDeclBuilder(declContext, typeName,
1270+
clang::TagTypeKind::TTK_Class);
1271+
typeDeclBuilder.addTypeTemplateParam("T");
1272+
typeDeclBuilder.addIntegerTemplateParam("v", context.UnsignedIntTy);
1273+
typeDeclBuilder.startDefinition();
1274+
typeDeclBuilder.addField(
1275+
"h", context.UnsignedIntTy); // Add an 'h' field to hold the handle.
1276+
*templateDecl = typeDeclBuilder.getTemplateDecl();
1277+
return typeDeclBuilder.getRecordDecl();
1278+
}
1279+
#endif
1280+
12441281
CXXRecordDecl *hlsl::DeclareNodeOutputArray(clang::ASTContext &Ctx,
12451282
DXIL::NodeIOKind Type,
12461283
CXXRecordDecl *OutputType,

tools/clang/lib/AST/HlslBuiltinTypeDeclBuilder.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,8 @@ BuiltinTypeDeclBuilder::BuiltinTypeDeclBuilder(DeclContext *declContext,
3333
m_recordDecl->setImplicit(true);
3434
}
3535

36-
TemplateTypeParmDecl *
37-
BuiltinTypeDeclBuilder::addTypeTemplateParam(StringRef name,
38-
TypeSourceInfo *defaultValue) {
36+
TemplateTypeParmDecl *BuiltinTypeDeclBuilder::addTypeTemplateParam(
37+
StringRef name, TypeSourceInfo *defaultValue, bool parameterPack) {
3938
DXASSERT_NOMSG(!m_recordDecl->isBeingDefined() &&
4039
!m_recordDecl->isCompleteDefinition());
4140

@@ -45,7 +44,7 @@ BuiltinTypeDeclBuilder::addTypeTemplateParam(StringRef name,
4544
astContext, m_recordDecl->getDeclContext(), NoLoc, NoLoc,
4645
/* TemplateDepth */ 0, index,
4746
&astContext.Idents.get(name, tok::TokenKind::identifier),
48-
/* Typename */ false, /* ParameterPack */ false);
47+
/* Typename */ false, parameterPack);
4948
if (defaultValue != nullptr)
5049
decl->setDefaultArgument(defaultValue);
5150
m_templateParams.emplace_back(decl);

tools/clang/lib/SPIRV/AlignmentSizeCalculator.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include "AlignmentSizeCalculator.h"
1111
#include "clang/AST/Attr.h"
12+
#include "clang/AST/DeclTemplate.h"
1213

1314
namespace {
1415

@@ -264,6 +265,21 @@ std::pair<uint32_t, uint32_t> AlignmentSizeCalculator::getAlignmentAndSize(
264265
return getAlignmentAndSize(desugaredType, rule, isRowMajor, stride);
265266
}
266267

268+
const auto *recordType = type->getAs<RecordType>();
269+
if (recordType != nullptr) {
270+
const llvm::StringRef name = recordType->getDecl()->getName();
271+
272+
if (isTypeInVkNamespace(recordType) && name == "SpirvType") {
273+
const ClassTemplateSpecializationDecl *templateDecl =
274+
cast<ClassTemplateSpecializationDecl>(recordType->getDecl());
275+
const uint64_t size =
276+
templateDecl->getTemplateArgs()[1].getAsIntegral().getZExtValue();
277+
const uint64_t alignment =
278+
templateDecl->getTemplateArgs()[2].getAsIntegral().getZExtValue();
279+
return {alignment, size};
280+
}
281+
}
282+
267283
if (isEnumType(type))
268284
type = astContext.IntTy;
269285

tools/clang/lib/SPIRV/ConstEvaluator.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ class ConstEvaluator {
3535
SpirvConstant *translateAPFloat(llvm::APFloat floatValue, QualType targetType,
3636
bool isSpecConstantMode);
3737

38+
/// Translates the given frontend APValue into its SPIR-V equivalent for the
39+
/// given targetType.
40+
SpirvConstant *translateAPValue(const APValue &value,
41+
const QualType targetType,
42+
bool isSpecConstantMode);
43+
3844
/// Tries to evaluate the given APInt as a 32-bit integer. If the evaluation
3945
/// can be performed without loss, it returns the <result-id> of the SPIR-V
4046
/// constant for that value.
@@ -52,12 +58,6 @@ class ConstEvaluator {
5258
bool isSpecConstantMode);
5359

5460
private:
55-
/// Translates the given frontend APValue into its SPIR-V equivalent for the
56-
/// given targetType.
57-
SpirvConstant *translateAPValue(const APValue &value,
58-
const QualType targetType,
59-
bool isSpecConstantMode);
60-
6161
/// Emits error to the diagnostic engine associated with the AST context.
6262
template <unsigned N>
6363
DiagnosticBuilder emitError(const char (&message)[N],

tools/clang/lib/SPIRV/EmitVisitor.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2577,7 +2577,11 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type) {
25772577
for (const SpvIntrinsicTypeOperand &operand :
25782578
spvIntrinsicType->getOperands()) {
25792579
if (operand.isTypeOperand) {
2580-
curTypeInst.push_back(emitType(operand.operand_as_type));
2580+
// calling emitType recursively will potentially replace the contents of
2581+
// curTypeInst, so we need to save them and restore after the call
2582+
std::vector<uint32_t> outerTypeInst = curTypeInst;
2583+
outerTypeInst.push_back(emitType(operand.operand_as_type));
2584+
curTypeInst = outerTypeInst;
25812585
} else {
25822586
auto *literal = dyn_cast<SpirvConstant>(operand.operand_as_inst);
25832587
if (literal && literal->isLiteral()) {

0 commit comments

Comments
 (0)