1515#include " llvm/ADT/SmallVector.h"
1616#include " llvm/ADT/StringExtras.h"
1717#include " llvm/ADT/StringRef.h"
18+ #include " llvm/Analysis/DXILMetadataAnalysis.h"
1819#include " llvm/BinaryFormat/DXContainer.h"
1920#include " llvm/CodeGen/Passes.h"
2021#include " llvm/IR/Constants.h"
@@ -57,6 +58,7 @@ class DXContainerGlobals : public llvm::ModulePass {
5758 void getAnalysisUsage (AnalysisUsage &AU) const override {
5859 AU.setPreservesAll ();
5960 AU.addRequired <ShaderFlagsAnalysisWrapper>();
61+ AU.addRequired <DXILMetadataAnalysisWrapperPass>();
6062 }
6163};
6264
@@ -143,23 +145,35 @@ void DXContainerGlobals::addPipelineStateValidationInfo(
143145 SmallString<256 > Data;
144146 raw_svector_ostream OS (Data);
145147 PSVRuntimeInfo PSV;
146- Triple TT (M.getTargetTriple ());
147148 PSV.BaseData .MinimumWaveLaneCount = 0 ;
148149 PSV.BaseData .MaximumWaveLaneCount = std::numeric_limits<uint32_t >::max ();
150+
151+ dxil::ModuleMetadataInfo &MMI =
152+ getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata ();
153+ assert (MMI.EntryPropertyVec .size () == 1 ||
154+ MMI.ShaderStage == Triple::Library);
149155 PSV.BaseData .ShaderStage =
150- static_cast <uint8_t >(TT. getEnvironment () - Triple::Pixel);
156+ static_cast <uint8_t >(MMI. ShaderStage - Triple::Pixel);
151157
152158 // Hardcoded values here to unblock loading the shader into D3D.
153159 //
154160 // TODO: Lots more stuff to do here!
155161 //
156162 // See issue https://github.com/llvm/llvm-project/issues/96674.
157- PSV.BaseData .NumThreadsX = 1 ;
158- PSV.BaseData .NumThreadsY = 1 ;
159- PSV.BaseData .NumThreadsZ = 1 ;
160- PSV.EntryName = " main" ;
163+ switch (MMI.ShaderStage ) {
164+ case Triple::Compute:
165+ PSV.BaseData .NumThreadsX = MMI.EntryPropertyVec [0 ].NumThreadsX ;
166+ PSV.BaseData .NumThreadsY = MMI.EntryPropertyVec [0 ].NumThreadsY ;
167+ PSV.BaseData .NumThreadsZ = MMI.EntryPropertyVec [0 ].NumThreadsZ ;
168+ break ;
169+ default :
170+ break ;
171+ }
172+
173+ if (MMI.ShaderStage != Triple::Library)
174+ PSV.EntryName = MMI.EntryPropertyVec [0 ].Entry ->getName ();
161175
162- PSV.finalize (TT. getEnvironment () );
176+ PSV.finalize (MMI. ShaderStage );
163177 PSV.write (OS);
164178 Constant *Constant =
165179 ConstantDataArray::getString (M.getContext (), Data, /* AddNull*/ false );
@@ -170,6 +184,7 @@ char DXContainerGlobals::ID = 0;
170184INITIALIZE_PASS_BEGIN (DXContainerGlobals, " dxil-globals" ,
171185 " DXContainer Global Emitter" , false , true )
172186INITIALIZE_PASS_DEPENDENCY(ShaderFlagsAnalysisWrapper)
187+ INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass)
173188INITIALIZE_PASS_END(DXContainerGlobals, " dxil-globals" ,
174189 " DXContainer Global Emitter" , false , true )
175190
0 commit comments