Skip to content

Commit 78ac1bc

Browse files
committed
[HLSL] Implement SpirvType and SpirvOpaqueType
This implements the design proposed by [Representing SpirvType in Clang's Type System](llvm/wg-hlsl#181). It creates `HLSLInlineSpirvType` as a new `Type` subclass, and `__hlsl_spirv_type` as a new builtin type template to create such a type. This new type is lowered to the `spirv.Type` target extension type, as described in [Target Extension Types for Inline SPIR-V and Decorated Types](https://github.com/llvm/wg-hlsl/blob/main/proposals/0017-inline-spirv-and-decorated-types.md).
1 parent a2ca2f3 commit 78ac1bc

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+1052
-76
lines changed

clang/include/clang-c/Index.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3034,7 +3034,8 @@ enum CXTypeKind {
30343034

30353035
/* HLSL Types */
30363036
CXType_HLSLResource = 179,
3037-
CXType_HLSLAttributedResource = 180
3037+
CXType_HLSLAttributedResource = 180,
3038+
CXType_HLSLInlineSpirv = 181
30383039
};
30393040

30403041
/**

clang/include/clang/AST/ASTContext.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ class ASTContext : public RefCountedBase<ASTContext> {
260260
DependentBitIntTypes;
261261
mutable llvm::FoldingSet<BTFTagAttributedType> BTFTagAttributedTypes;
262262
llvm::FoldingSet<HLSLAttributedResourceType> HLSLAttributedResourceTypes;
263+
llvm::FoldingSet<HLSLInlineSpirvType> HLSLInlineSpirvTypes;
263264

264265
mutable llvm::FoldingSet<CountAttributedType> CountAttributedTypes;
265266

@@ -1795,6 +1796,10 @@ class ASTContext : public RefCountedBase<ASTContext> {
17951796
QualType Wrapped, QualType Contained,
17961797
const HLSLAttributedResourceType::Attributes &Attrs);
17971798

1799+
QualType getHLSLInlineSpirvType(uint32_t Opcode, uint32_t Size,
1800+
uint32_t Alignment,
1801+
ArrayRef<SpirvOperand> Operands);
1802+
17981803
QualType
17991804
getSubstTemplateTypeParmType(QualType Replacement, Decl *AssociatedDecl,
18001805
unsigned Index,

clang/include/clang/AST/ASTNodeTraverser.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,24 @@ class ASTNodeTraverser
450450
if (!Contained.isNull())
451451
Visit(Contained);
452452
}
453+
void VisitHLSLInlineSpirvType(const HLSLInlineSpirvType *T) {
454+
for (auto &Operand : T->getOperands()) {
455+
using SpirvOperandKind = SpirvOperand::SpirvOperandKind;
456+
457+
switch (Operand.getKind()) {
458+
case SpirvOperandKind::kConstantId:
459+
case SpirvOperandKind::kLiteral:
460+
break;
461+
462+
case SpirvOperandKind::kTypeId:
463+
Visit(Operand.getResultType());
464+
break;
465+
466+
default:
467+
llvm_unreachable("Invalid SpirvOperand kind!");
468+
}
469+
}
470+
}
453471
void VisitSubstTemplateTypeParmType(const SubstTemplateTypeParmType *) {}
454472
void
455473
VisitSubstTemplateTypeParmPackType(const SubstTemplateTypeParmPackType *T) {

clang/include/clang/AST/PropertiesBase.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def UInt64 : CountPropertyType<"uint64_t">;
147147
def UnaryTypeTransformKind : EnumPropertyType<"UnaryTransformType::UTTKind">;
148148
def VectorKind : EnumPropertyType<"VectorKind">;
149149
def TypeCoupledDeclRefInfo : PropertyType;
150+
def HLSLSpirvOperand : PropertyType<"SpirvOperand"> { let PassByReference = 1; }
150151

151152
def ExceptionSpecInfo : PropertyType<"FunctionProtoType::ExceptionSpecInfo"> {
152153
let BufferElementTypes = [ QualType ];

clang/include/clang/AST/RecursiveASTVisitor.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,6 +1154,14 @@ DEF_TRAVERSE_TYPE(BTFTagAttributedType,
11541154
DEF_TRAVERSE_TYPE(HLSLAttributedResourceType,
11551155
{ TRY_TO(TraverseType(T->getWrappedType())); })
11561156

1157+
DEF_TRAVERSE_TYPE(HLSLInlineSpirvType, {
1158+
for (auto &Operand : T->getOperands()) {
1159+
if (Operand.isConstant() || Operand.isType()) {
1160+
TRY_TO(TraverseType(Operand.getResultType()));
1161+
}
1162+
}
1163+
})
1164+
11571165
DEF_TRAVERSE_TYPE(ParenType, { TRY_TO(TraverseType(T->getInnerType())); })
11581166

11591167
DEF_TRAVERSE_TYPE(MacroQualifiedType,
@@ -1457,6 +1465,9 @@ DEF_TRAVERSE_TYPELOC(BTFTagAttributedType,
14571465
DEF_TRAVERSE_TYPELOC(HLSLAttributedResourceType,
14581466
{ TRY_TO(TraverseTypeLoc(TL.getWrappedLoc())); })
14591467

1468+
DEF_TRAVERSE_TYPELOC(HLSLInlineSpirvType,
1469+
{ TRY_TO(TraverseType(TL.getType())); })
1470+
14601471
DEF_TRAVERSE_TYPELOC(ElaboratedType, {
14611472
if (TL.getQualifierLoc()) {
14621473
TRY_TO(TraverseNestedNameSpecifierLoc(TL.getQualifierLoc()));

clang/include/clang/AST/Type.h

Lines changed: 141 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2652,6 +2652,7 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
26522652
bool isHLSLSpecificType() const; // Any HLSL specific type
26532653
bool isHLSLBuiltinIntangibleType() const; // Any HLSL builtin intangible type
26542654
bool isHLSLAttributedResourceType() const;
2655+
bool isHLSLInlineSpirvType() const;
26552656
bool isHLSLResourceRecord() const;
26562657
bool isHLSLIntangibleType()
26572658
const; // Any HLSL intangible type (builtin, array, class)
@@ -6330,6 +6331,140 @@ class HLSLAttributedResourceType : public Type, public llvm::FoldingSetNode {
63306331
findHandleTypeOnResource(const Type *RT);
63316332
};
63326333

6334+
/// Instances of this class represent operands to a SPIR-V type instruction.
6335+
class SpirvOperand {
6336+
public:
6337+
enum SpirvOperandKind : unsigned char {
6338+
kInvalid, ///< Uninitialized.
6339+
kConstantId, ///< Integral value to represent as a SPIR-V OpConstant
6340+
///< instruction ID.
6341+
kLiteral, ///< Integral value to represent as an immediate literal.
6342+
kTypeId, ///< Type to represent as a SPIR-V type ID.
6343+
6344+
kMax,
6345+
};
6346+
6347+
private:
6348+
SpirvOperandKind Kind = kInvalid;
6349+
6350+
QualType ResultType;
6351+
llvm::APInt Value; // Signedness of constants is represented by ResultType.
6352+
6353+
public:
6354+
SpirvOperand() : Kind(kInvalid), ResultType() {}
6355+
6356+
SpirvOperand(SpirvOperandKind Kind, QualType ResultType, llvm::APInt Value)
6357+
: Kind(Kind), ResultType(ResultType), Value(Value) {}
6358+
6359+
SpirvOperand(const SpirvOperand &Other) { *this = Other; }
6360+
~SpirvOperand() {}
6361+
6362+
SpirvOperand &operator=(const SpirvOperand &Other) {
6363+
this->Kind = Other.Kind;
6364+
this->ResultType = Other.ResultType;
6365+
this->Value = Other.Value;
6366+
return *this;
6367+
}
6368+
6369+
bool operator==(const SpirvOperand &Other) const {
6370+
return Kind == Other.Kind && ResultType == Other.ResultType &&
6371+
Value == Other.Value;
6372+
}
6373+
6374+
bool operator!=(const SpirvOperand &Other) const { return !(*this == Other); }
6375+
6376+
SpirvOperandKind getKind() const { return Kind; }
6377+
6378+
bool isValid() const { return Kind != kInvalid && Kind < kMax; }
6379+
bool isConstant() const { return Kind == kConstantId; }
6380+
bool isLiteral() const { return Kind == kLiteral; }
6381+
bool isType() const { return Kind == kTypeId; }
6382+
6383+
llvm::APInt getValue() const {
6384+
assert((isConstant() || isLiteral()) &&
6385+
"This is not an operand with a value!");
6386+
return Value;
6387+
}
6388+
6389+
QualType getResultType() const {
6390+
assert((isConstant() || isType()) &&
6391+
"This is not an operand with a result type!");
6392+
return ResultType;
6393+
}
6394+
6395+
static SpirvOperand createConstant(QualType ResultType, llvm::APInt Val) {
6396+
return SpirvOperand(kConstantId, ResultType, Val);
6397+
}
6398+
6399+
static SpirvOperand createLiteral(llvm::APInt Val) {
6400+
return SpirvOperand(kLiteral, QualType(), Val);
6401+
}
6402+
6403+
static SpirvOperand createType(QualType T) {
6404+
return SpirvOperand(kTypeId, T, llvm::APSInt());
6405+
}
6406+
6407+
void Profile(llvm::FoldingSetNodeID &ID) const {
6408+
ID.AddInteger(Kind);
6409+
ID.AddPointer(ResultType.getAsOpaquePtr());
6410+
Value.Profile(ID);
6411+
}
6412+
};
6413+
6414+
/// Represents an arbitrary, user-specified SPIR-V type instruction.
6415+
class HLSLInlineSpirvType final
6416+
: public Type,
6417+
public llvm::FoldingSetNode,
6418+
private llvm::TrailingObjects<HLSLInlineSpirvType, SpirvOperand> {
6419+
friend class ASTContext; // ASTContext creates these
6420+
friend TrailingObjects;
6421+
6422+
private:
6423+
uint32_t Opcode;
6424+
uint32_t Size;
6425+
uint32_t Alignment;
6426+
size_t NumOperands;
6427+
6428+
HLSLInlineSpirvType(uint32_t Opcode, uint32_t Size, uint32_t Alignment,
6429+
ArrayRef<SpirvOperand> Operands)
6430+
: Type(HLSLInlineSpirv, QualType(), TypeDependence::None), Opcode(Opcode),
6431+
Size(Size), Alignment(Alignment), NumOperands(Operands.size()) {
6432+
for (size_t I = 0; I < NumOperands; I++) {
6433+
getTrailingObjects<SpirvOperand>()[I] = Operands[I];
6434+
}
6435+
}
6436+
6437+
public:
6438+
uint32_t getOpcode() const { return Opcode; }
6439+
uint32_t getSize() const { return Size; }
6440+
uint32_t getAlignment() const { return Alignment; }
6441+
ArrayRef<SpirvOperand> getOperands() const {
6442+
return {getTrailingObjects<SpirvOperand>(), NumOperands};
6443+
}
6444+
6445+
bool isSugared() const { return false; }
6446+
QualType desugar() const { return QualType(this, 0); }
6447+
6448+
void Profile(llvm::FoldingSetNodeID &ID) {
6449+
Profile(ID, Opcode, Size, Alignment, getOperands());
6450+
}
6451+
6452+
static void Profile(llvm::FoldingSetNodeID &ID, uint32_t Opcode,
6453+
uint32_t Size, uint32_t Alignment,
6454+
ArrayRef<SpirvOperand> Operands) {
6455+
ID.AddInteger(Opcode);
6456+
ID.AddInteger(Size);
6457+
ID.AddInteger(Alignment);
6458+
for (auto &Operand : Operands) {
6459+
Operand.Profile(ID);
6460+
}
6461+
}
6462+
6463+
static bool classof(const Type *T) {
6464+
return T->getTypeClass() == HLSLInlineSpirv;
6465+
}
6466+
};
6467+
63336468
class TemplateTypeParmType : public Type, public llvm::FoldingSetNode {
63346469
friend class ASTContext; // ASTContext creates these
63356470

@@ -8458,13 +8593,18 @@ inline bool Type::isHLSLBuiltinIntangibleType() const {
84588593
}
84598594

84608595
inline bool Type::isHLSLSpecificType() const {
8461-
return isHLSLBuiltinIntangibleType() || isHLSLAttributedResourceType();
8596+
return isHLSLBuiltinIntangibleType() || isHLSLAttributedResourceType() ||
8597+
isHLSLInlineSpirvType();
84628598
}
84638599

84648600
inline bool Type::isHLSLAttributedResourceType() const {
84658601
return isa<HLSLAttributedResourceType>(this);
84668602
}
84678603

8604+
inline bool Type::isHLSLInlineSpirvType() const {
8605+
return isa<HLSLInlineSpirvType>(this);
8606+
}
8607+
84688608
inline bool Type::isTemplateTypeParmType() const {
84698609
return isa<TemplateTypeParmType>(CanonicalType);
84708610
}

clang/include/clang/AST/TypeLoc.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -973,6 +973,25 @@ class HLSLAttributedResourceTypeLoc
973973
}
974974
};
975975

976+
struct HLSLInlineSpirvTypeLocInfo {
977+
SourceLocation Loc;
978+
}; // Nothing.
979+
980+
class HLSLInlineSpirvTypeLoc
981+
: public ConcreteTypeLoc<UnqualTypeLoc, HLSLInlineSpirvTypeLoc,
982+
HLSLInlineSpirvType, HLSLInlineSpirvTypeLocInfo> {
983+
public:
984+
SourceLocation getSpirvTypeLoc() const { return getLocalData()->Loc; }
985+
void setSpirvTypeLoc(SourceLocation loc) const { getLocalData()->Loc = loc; }
986+
987+
SourceRange getLocalSourceRange() const {
988+
return SourceRange(getSpirvTypeLoc(), getSpirvTypeLoc());
989+
}
990+
void initializeLocal(ASTContext &Context, SourceLocation loc) {
991+
setSpirvTypeLoc(loc);
992+
}
993+
};
994+
976995
struct ObjCObjectTypeLocInfo {
977996
SourceLocation TypeArgsLAngleLoc;
978997
SourceLocation TypeArgsRAngleLoc;

clang/include/clang/AST/TypeProperties.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,24 @@ let Class = HLSLAttributedResourceType in {
719719
}]>;
720720
}
721721

722+
let Class = HLSLInlineSpirvType in {
723+
def : Property<"opcode", UInt32> {
724+
let Read = [{ node->getOpcode() }];
725+
}
726+
def : Property<"size", UInt32> {
727+
let Read = [{ node->getSize() }];
728+
}
729+
def : Property<"alignment", UInt32> {
730+
let Read = [{ node->getAlignment() }];
731+
}
732+
def : Property<"operands", Array<HLSLSpirvOperand>> {
733+
let Read = [{ node->getOperands() }];
734+
}
735+
def : Creator<[{
736+
return ctx.getHLSLInlineSpirvType(opcode, size, alignment, operands);
737+
}]>;
738+
}
739+
722740
let Class = DependentAddressSpaceType in {
723741
def : Property<"pointeeType", QualType> {
724742
let Read = [{ node->getPointeeType() }];

clang/include/clang/Basic/BuiltinTemplates.td

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,25 +28,37 @@ class BuiltinNTTP<string type_name> : TemplateArg<""> {
2828
}
2929

3030
def SizeT : BuiltinNTTP<"size_t"> {}
31+
def Uint32T: BuiltinNTTP<"uint32_t"> {}
3132

3233
class BuiltinTemplate<list<TemplateArg> template_head> {
3334
list<TemplateArg> TemplateHead = template_head;
3435
}
3536

37+
class CPlusPlusBuiltinTemplate<list<TemplateArg> template_head> : BuiltinTemplate<template_head>;
38+
39+
class HLSLBuiltinTemplate<list<TemplateArg> template_head> : BuiltinTemplate<template_head>;
40+
3641
// template <template <class T, T... Ints> IntSeq, class T, T N>
37-
def __make_integer_seq : BuiltinTemplate<
42+
def __make_integer_seq : CPlusPlusBuiltinTemplate<
3843
[Template<[Class<"T">, NTTP<"T", "Ints", /*is_variadic=*/1>], "IntSeq">, Class<"T">, NTTP<"T", "N">]>;
3944

