|
15 | 15 | #include "llvm/ADT/bit.h"
|
16 | 16 | #include "llvm/IR/IRBuilder.h"
|
17 | 17 | #include "llvm/IR/Metadata.h"
|
| 18 | +#include "llvm/Support/ScopedPrinter.h" |
18 | 19 |
|
19 | 20 | namespace llvm {
|
20 | 21 | namespace hlsl {
|
21 | 22 | namespace rootsig {
|
22 | 23 |
|
23 |
| -static raw_ostream &operator<<(raw_ostream &OS, const Register &Reg) { |
24 |
| - switch (Reg.ViewType) { |
25 |
| - case RegisterType::BReg: |
26 |
| - OS << "b"; |
27 |
| - break; |
28 |
| - case RegisterType::TReg: |
29 |
| - OS << "t"; |
30 |
| - break; |
31 |
| - case RegisterType::UReg: |
32 |
| - OS << "u"; |
33 |
| - break; |
34 |
| - case RegisterType::SReg: |
35 |
| - OS << "s"; |
36 |
| - break; |
37 |
| - } |
38 |
| - OS << Reg.Number; |
39 |
| - return OS; |
| 24 | +template <typename T> |
| 25 | +static StringRef getEnumName(const T Value, ArrayRef<EnumEntry<T>> Enums) { |
| 26 | + for (const auto &EnumItem : Enums) |
| 27 | + if (EnumItem.Value == Value) |
| 28 | + return EnumItem.Name; |
| 29 | + return ""; |
40 | 30 | }
|
41 | 31 |
|
42 |
| -static raw_ostream &operator<<(raw_ostream &OS, |
43 |
| - const ShaderVisibility &Visibility) { |
44 |
| - switch (Visibility) { |
45 |
| - case ShaderVisibility::All: |
46 |
| - OS << "All"; |
47 |
| - break; |
48 |
| - case ShaderVisibility::Vertex: |
49 |
| - OS << "Vertex"; |
50 |
| - break; |
51 |
| - case ShaderVisibility::Hull: |
52 |
| - OS << "Hull"; |
53 |
| - break; |
54 |
| - case ShaderVisibility::Domain: |
55 |
| - OS << "Domain"; |
56 |
| - break; |
57 |
| - case ShaderVisibility::Geometry: |
58 |
| - OS << "Geometry"; |
59 |
| - break; |
60 |
| - case ShaderVisibility::Pixel: |
61 |
| - OS << "Pixel"; |
62 |
| - break; |
63 |
| - case ShaderVisibility::Amplification: |
64 |
| - OS << "Amplification"; |
65 |
| - break; |
66 |
| - case ShaderVisibility::Mesh: |
67 |
| - OS << "Mesh"; |
68 |
| - break; |
69 |
| - } |
| 32 | +template <typename T> |
| 33 | +static raw_ostream &printEnum(raw_ostream &OS, const T Value, |
| 34 | + ArrayRef<EnumEntry<T>> Enums) { |
| 35 | + OS << getEnumName(Value, Enums); |
70 | 36 |
|
71 | 37 | return OS;
|
72 | 38 | }
|
73 | 39 |
|
74 |
| -static raw_ostream &operator<<(raw_ostream &OS, const ClauseType &Type) { |
75 |
| - switch (Type) { |
76 |
| - case ClauseType::CBuffer: |
77 |
| - OS << "CBV"; |
78 |
| - break; |
79 |
| - case ClauseType::SRV: |
80 |
| - OS << "SRV"; |
81 |
| - break; |
82 |
| - case ClauseType::UAV: |
83 |
| - OS << "UAV"; |
84 |
| - break; |
85 |
| - case ClauseType::Sampler: |
86 |
| - OS << "Sampler"; |
87 |
| - break; |
88 |
| - } |
89 |
| - |
90 |
| - return OS; |
91 |
| -} |
92 |
| - |
93 |
| -static raw_ostream &operator<<(raw_ostream &OS, |
94 |
| - const DescriptorRangeFlags &Flags) { |
| 40 | +template <typename T> |
| 41 | +static raw_ostream &printFlags(raw_ostream &OS, const T Value, |
| 42 | + ArrayRef<EnumEntry<T>> Flags) { |
95 | 43 | bool FlagSet = false;
|
96 |
| - unsigned Remaining = llvm::to_underlying(Flags); |
| 44 | + unsigned Remaining = llvm::to_underlying(Value); |
97 | 45 | while (Remaining) {
|
98 | 46 | unsigned Bit = 1u << llvm::countr_zero(Remaining);
|
99 | 47 | if (Remaining & Bit) {
|
100 | 48 | if (FlagSet)
|
101 | 49 | OS << " | ";
|
102 | 50 |
|
103 |
| - switch (static_cast<DescriptorRangeFlags>(Bit)) { |
104 |
| - case DescriptorRangeFlags::DescriptorsVolatile: |
105 |
| - OS << "DescriptorsVolatile"; |
106 |
| - break; |
107 |
| - case DescriptorRangeFlags::DataVolatile: |
108 |
| - OS << "DataVolatile"; |
109 |
| - break; |
110 |
| - case DescriptorRangeFlags::DataStaticWhileSetAtExecute: |
111 |
| - OS << "DataStaticWhileSetAtExecute"; |
112 |
| - break; |
113 |
| - case DescriptorRangeFlags::DataStatic: |
114 |
| - OS << "DataStatic"; |
115 |
| - break; |
116 |
| - case DescriptorRangeFlags::DescriptorsStaticKeepingBufferBoundsChecks: |
117 |
| - OS << "DescriptorsStaticKeepingBufferBoundsChecks"; |
118 |
| - break; |
119 |
| - default: |
| 51 | + bool Found = false; |
| 52 | + for (const auto &FlagItem : Flags) |
| 53 | + if (FlagItem.Value == T(Bit)) { |
| 54 | + OS << FlagItem.Name; |
| 55 | + Found = true; |
| 56 | + break; |
| 57 | + } |
| 58 | + if (!Found) |
120 | 59 | OS << "invalid: " << Bit;
|
121 |
| - break; |
122 |
| - } |
123 |
| - |
124 | 60 | FlagSet = true;
|
125 | 61 | }
|
126 | 62 | Remaining &= ~Bit;
|
127 | 63 | }
|
128 | 64 |
|
129 | 65 | if (!FlagSet)
|
130 | 66 | OS << "None";
|
| 67 | + return OS; |
| 68 | +} |
| 69 | + |
| 70 | +static const EnumEntry<RegisterType> RegisterNames[] = { |
| 71 | + {"b", RegisterType::BReg}, |
| 72 | + {"t", RegisterType::TReg}, |
| 73 | + {"u", RegisterType::UReg}, |
| 74 | + {"s", RegisterType::SReg}, |
| 75 | +}; |
| 76 | + |
| 77 | +static raw_ostream &operator<<(raw_ostream &OS, const Register &Reg) { |
| 78 | + printEnum(OS, Reg.ViewType, ArrayRef(RegisterNames)); |
| 79 | + OS << Reg.Number; |
| 80 | + |
| 81 | + return OS; |
| 82 | +} |
| 83 | + |
| 84 | +static const EnumEntry<ShaderVisibility> VisibilityNames[] = { |
| 85 | + {"All", ShaderVisibility::All}, |
| 86 | + {"Vertex", ShaderVisibility::Vertex}, |
| 87 | + {"Hull", ShaderVisibility::Hull}, |
| 88 | + {"Domain", ShaderVisibility::Domain}, |
| 89 | + {"Geometry", ShaderVisibility::Geometry}, |
| 90 | + {"Pixel", ShaderVisibility::Pixel}, |
| 91 | + {"Amplification", ShaderVisibility::Amplification}, |
| 92 | + {"Mesh", ShaderVisibility::Mesh}, |
| 93 | +}; |
| 94 | + |
| 95 | +static raw_ostream &operator<<(raw_ostream &OS, |
| 96 | + const ShaderVisibility &Visibility) { |
| 97 | + printEnum(OS, Visibility, ArrayRef(VisibilityNames)); |
| 98 | + |
| 99 | + return OS; |
| 100 | +} |
| 101 | + |
| 102 | +static const EnumEntry<dxil::ResourceClass> ResourceClassNames[] = { |
| 103 | + {"CBV", dxil::ResourceClass::CBuffer}, |
| 104 | + {"SRV", dxil::ResourceClass::SRV}, |
| 105 | + {"UAV", dxil::ResourceClass::UAV}, |
| 106 | + {"Sampler", dxil::ResourceClass::Sampler}, |
| 107 | +}; |
| 108 | + |
| 109 | +static raw_ostream &operator<<(raw_ostream &OS, const ClauseType &Type) { |
| 110 | + printEnum(OS, dxil::ResourceClass(llvm::to_underlying(Type)), |
| 111 | + ArrayRef(ResourceClassNames)); |
| 112 | + |
| 113 | + return OS; |
| 114 | +} |
| 115 | + |
| 116 | +static const EnumEntry<DescriptorRangeFlags> DescriptorRangeFlagNames[] = { |
| 117 | + {"DescriptorsVolatile", DescriptorRangeFlags::DescriptorsVolatile}, |
| 118 | + {"DataVolatile", DescriptorRangeFlags::DataVolatile}, |
| 119 | + {"DataStaticWhileSetAtExecute", |
| 120 | + DescriptorRangeFlags::DataStaticWhileSetAtExecute}, |
| 121 | + {"DataStatic", DescriptorRangeFlags::DataStatic}, |
| 122 | + {"DescriptorsStaticKeepingBufferBoundsChecks", |
| 123 | + DescriptorRangeFlags::DescriptorsStaticKeepingBufferBoundsChecks}, |
| 124 | +}; |
| 125 | + |
| 126 | +static raw_ostream &operator<<(raw_ostream &OS, |
| 127 | + const DescriptorRangeFlags &Flags) { |
| 128 | + printFlags(OS, Flags, ArrayRef(DescriptorRangeFlagNames)); |
131 | 129 |
|
132 | 130 | return OS;
|
133 | 131 | }
|
|
0 commit comments