1515#include " llvm/ADT/bit.h"
1616#include " llvm/IR/IRBuilder.h"
1717#include " llvm/IR/Metadata.h"
18+ #include " llvm/Support/ScopedPrinter.h"
1819
1920namespace llvm {
2021namespace hlsl {
2122namespace rootsig {
2223
23- static raw_ostream &operator <<(raw_ostream &OS, const Register &Reg) {
24- switch (Reg.ViewType ) {
25- case RegisterType::BReg:
26- OS << " b" ;
27- break ;
28- case RegisterType::TReg:
29- OS << " t" ;
30- break ;
31- case RegisterType::UReg:
32- OS << " u" ;
33- break ;
34- case RegisterType::SReg:
35- OS << " s" ;
36- break ;
37- }
38- OS << Reg.Number ;
39- return OS;
24+ template <typename T>
25+ static std::optional<StringRef> getEnumName (const T Value,
26+ ArrayRef<EnumEntry<T>> Enums) {
27+ for (const auto &EnumItem : Enums)
28+ if (EnumItem.Value == Value)
29+ return EnumItem.Name ;
30+ return std::nullopt ;
4031}
4132
42- static raw_ostream &operator <<(raw_ostream &OS,
43- const ShaderVisibility &Visibility) {
44- switch (Visibility) {
45- case ShaderVisibility::All:
46- OS << " All" ;
47- break ;
48- case ShaderVisibility::Vertex:
49- OS << " Vertex" ;
50- break ;
51- case ShaderVisibility::Hull:
52- OS << " Hull" ;
53- break ;
54- case ShaderVisibility::Domain:
55- OS << " Domain" ;
56- break ;
57- case ShaderVisibility::Geometry:
58- OS << " Geometry" ;
59- break ;
60- case ShaderVisibility::Pixel:
61- OS << " Pixel" ;
62- break ;
63- case ShaderVisibility::Amplification:
64- OS << " Amplification" ;
65- break ;
66- case ShaderVisibility::Mesh:
67- OS << " Mesh" ;
68- break ;
69- }
70-
71- return OS;
72- }
73-
74- static raw_ostream &operator <<(raw_ostream &OS, const ClauseType &Type) {
75- switch (Type) {
76- case ClauseType::CBuffer:
77- OS << " CBV" ;
78- break ;
79- case ClauseType::SRV:
80- OS << " SRV" ;
81- break ;
82- case ClauseType::UAV:
83- OS << " UAV" ;
84- break ;
85- case ClauseType::Sampler:
86- OS << " Sampler" ;
87- break ;
88- }
89-
33+ template <typename T>
34+ static raw_ostream &printEnum (raw_ostream &OS, const T Value,
35+ ArrayRef<EnumEntry<T>> Enums) {
36+ auto MaybeName = getEnumName (Value, Enums);
37+ if (MaybeName)
38+ OS << *MaybeName;
9039 return OS;
9140}
9241
93- static raw_ostream &operator <<(raw_ostream &OS,
94- const DescriptorRangeFlags &Flags) {
42+ template <typename T>
43+ static raw_ostream &printFlags (raw_ostream &OS, const T Value,
44+ ArrayRef<EnumEntry<T>> Flags) {
9545 bool FlagSet = false ;
96- unsigned Remaining = llvm::to_underlying (Flags );
46+ unsigned Remaining = llvm::to_underlying (Value );
9747 while (Remaining) {
9848 unsigned Bit = 1u << llvm::countr_zero (Remaining);
9949 if (Remaining & Bit) {
10050 if (FlagSet)
10151 OS << " | " ;
10252
103- switch (static_cast <DescriptorRangeFlags>(Bit)) {
104- case DescriptorRangeFlags::DescriptorsVolatile:
105- OS << " DescriptorsVolatile" ;
106- break ;
107- case DescriptorRangeFlags::DataVolatile:
108- OS << " DataVolatile" ;
109- break ;
110- case DescriptorRangeFlags::DataStaticWhileSetAtExecute:
111- OS << " DataStaticWhileSetAtExecute" ;
112- break ;
113- case DescriptorRangeFlags::DataStatic:
114- OS << " DataStatic" ;
115- break ;
116- case DescriptorRangeFlags::DescriptorsStaticKeepingBufferBoundsChecks:
117- OS << " DescriptorsStaticKeepingBufferBoundsChecks" ;
118- break ;
119- default :
53+ auto MaybeFlag = getEnumName (T (Bit), Flags);
54+ if (MaybeFlag)
55+ OS << *MaybeFlag;
56+ else
12057 OS << " invalid: " << Bit;
121- break ;
122- }
12358
12459 FlagSet = true ;
12560 }
@@ -128,6 +63,68 @@ static raw_ostream &operator<<(raw_ostream &OS,
12863
12964 if (!FlagSet)
13065 OS << " None" ;
66+ return OS;
67+ }
68+
69+ static const EnumEntry<RegisterType> RegisterNames[] = {
70+ {" b" , RegisterType::BReg},
71+ {" t" , RegisterType::TReg},
72+ {" u" , RegisterType::UReg},
73+ {" s" , RegisterType::SReg},
74+ };
75+
76+ static raw_ostream &operator <<(raw_ostream &OS, const Register &Reg) {
77+ printEnum (OS, Reg.ViewType , ArrayRef (RegisterNames));
78+ OS << Reg.Number ;
79+
80+ return OS;
81+ }
82+
83+ static const EnumEntry<ShaderVisibility> VisibilityNames[] = {
84+ {" All" , ShaderVisibility::All},
85+ {" Vertex" , ShaderVisibility::Vertex},
86+ {" Hull" , ShaderVisibility::Hull},
87+ {" Domain" , ShaderVisibility::Domain},
88+ {" Geometry" , ShaderVisibility::Geometry},
89+ {" Pixel" , ShaderVisibility::Pixel},
90+ {" Amplification" , ShaderVisibility::Amplification},
91+ {" Mesh" , ShaderVisibility::Mesh},
92+ };
93+
94+ static raw_ostream &operator <<(raw_ostream &OS,
95+ const ShaderVisibility &Visibility) {
96+ printEnum (OS, Visibility, ArrayRef (VisibilityNames));
97+
98+ return OS;
99+ }
100+
101+ static const EnumEntry<dxil::ResourceClass> ResourceClassNames[] = {
102+ {" CBV" , dxil::ResourceClass::CBuffer},
103+ {" SRV" , dxil::ResourceClass::SRV},
104+ {" UAV" , dxil::ResourceClass::UAV},
105+ {" Sampler" , dxil::ResourceClass::Sampler},
106+ };
107+
108+ static raw_ostream &operator <<(raw_ostream &OS, const ClauseType &Type) {
109+ printEnum (OS, dxil::ResourceClass (llvm::to_underlying (Type)),
110+ ArrayRef (ResourceClassNames));
111+
112+ return OS;
113+ }
114+
115+ static const EnumEntry<DescriptorRangeFlags> DescriptorRangeFlagNames[] = {
116+ {" DescriptorsVolatile" , DescriptorRangeFlags::DescriptorsVolatile},
117+ {" DataVolatile" , DescriptorRangeFlags::DataVolatile},
118+ {" DataStaticWhileSetAtExecute" ,
119+ DescriptorRangeFlags::DataStaticWhileSetAtExecute},
120+ {" DataStatic" , DescriptorRangeFlags::DataStatic},
121+ {" DescriptorsStaticKeepingBufferBoundsChecks" ,
122+ DescriptorRangeFlags::DescriptorsStaticKeepingBufferBoundsChecks},
123+ };
124+
125+ static raw_ostream &operator <<(raw_ostream &OS,
126+ const DescriptorRangeFlags &Flags) {
127+ printFlags (OS, Flags, ArrayRef (DescriptorRangeFlagNames));
131128
132129 return OS;
133130}
@@ -236,12 +233,13 @@ MDNode *MetadataBuilder::BuildRootConstants(const RootConstants &Constants) {
236233
237234MDNode *MetadataBuilder::BuildRootDescriptor (const RootDescriptor &Descriptor) {
238235 IRBuilder<> Builder (Ctx);
239- llvm::SmallString<7 > Name;
240- llvm::raw_svector_ostream OS (Name);
241- OS << " Root" << ClauseType (llvm::to_underlying (Descriptor.Type ));
242-
236+ std::optional<StringRef> TypeName =
237+ getEnumName (dxil::ResourceClass (llvm::to_underlying (Descriptor.Type )),
238+ ArrayRef (ResourceClassNames));
239+ assert (TypeName && " Provided an invalid Resource Class" );
240+ llvm::SmallString<7 > Name ({" Root" , *TypeName});
243241 Metadata *Operands[] = {
244- MDString::get (Ctx, OS. str () ),
242+ MDString::get (Ctx, Name ),
245243 ConstantAsMetadata::get (
246244 Builder.getInt32 (llvm::to_underlying (Descriptor.Visibility ))),
247245 ConstantAsMetadata::get (Builder.getInt32 (Descriptor.Reg .Number )),
@@ -277,19 +275,20 @@ MDNode *MetadataBuilder::BuildDescriptorTable(const DescriptorTable &Table) {
277275MDNode *MetadataBuilder::BuildDescriptorTableClause (
278276 const DescriptorTableClause &Clause) {
279277 IRBuilder<> Builder (Ctx);
280- std::string Name;
281- llvm::raw_string_ostream OS (Name);
282- OS << Clause.Type ;
283- return MDNode::get (
284- Ctx, {
285- MDString::get (Ctx, OS.str ()),
286- ConstantAsMetadata::get (Builder.getInt32 (Clause.NumDescriptors )),
287- ConstantAsMetadata::get (Builder.getInt32 (Clause.Reg .Number )),
288- ConstantAsMetadata::get (Builder.getInt32 (Clause.Space )),
289- ConstantAsMetadata::get (Builder.getInt32 (Clause.Offset )),
290- ConstantAsMetadata::get (
291- Builder.getInt32 (llvm::to_underlying (Clause.Flags ))),
292- });
278+ std::optional<StringRef> Name =
279+ getEnumName (dxil::ResourceClass (llvm::to_underlying (Clause.Type )),
280+ ArrayRef (ResourceClassNames));
281+ assert (Name && " Provided an invalid Resource Class" );
282+ Metadata *Operands[] = {
283+ MDString::get (Ctx, *Name),
284+ ConstantAsMetadata::get (Builder.getInt32 (Clause.NumDescriptors )),
285+ ConstantAsMetadata::get (Builder.getInt32 (Clause.Reg .Number )),
286+ ConstantAsMetadata::get (Builder.getInt32 (Clause.Space )),
287+ ConstantAsMetadata::get (Builder.getInt32 (Clause.Offset )),
288+ ConstantAsMetadata::get (
289+ Builder.getInt32 (llvm::to_underlying (Clause.Flags ))),
290+ };
291+ return MDNode::get (Ctx, Operands);
293292}
294293
295294MDNode *MetadataBuilder::BuildStaticSampler (const StaticSampler &Sampler) {
0 commit comments