4045
// template <size_t, class... T>
41-
def __type_pack_element : BuiltinTemplate<
46+
def __type_pack_element : CPlusPlusBuiltinTemplate<
4247
[SizeT, Class<"T", /*is_variadic=*/1>]>;
4348

4449
// template <template <class... Args> BaseTemplate,
4550
// template <class TypeMember> HasTypeMember,
4651
// class HasNoTypeMember
4752
// class... Ts>
48-
def __builtin_common_type : BuiltinTemplate<
53+
def __builtin_common_type : CPlusPlusBuiltinTemplate<
4954
[Template<[Class<"Args", /*is_variadic=*/1>], "BaseTemplate">,
5055
Template<[Class<"TypeMember">], "HasTypeMember">,
5156
Class<"HasNoTypeMember">,
5257
Class<"Ts", /*is_variadic=*/1>]>;
58+
59+
// template <uint32_t Opcode,
60+
// uint32_t Size,
61+
// uint32_t Alignment,
62+
// typename ...Operands>
63+
def __hlsl_spirv_type : HLSLBuiltinTemplate<
64+
[Uint32T, Uint32T, Uint32T, Class<"Operands", /*is_variadic=*/1>]>;

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12709,6 +12709,9 @@ def err_hlsl_expect_arg_const_int_one_or_neg_one: Error<
1270912709
def err_invalid_hlsl_resource_type: Error<
1271012710
"invalid __hlsl_resource_t type attributes">;
1271112711

12712+
def err_hlsl_spirv_only: Error<"%0 is only available for the SPIR-V target">;
12713+
def err_hlsl_vk_literal_must_contain_constant: Error<"the argument to vk::Literal must be a vk::integral_constant">;
12714+
1271212715
// Layout randomization diagnostics.
1271312716
def err_non_designated_init_used : Error<
1271412717
"a randomized struct can only be initialized with a designated initializer">;

0 commit comments

Comments
 (0)