|
15 | 15 | #include "llvm/ADT/bit.h" |
16 | 16 | #include "llvm/IR/IRBuilder.h" |
17 | 17 | #include "llvm/IR/Metadata.h" |
| 18 | +#include "llvm/Support/ScopedPrinter.h" |
18 | 19 |
|
19 | 20 | namespace llvm { |
20 | 21 | namespace hlsl { |
21 | 22 | namespace rootsig { |
22 | 23 |
|
23 | | -static raw_ostream &operator<<(raw_ostream &OS, const Register &Reg) { |
24 | | - switch (Reg.ViewType) { |
25 | | - case RegisterType::BReg: |
26 | | - OS << "b"; |
27 | | - break; |
28 | | - case RegisterType::TReg: |
29 | | - OS << "t"; |
30 | | - break; |
31 | | - case RegisterType::UReg: |
32 | | - OS << "u"; |
33 | | - break; |
34 | | - case RegisterType::SReg: |
35 | | - OS << "s"; |
36 | | - break; |
37 | | - } |
38 | | - OS << Reg.Number; |
39 | | - return OS; |
| 24 | +template <typename T> |
| 25 | +static StringRef getEnumName(const T Value, ArrayRef<EnumEntry<T>> Enums) { |
| 26 | + for (const auto &EnumItem : Enums) |
| 27 | + if (EnumItem.Value == Value) |
| 28 | + return EnumItem.Name; |
| 29 | + return ""; |
40 | 30 | } |
41 | 31 |
|
42 | | -static raw_ostream &operator<<(raw_ostream &OS, |
43 | | - const ShaderVisibility &Visibility) { |
44 | | - switch (Visibility) { |
45 | | - case ShaderVisibility::All: |
46 | | - OS << "All"; |
47 | | - break; |
48 | | - case ShaderVisibility::Vertex: |
49 | | - OS << "Vertex"; |
50 | | - break; |
51 | | - case ShaderVisibility::Hull: |
52 | | - OS << "Hull"; |
53 | | - break; |
54 | | - case ShaderVisibility::Domain: |
55 | | - OS << "Domain"; |
56 | | - break; |
57 | | - case ShaderVisibility::Geometry: |
58 | | - OS << "Geometry"; |
59 | | - break; |
60 | | - case ShaderVisibility::Pixel: |
61 | | - OS << "Pixel"; |
62 | | - break; |
63 | | - case ShaderVisibility::Amplification: |
64 | | - OS << "Amplification"; |
65 | | - break; |
66 | | - case ShaderVisibility::Mesh: |
67 | | - OS << "Mesh"; |
68 | | - break; |
69 | | - } |
| 32 | +template <typename T> |
| 33 | +static raw_ostream &printEnum(raw_ostream &OS, const T Value, |
| 34 | + ArrayRef<EnumEntry<T>> Enums) { |
| 35 | + OS << getEnumName(Value, Enums); |
70 | 36 |
|
71 | 37 | return OS; |
72 | 38 | } |
73 | 39 |
|
74 | | -static raw_ostream &operator<<(raw_ostream &OS, const ClauseType &Type) { |
75 | | - switch (Type) { |
76 | | - case ClauseType::CBuffer: |
77 | | - OS << "CBV"; |
78 | | - break; |
79 | | - case ClauseType::SRV: |
80 | | - OS << "SRV"; |
81 | | - break; |
82 | | - case ClauseType::UAV: |
83 | | - OS << "UAV"; |
84 | | - break; |
85 | | - case ClauseType::Sampler: |
86 | | - OS << "Sampler"; |
87 | | - break; |
88 | | - } |
89 | | - |
90 | | - return OS; |
91 | | -} |
92 | | - |
93 | | -static raw_ostream &operator<<(raw_ostream &OS, |
94 | | - const DescriptorRangeFlags &Flags) { |
| 40 | +template <typename T> |
| 41 | +static raw_ostream &printFlags(raw_ostream &OS, const T Value, |
| 42 | + ArrayRef<EnumEntry<T>> Flags) { |
95 | 43 | bool FlagSet = false; |
96 | | - unsigned Remaining = llvm::to_underlying(Flags); |
| 44 | + unsigned Remaining = llvm::to_underlying(Value); |
97 | 45 | while (Remaining) { |
98 | 46 | unsigned Bit = 1u << llvm::countr_zero(Remaining); |
99 | 47 | if (Remaining & Bit) { |
100 | 48 | if (FlagSet) |
101 | 49 | OS << " | "; |
102 | 50 |
|
103 | | - switch (static_cast<DescriptorRangeFlags>(Bit)) { |
104 | | - case DescriptorRangeFlags::DescriptorsVolatile: |
105 | | - OS << "DescriptorsVolatile"; |
106 | | - break; |
107 | | - case DescriptorRangeFlags::DataVolatile: |
108 | | - OS << "DataVolatile"; |
109 | | - break; |
110 | | - case DescriptorRangeFlags::DataStaticWhileSetAtExecute: |
111 | | - OS << "DataStaticWhileSetAtExecute"; |
112 | | - break; |
113 | | - case DescriptorRangeFlags::DataStatic: |
114 | | - OS << "DataStatic"; |
115 | | - break; |
116 | | - case DescriptorRangeFlags::DescriptorsStaticKeepingBufferBoundsChecks: |
117 | | - OS << "DescriptorsStaticKeepingBufferBoundsChecks"; |
118 | | - break; |
119 | | - default: |
| 51 | + bool Found = false; |
| 52 | + for (const auto &FlagItem : Flags) |
| 53 | + if (FlagItem.Value == T(Bit)) { |
| 54 | + OS << FlagItem.Name; |
| 55 | + Found = true; |
| 56 | + break; |
| 57 | + } |
| 58 | + if (!Found) |
120 | 59 | OS << "invalid: " << Bit; |
121 | | - break; |
122 | | - } |
123 | | - |
124 | 60 | FlagSet = true; |
125 | 61 | } |
126 | 62 | Remaining &= ~Bit; |
127 | 63 | } |
128 | 64 |
|
129 | 65 | if (!FlagSet) |
130 | 66 | OS << "None"; |
| 67 | + return OS; |
| 68 | +} |
| 69 | + |
| 70 | +static const EnumEntry<RegisterType> RegisterNames[] = { |
| 71 | + {"b", RegisterType::BReg}, |
| 72 | + {"t", RegisterType::TReg}, |
| 73 | + {"u", RegisterType::UReg}, |
| 74 | + {"s", RegisterType::SReg}, |
| 75 | +}; |
| 76 | + |
| 77 | +static raw_ostream &operator<<(raw_ostream &OS, const Register &Reg) { |
| 78 | + printEnum(OS, Reg.ViewType, ArrayRef(RegisterNames)); |
| 79 | + OS << Reg.Number; |
| 80 | + |
| 81 | + return OS; |
| 82 | +} |
| 83 | + |
| 84 | +static const EnumEntry<ShaderVisibility> VisibilityNames[] = { |
| 85 | + {"All", ShaderVisibility::All}, |
| 86 | + {"Vertex", ShaderVisibility::Vertex}, |
| 87 | + {"Hull", ShaderVisibility::Hull}, |
| 88 | + {"Domain", ShaderVisibility::Domain}, |
| 89 | + {"Geometry", ShaderVisibility::Geometry}, |
| 90 | + {"Pixel", ShaderVisibility::Pixel}, |
| 91 | + {"Amplification", ShaderVisibility::Amplification}, |
| 92 | + {"Mesh", ShaderVisibility::Mesh}, |
| 93 | +}; |
| 94 | + |
| 95 | +static raw_ostream &operator<<(raw_ostream &OS, |
| 96 | + const ShaderVisibility &Visibility) { |
| 97 | + printEnum(OS, Visibility, ArrayRef(VisibilityNames)); |
| 98 | + |
| 99 | + return OS; |
| 100 | +} |
| 101 | + |
| 102 | +static const EnumEntry<dxil::ResourceClass> ResourceClassNames[] = { |
| 103 | + {"CBV", dxil::ResourceClass::CBuffer}, |
| 104 | + {"SRV", dxil::ResourceClass::SRV}, |
| 105 | + {"UAV", dxil::ResourceClass::UAV}, |
| 106 | + {"Sampler", dxil::ResourceClass::Sampler}, |
| 107 | +}; |
| 108 | + |
| 109 | +static raw_ostream &operator<<(raw_ostream &OS, const ClauseType &Type) { |
| 110 | + printEnum(OS, dxil::ResourceClass(llvm::to_underlying(Type)), |
| 111 | + ArrayRef(ResourceClassNames)); |
| 112 | + |
| 113 | + return OS; |
| 114 | +} |
| 115 | + |
| 116 | +static const EnumEntry<DescriptorRangeFlags> DescriptorRangeFlagNames[] = { |
| 117 | + {"DescriptorsVolatile", DescriptorRangeFlags::DescriptorsVolatile}, |
| 118 | + {"DataVolatile", DescriptorRangeFlags::DataVolatile}, |
| 119 | + {"DataStaticWhileSetAtExecute", |
| 120 | + DescriptorRangeFlags::DataStaticWhileSetAtExecute}, |
| 121 | + {"DataStatic", DescriptorRangeFlags::DataStatic}, |
| 122 | + {"DescriptorsStaticKeepingBufferBoundsChecks", |
| 123 | + DescriptorRangeFlags::DescriptorsStaticKeepingBufferBoundsChecks}, |
| 124 | +}; |
| 125 | + |
| 126 | +static raw_ostream &operator<<(raw_ostream &OS, |
| 127 | + const DescriptorRangeFlags &Flags) { |
| 128 | + printFlags(OS, Flags, ArrayRef(DescriptorRangeFlagNames)); |
131 | 129 |
|
132 | 130 | return OS; |
133 | 131 | } |
|
0 commit comments