2626#include " llvm/Pass.h"
2727#include " llvm/Support/Error.h"
2828#include " llvm/Support/ErrorHandling.h"
29+ #include < cstdint>
2930#include < optional>
31+ #include < utility>
3032
3133using namespace llvm ;
3234using namespace llvm ::dxil;
@@ -37,20 +39,20 @@ static bool reportError(LLVMContext *Ctx, Twine Message,
3739 return true ;
3840}
3941
40- static bool parseRootFlags (LLVMContext *Ctx, ModuleRootSignature * MRS,
42+ static bool parseRootFlags (LLVMContext *Ctx, ModuleRootSignature & MRS,
4143 MDNode *RootFlagNode) {
4244
4345 if (RootFlagNode->getNumOperands () != 2 )
4446 return reportError (Ctx, " Invalid format for RootFlag Element" );
4547
4648 auto *Flag = mdconst::extract<ConstantInt>(RootFlagNode->getOperand (1 ));
47- MRS-> Flags = Flag->getZExtValue ();
49+ MRS. Flags = Flag->getZExtValue ();
4850
4951 return false ;
5052}
5153
5254static bool parseRootSignatureElement (LLVMContext *Ctx,
53- ModuleRootSignature * MRS,
55+ ModuleRootSignature & MRS,
5456 MDNode *Element) {
5557 MDString *ElementText = cast<MDString>(Element->getOperand (0 ));
5658 if (ElementText == nullptr )
@@ -73,8 +75,7 @@ static bool parseRootSignatureElement(LLVMContext *Ctx,
7375 llvm_unreachable (" Root signature element kind not expected." );
7476}
7577
76- static bool parse (LLVMContext *Ctx, ModuleRootSignature *MRS, NamedMDNode *Root,
77- const Function *EntryFunction) {
78+ static bool parse (LLVMContext *Ctx, ModuleRootSignature &MRS, MDNode *Node) {
7879 bool HasError = false ;
7980
8081 /* * Root Signature are specified as following in the metadata:
@@ -89,15 +90,46 @@ static bool parse(LLVMContext *Ctx, ModuleRootSignature *MRS, NamedMDNode *Root,
8990 signature pair.
9091 */
9192
92- for (const MDNode *Node : Root->operands ()) {
93- if (Node->getNumOperands () != 2 ) {
94- HasError = reportError (
95- Ctx, " Invalid format for Root Signature Definition. Pairs "
96- " of function, root signature expected." );
93+ // Get the Root Signature Description from the function signature pair.
94+
95+ // Loop through the Root Elements of the root signature.
96+ for (const auto &Operand : Node->operands ()) {
97+ MDNode *Element = dyn_cast<MDNode>(Operand);
98+ if (Element == nullptr )
99+ return reportError (Ctx, " Missing Root Element Metadata Node." );
100+
101+ HasError = HasError || parseRootSignatureElement (Ctx, MRS, Element);
102+ }
103+
104+ return HasError;
105+ }
106+
107+ static bool validate (LLVMContext *Ctx, const ModuleRootSignature &MRS) {
108+ if (!dxbc::RootSignatureValidations::isValidRootFlag (MRS.Flags )) {
109+ return reportError (Ctx, " Invalid Root Signature flag value" );
110+ }
111+ return false ;
112+ }
113+
114+ static SmallDenseMap<const Function *, ModuleRootSignature>
115+ analyzeModule (Module &M) {
116+
117+ LLVMContext *Ctx = &M.getContext ();
118+
119+ SmallDenseMap<const Function *, ModuleRootSignature> MRSMap;
120+
121+ NamedMDNode *RootSignatureNode = M.getNamedMetadata (" dx.rootsignatures" );
122+ if (RootSignatureNode == nullptr )
123+ return MRSMap;
124+
125+ for (const auto &RSDefNode : RootSignatureNode->operands ()) {
126+ if (RSDefNode->getNumOperands () != 2 ) {
127+ reportError (Ctx, " Invalid format for Root Signature Definition. Pairs "
128+ " of function, root signature expected." );
97129 continue ;
98130 }
99131
100- const MDOperand &FunctionPointerMdNode = Node ->getOperand (0 );
132+ const MDOperand &FunctionPointerMdNode = RSDefNode ->getOperand (0 );
101133 if (FunctionPointerMdNode == nullptr ) {
102134 // Function was pruned during compilation.
103135 continue ;
@@ -106,97 +138,76 @@ static bool parse(LLVMContext *Ctx, ModuleRootSignature *MRS, NamedMDNode *Root,
106138 ValueAsMetadata *VAM =
107139 llvm::dyn_cast<ValueAsMetadata>(FunctionPointerMdNode.get ());
108140 if (VAM == nullptr ) {
109- HasError =
110- reportError (Ctx, " First element of root signature is not a value" );
141+ reportError (Ctx, " First element of root signature is not a value" );
111142 continue ;
112143 }
113144
114145 Function *F = dyn_cast<Function>(VAM->getValue ());
115146 if (F == nullptr ) {
116- HasError =
117- reportError (Ctx, " First element of root signature is not a function" );
147+ reportError (Ctx, " First element of root signature is not a function" );
118148 continue ;
119149 }
120150
121- if (F != EntryFunction)
122- continue ;
151+ MDNode *RootElementListNode =
152+ dyn_cast<MDNode>(RSDefNode-> getOperand ( 1 ). get ()) ;
123153
124- // Get the Root Signature Description from the function signature pair.
125- MDNode *RS = dyn_cast<MDNode>(Node->getOperand (1 ).get ());
126-
127- if (RS == nullptr ) {
154+ if (RootElementListNode == nullptr ) {
128155 reportError (Ctx, " Missing Root Element List Metadata node." );
129- continue ;
130156 }
131157
132- // Loop through the Root Elements of the root signature.
133- for (const auto &Operand : RS->operands ()) {
134- MDNode *Element = dyn_cast<MDNode>(Operand);
135- if (Element == nullptr )
136- return reportError (Ctx, " Missing Root Element Metadata Node." );
158+ ModuleRootSignature MRS;
137159
138- HasError = HasError || parseRootSignatureElement (Ctx, MRS, Element);
160+ if (parse (Ctx, MRS, RootElementListNode) || validate (Ctx, MRS)) {
161+ return MRSMap;
139162 }
140- }
141- return HasError;
142- }
143163
144- static bool validate (LLVMContext *Ctx, ModuleRootSignature *MRS) {
145- if (!dxbc::RootSignatureValidations::isValidRootFlag (MRS->Flags )) {
146- return reportError (Ctx, " Invalid Root Signature flag value" );
164+ MRSMap.insert (std::make_pair (F, MRS));
147165 }
148- return false ;
149- }
150166
151- static const Function *getEntryFunction (Module &M, ModuleMetadataInfo MMI) {
152-
153- LLVMContext *Ctx = &M.getContext ();
154- if (MMI.EntryPropertyVec .size () != 1 ) {
155- reportError (Ctx, " More than one entry function defined." );
156- // needed to stop compilation
157- report_fatal_error (" Invalid Root Signature Definition" , false );
158- return nullptr ;
159- }
160- return MMI.EntryPropertyVec [0 ].Entry ;
167+ return MRSMap;
161168}
162169
163- std::optional<ModuleRootSignature>
164- ModuleRootSignature::analyzeModule (Module &M, const Function *F) {
165-
166- LLVMContext *Ctx = &M.getContext ();
170+ AnalysisKey RootSignatureAnalysis::Key;
167171
168- ModuleRootSignature MRS;
172+ SmallDenseMap<const Function *, ModuleRootSignature>
173+ RootSignatureAnalysis::run (Module &M, ModuleAnalysisManager &AM) {
174+ return analyzeModule (M);
175+ }
169176
170- NamedMDNode *RootSignatureNode = M.getNamedMetadata (" dx.rootsignatures" );
171- if (RootSignatureNode == nullptr )
172- return std::nullopt ;
177+ // ===----------------------------------------------------------------------===//
173178
174- if (parse (Ctx, &MRS, RootSignatureNode, F) || validate (Ctx, &MRS)) {
175- // needed to stop compilation
176- report_fatal_error (" Invalid Root Signature Definition" , false );
177- return std::nullopt ;
179+ static void printSpaces (raw_ostream &Stream, unsigned int Count) {
180+ for (unsigned int I = 0 ; I < Count; ++I) {
181+ Stream << ' ' ;
178182 }
179-
180- return MRS;
181183}
182184
183- AnalysisKey RootSignatureAnalysis::Key;
185+ PreservedAnalyses RootSignatureAnalysisPrinter::run (Module &M,
186+ ModuleAnalysisManager &AM) {
187+
188+ SmallDenseMap<const Function *, ModuleRootSignature> &MRSMap =
189+ AM.getResult <RootSignatureAnalysis>(M);
190+ OS << " Root Signature Definitions"
191+ << " \n " ;
192+ uint8_t Space = 0 ;
193+ for (const auto &P : MRSMap) {
194+ const auto &[Function, MRS] = P;
195+ OS << " Definition for '" << Function->getName () << " ':\n " ;
196+
197+ // start root signature header
198+ Space++;
199+ printSpaces (OS, Space);
200+ OS << " Flags: " << format_hex (MRS.Flags , 8 ) << " :\n " ;
201+ Space--;
202+ // end root signature header
203+ }
184204
185- std::optional<ModuleRootSignature>
186- RootSignatureAnalysis::run (Module &M, ModuleAnalysisManager &AM) {
187- ModuleMetadataInfo MMI = AM.getResult <DXILMetadataAnalysis>(M);
188- if (MMI.ShaderProfile == Triple::Library)
189- return std::nullopt ;
190- return ModuleRootSignature::analyzeModule (M, getEntryFunction (M, MMI));
205+ return PreservedAnalyses::all ();
191206}
192207
193208// ===----------------------------------------------------------------------===//
194209bool RootSignatureAnalysisWrapper::runOnModule (Module &M) {
195- dxil::ModuleMetadataInfo &MMI =
196- getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata ();
197- if (MMI.ShaderProfile == Triple::Library)
198- return false ;
199- MRS = ModuleRootSignature::analyzeModule (M, getEntryFunction (M, MMI));
210+ MRS = analyzeModule (M);
200211 return false ;
201212}
202213
@@ -208,8 +219,8 @@ void RootSignatureAnalysisWrapper::getAnalysisUsage(AnalysisUsage &AU) const {
208219char RootSignatureAnalysisWrapper::ID = 0 ;
209220
210221INITIALIZE_PASS_BEGIN (RootSignatureAnalysisWrapper,
211- " dx -root-signature-analysis" ,
222+ " dxil -root-signature-analysis" ,
212223 " DXIL Root Signature Analysis" , true , true )
213- INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass)
214- INITIALIZE_PASS_END(RootSignatureAnalysisWrapper, " dx -root-signature-analysis" ,
224+ INITIALIZE_PASS_END(RootSignatureAnalysisWrapper,
225+ " dxil -root-signature-analysis" ,
215226 " DXIL Root Signature Analysis" , true , true )
0 commit comments