1313
1414#include " DXILShaderFlags.h"
1515#include " DirectX.h"
16+ #include " llvm/Analysis/DXILMetadataAnalysis.h"
17+ #include " llvm/Analysis/DXILResource.h"
1618#include " llvm/IR/Instruction.h"
1719#include " llvm/IR/Module.h"
20+ #include " llvm/Support/DXILABI.h"
1821#include " llvm/Support/FormatVariadic.h"
1922
2023using namespace llvm ;
@@ -36,8 +39,42 @@ static void updateFlags(ComputedShaderFlags &Flags, const Instruction &I) {
3639 }
3740}
3841
39- ComputedShaderFlags ComputedShaderFlags::computeFlags (Module &M) {
42+ static void updateResourceFlags (ComputedShaderFlags &Flags, Module &M,
43+ ModuleAnalysisManager &AM) {
44+ const DXILResourceMap &DRM = AM.getResult <DXILResourceAnalysis>(M);
45+ if (DRM.empty ())
46+ return ;
47+
48+ const dxil::ModuleMetadataInfo &MMDI = AM.getResult <DXILMetadataAnalysis>(M);
49+ VersionTuple SM = MMDI.ShaderModelVersion ;
50+ Triple::EnvironmentType SP = MMDI.ShaderProfile ;
51+
52+ // StructuredBuffer
53+ // for (const ResourceInfo &RI : DRM.srvs()) {
54+ // if (RI.getResourceKind() ==
55+ // ResourceKind::RawBuffer) {
56+ // Flags.EnableRawAndStructuredBuffers = true;
57+ // Flags.ComputeShadersPlusRawAndStructuredBuffers = (SM.getMajor() == 4);
58+ // break;
59+ // }
60+ // }
61+
62+ // RWBuffer
63+ for (const ResourceInfo &RI : DRM.uavs ()) {
64+ if (RI.getResourceKind () == ResourceKind::TypedBuffer) {
65+ Flags.EnableRawAndStructuredBuffers = true ;
66+ Flags.ComputeShadersPlusRawAndStructuredBuffers =
67+ (SP == Triple::EnvironmentType::Compute && SM.getMajor () == 4 );
68+ break ;
69+ }
70+ }
71+ }
72+
73+ ComputedShaderFlags
74+ ComputedShaderFlags::computeFlags (Module &M, ModuleAnalysisManager &AM) {
4075 ComputedShaderFlags Flags;
76+ updateResourceFlags (Flags, M, AM);
77+
4178 for (const auto &F : M)
4279 for (const auto &BB : F)
4380 for (const auto &I : BB)
@@ -67,7 +104,7 @@ AnalysisKey ShaderFlagsAnalysis::Key;
67104
68105ComputedShaderFlags ShaderFlagsAnalysis::run (Module &M,
69106 ModuleAnalysisManager &AM) {
70- return ComputedShaderFlags::computeFlags (M);
107+ return ComputedShaderFlags::computeFlags (M, AM );
71108}
72109
73110PreservedAnalyses ShaderFlagsAnalysisPrinter::run (Module &M,
0 commit comments