-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[DirectX] Infrastructure to collect shader flags for each function #112967
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
3da01ee
397f70b
ae373d4
c02053a
47ab4c5
fa8ec60
a4f1e51
a6d84b2
31b0770
3427781
56af02a
f8e501f
c6b3390
70e46e0
07734f7
a0d2a31
2cee00a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -13,36 +13,87 @@ | |||||||||||||||||||||
|
|
||||||||||||||||||||||
| #include "DXILShaderFlags.h" | ||||||||||||||||||||||
| #include "DirectX.h" | ||||||||||||||||||||||
| #include "llvm/ADT/STLExtras.h" | ||||||||||||||||||||||
| #include "llvm/IR/DiagnosticInfo.h" | ||||||||||||||||||||||
| #include "llvm/IR/DiagnosticPrinter.h" | ||||||||||||||||||||||
| #include "llvm/IR/Instruction.h" | ||||||||||||||||||||||
| #include "llvm/IR/Module.h" | ||||||||||||||||||||||
| #include "llvm/Support/Error.h" | ||||||||||||||||||||||
| #include "llvm/Support/FormatVariadic.h" | ||||||||||||||||||||||
| #include "llvm/Support/raw_ostream.h" | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| using namespace llvm; | ||||||||||||||||||||||
| using namespace llvm::dxil; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| static void updateFlags(ComputedShaderFlags &Flags, const Instruction &I) { | ||||||||||||||||||||||
| Type *Ty = I.getType(); | ||||||||||||||||||||||
| if (Ty->isDoubleTy()) { | ||||||||||||||||||||||
| Flags.Doubles = true; | ||||||||||||||||||||||
| namespace { | ||||||||||||||||||||||
| /// A simple Wrapper DiagnosticInfo that generates Module-level diagnostic | ||||||||||||||||||||||
| /// for Shader Flags Analysis pass | ||||||||||||||||||||||
| class DiagnosticInfoShaderFlags : public DiagnosticInfo { | ||||||||||||||||||||||
| private: | ||||||||||||||||||||||
| const Twine &Msg; | ||||||||||||||||||||||
|
||||||||||||||||||||||
| const Module &Mod; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| public: | ||||||||||||||||||||||
| /// \p M is the module for which the diagnostic is being emitted. \p Msg is | ||||||||||||||||||||||
| /// the message to show. Note that this class does not copy this message, so | ||||||||||||||||||||||
| /// this reference must be valid for the whole life time of the diagnostic. | ||||||||||||||||||||||
| DiagnosticInfoShaderFlags(const Module &M, const Twine &Msg, | ||||||||||||||||||||||
| DiagnosticSeverity Severity = DS_Error) | ||||||||||||||||||||||
| : DiagnosticInfo(DK_Unsupported, Severity), Msg(Msg), Mod(M) {} | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| void print(DiagnosticPrinter &DP) const override { | ||||||||||||||||||||||
| DP << Mod.getName() << ": " << Msg << '\n'; | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| }; | ||||||||||||||||||||||
| } // namespace | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| static void updateFlags(ComputedShaderFlags &CSF, const Instruction &I) { | ||||||||||||||||||||||
| if (!CSF.Doubles) { | ||||||||||||||||||||||
| CSF.Doubles = I.getType()->isDoubleTy(); | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
||||||||||||||||||||||
| if (!CSF.Doubles) { | |
| CSF.Doubles = I.getType()->isDoubleTy(); | |
| } | |
| if (!CSF.Doubles) | |
| CSF.Doubles = I.getType()->isDoubleTy(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed per LLVM coding style.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| for (Value *Op : I.operands()) { | |
| CSF.Doubles |= Op->getType()->isDoubleTy(); | |
| } | |
| for (Value *Op : I.operands()) | |
| CSF.Doubles |= Op->getType()->isDoubleTy(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed per LLVM coding style.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have an issue for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have an issue for this?
damyanp marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: ComputedShaderFlags has a default constructor to zero itself out, the empty initializer list is unnecessary.
| ComputedShaderFlags CSF{}; | |
| ComputedShaderFlags CSF; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: ComputedShaderFlags has a default constructor to zero itself out, the empty initializer list is unnecessary.
Changed.
damyanp marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
damyanp marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if (Iter == FunctionFlags.end() || Iter->first != Func) { | |
| return createStringError("Shader Flags information of Function '" + | |
| Func->getName() + "' not found"); | |
| } | |
| if (Iter == FunctionFlags.end() || Iter->first != Func) | |
| return createStringError("Shader Flags information of Function '" + | |
| Func->getName() + "' not found"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only way to have this fail is if we've invalidated the analysis (and failed to tell the pass manager) or we're trying to use it wrong. This should just be an assert.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed the if statement assert.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could store the const & to avoid the copy here.
damyanp marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
damyanp marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be done through a call to llvm::handleAllErrors. Something like:
| if (Error E = SFMask.takeError()) { | |
| M.getContext().diagnose( | |
| DiagnosticInfoShaderFlags(M, toString(std::move(E)))); | |
| } | |
| if (!SFMask) | |
| return handleAllErrors(std::move(E), | |
| [&](std::unique_ptr<ErrorInfoBase> EIB) -> Error { | |
| M.getContext().diagnose(errorToDiagnosticInfo(EIB); | |
| return Error::success(); | |
| }); |
This handles arrays of errors so that your function can return more than one error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be done through a call to
llvm::handleAllErrors. Something like:This handles arrays of errors so that your function can return more than one error.
Deleted this error-handling code as a result of the assertion added in getShaderFlagsmask().
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,6 +14,7 @@ | |
| #ifndef LLVM_TARGET_DIRECTX_DXILSHADERFLAGS_H | ||
| #define LLVM_TARGET_DIRECTX_DXILSHADERFLAGS_H | ||
|
|
||
| #include "llvm/IR/Function.h" | ||
| #include "llvm/IR/PassManager.h" | ||
| #include "llvm/Pass.h" | ||
| #include "llvm/Support/Compiler.h" | ||
|
|
@@ -60,21 +61,47 @@ struct ComputedShaderFlags { | |
| return FeatureFlags; | ||
| } | ||
|
|
||
| static ComputedShaderFlags computeFlags(Module &M); | ||
| uint64_t getModuleFlags() const { | ||
| uint64_t ModuleFlags = 0; | ||
| #define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) \ | ||
| ModuleFlags |= FlagName ? getMask(DxilModuleBit) : 0ull; | ||
|
||
| #include "llvm/BinaryFormat/DXContainerConstants.def" | ||
| return ModuleFlags; | ||
| } | ||
|
|
||
| void print(raw_ostream &OS = dbgs()) const; | ||
| LLVM_DUMP_METHOD void dump() const { print(); } | ||
| }; | ||
|
|
||
| struct DXILModuleShaderFlagsInfo { | ||
|
||
| Expected<const ComputedShaderFlags &> | ||
| getShaderFlagsMask(const Function *Func) const; | ||
| bool hasShaderFlagsMask(const Function *Func) const; | ||
| const ComputedShaderFlags &getModuleFlags() const; | ||
| const SmallVector<std::pair<Function const *, ComputedShaderFlags>> & | ||
| getFunctionFlags() const; | ||
| void insertInorderFunctionFlags(const Function *, ComputedShaderFlags); | ||
|
|
||
| private: | ||
| // Shader Flag mask representing module-level properties. These are | ||
| // represented using the macro DXIL_MODULE_FLAG | ||
| ComputedShaderFlags ModuleFlags; | ||
damyanp marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| // Vector of Function-Shader Flag mask pairs representing properties of each | ||
| // of the functions in the module. Shader Flags of each function are those | ||
| // represented using the macro SHADER_FEATURE_FLAG. | ||
| SmallVector<std::pair<Function const *, ComputedShaderFlags>> FunctionFlags; | ||
| }; | ||
|
|
||
| class ShaderFlagsAnalysis : public AnalysisInfoMixin<ShaderFlagsAnalysis> { | ||
| friend AnalysisInfoMixin<ShaderFlagsAnalysis>; | ||
| static AnalysisKey Key; | ||
|
|
||
| public: | ||
| ShaderFlagsAnalysis() = default; | ||
|
|
||
| using Result = ComputedShaderFlags; | ||
| using Result = DXILModuleShaderFlagsInfo; | ||
|
|
||
| ComputedShaderFlags run(Module &M, ModuleAnalysisManager &AM); | ||
| DXILModuleShaderFlagsInfo run(Module &M, ModuleAnalysisManager &AM); | ||
| }; | ||
|
|
||
| /// Printer pass for ShaderFlagsAnalysis results. | ||
|
|
@@ -92,19 +119,16 @@ class ShaderFlagsAnalysisPrinter | |
| /// This is required because the passes that will depend on this are codegen | ||
| /// passes which run through the legacy pass manager. | ||
| class ShaderFlagsAnalysisWrapper : public ModulePass { | ||
| ComputedShaderFlags Flags; | ||
| DXILModuleShaderFlagsInfo MSFI; | ||
|
|
||
| public: | ||
| static char ID; | ||
|
|
||
| ShaderFlagsAnalysisWrapper() : ModulePass(ID) {} | ||
|
|
||
| const ComputedShaderFlags &getShaderFlags() { return Flags; } | ||
| const DXILModuleShaderFlagsInfo &getShaderFlags() { return MSFI; } | ||
|
|
||
| bool runOnModule(Module &M) override { | ||
| Flags = ComputedShaderFlags::computeFlags(M); | ||
| return false; | ||
| } | ||
| bool runOnModule(Module &M) override; | ||
|
|
||
| void getAnalysisUsage(AnalysisUsage &AU) const override { | ||
| AU.setPreservesAll(); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,6 +25,7 @@ | |
| #include "llvm/IR/Module.h" | ||
| #include "llvm/InitializePasses.h" | ||
| #include "llvm/Pass.h" | ||
| #include "llvm/Support/Error.h" | ||
|
||
| #include "llvm/Support/ErrorHandling.h" | ||
| #include "llvm/Support/VersionTuple.h" | ||
| #include "llvm/TargetParser/Triple.h" | ||
|
|
@@ -286,11 +287,6 @@ static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD, | |
| MDTuple *Properties = nullptr; | ||
| if (ShaderFlags != 0) { | ||
| SmallVector<Metadata *> MDVals; | ||
| // FIXME: ShaderFlagsAnalysis pass needs to collect and provide | ||
| // ShaderFlags for each entry function. Currently, ShaderFlags value | ||
| // provided by ShaderFlagsAnalysis pass is created by walking *all* the | ||
| // function instructions of the module. Is it is correct to use this value | ||
| // for metadata of the empty library entry? | ||
| MDVals.append( | ||
| getTagValueAsMetadata(EntryPropsTag::ShaderFlags, ShaderFlags, Ctx)); | ||
| Properties = MDNode::get(Ctx, MDVals); | ||
|
|
@@ -302,7 +298,7 @@ static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD, | |
|
|
||
| static void translateMetadata(Module &M, const DXILResourceMap &DRM, | ||
| const Resources &MDResources, | ||
| const ComputedShaderFlags &ShaderFlags, | ||
| const DXILModuleShaderFlagsInfo &ShaderFlags, | ||
| const ModuleMetadataInfo &MMDI) { | ||
| LLVMContext &Ctx = M.getContext(); | ||
| IRBuilder<> IRB(Ctx); | ||
|
|
@@ -318,22 +314,38 @@ static void translateMetadata(Module &M, const DXILResourceMap &DRM, | |
| // See https://github.com/llvm/llvm-project/issues/57928 | ||
| MDTuple *Signatures = nullptr; | ||
|
|
||
| if (MMDI.ShaderProfile == Triple::EnvironmentType::Library) | ||
| if (MMDI.ShaderProfile == Triple::EnvironmentType::Library) { | ||
| // Create a consolidated shader flag mask of all functions in the library | ||
| // to be used as shader flags mask value associated with top-level library | ||
| // entry metadata. | ||
| uint64_t ConsolidatedMask = ShaderFlags.getModuleFlags(); | ||
| for (const auto &FunFlags : ShaderFlags.getFunctionFlags()) { | ||
| ConsolidatedMask |= FunFlags.second; | ||
| } | ||
damyanp marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| EntryFnMDNodes.emplace_back( | ||
| emitTopLevelLibraryNode(M, ResourceMD, ShaderFlags)); | ||
| else if (MMDI.EntryPropertyVec.size() > 1) { | ||
| emitTopLevelLibraryNode(M, ResourceMD, ConsolidatedMask)); | ||
| } else if (MMDI.EntryPropertyVec.size() > 1) { | ||
| M.getContext().diagnose(DiagnosticInfoTranslateMD( | ||
| M, "Non-library shader: One and only one entry expected")); | ||
| } | ||
|
|
||
| for (const EntryProperties &EntryProp : MMDI.EntryPropertyVec) { | ||
| // FIXME: ShaderFlagsAnalysis pass needs to collect and provide | ||
| // ShaderFlags for each entry function. For now, assume shader flags value | ||
| // of entry functions being compiled for lib_* shader profile viz., | ||
| // EntryPro.Entry is 0. | ||
| uint64_t EntryShaderFlags = | ||
| (MMDI.ShaderProfile == Triple::EnvironmentType::Library) ? 0 | ||
| : ShaderFlags; | ||
| Expected<const ComputedShaderFlags &> ECSF = | ||
| ShaderFlags.getShaderFlagsMask(EntryProp.Entry); | ||
| if (Error E = ECSF.takeError()) { | ||
| M.getContext().diagnose( | ||
| DiagnosticInfoTranslateMD(M, toString(std::move(E)))); | ||
| } | ||
|
|
||
| // If ShaderProfile is Library, mask is already consolidated in the | ||
| // top-level library node. Hence it is not emitted. | ||
| uint64_t EntryShaderFlags = 0; | ||
| if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) { | ||
| // TODO: Create a consolidated shader flag mask of all the entry | ||
| // functions and its callees. The following is correct only if | ||
| // EntryProp.Entry has no call instructions. | ||
| EntryShaderFlags = *ECSF | ShaderFlags.getModuleFlags(); | ||
| } | ||
| if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) { | ||
| if (EntryProp.ShaderStage != MMDI.ShaderProfile) { | ||
| M.getContext().diagnose(DiagnosticInfoTranslateMD( | ||
|
|
@@ -361,7 +373,7 @@ PreservedAnalyses DXILTranslateMetadata::run(Module &M, | |
| ModuleAnalysisManager &MAM) { | ||
| const DXILResourceMap &DRM = MAM.getResult<DXILResourceAnalysis>(M); | ||
| const dxil::Resources &MDResources = MAM.getResult<DXILResourceMDAnalysis>(M); | ||
| const ComputedShaderFlags &ShaderFlags = | ||
| const DXILModuleShaderFlagsInfo &ShaderFlags = | ||
| MAM.getResult<ShaderFlagsAnalysis>(M); | ||
| const dxil::ModuleMetadataInfo MMDI = MAM.getResult<DXILMetadataAnalysis>(M); | ||
|
|
||
|
|
@@ -393,7 +405,7 @@ class DXILTranslateMetadataLegacy : public ModulePass { | |
| getAnalysis<DXILResourceWrapperPass>().getResourceMap(); | ||
| const dxil::Resources &MDResources = | ||
| getAnalysis<DXILResourceMDWrapper>().getDXILResource(); | ||
| const ComputedShaderFlags &ShaderFlags = | ||
| const DXILModuleShaderFlagsInfo &ShaderFlags = | ||
| getAnalysis<ShaderFlagsAnalysisWrapper>().getShaderFlags(); | ||
| dxil::ModuleMetadataInfo MMDI = | ||
| getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata(); | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,19 @@ | ||||||||||||
| ; RUN: llc %s --filetype=obj -o - | obj2yaml | FileCheck %s | ||||||||||||
|
|
||||||||||||
| target triple = "dxil-pc-shadermodel6.7-library" | ||||||||||||
| define double @div(double %a, double %b) #0 { | ||||||||||||
| %res = fdiv double %a, %b | ||||||||||||
| ret double %res | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| attributes #0 = { convergent norecurse nounwind "hlsl.export"} | ||||||||||||
|
|
||||||||||||
| ; CHECK: - Name: SFI0 | ||||||||||||
| ; CHECK-NEXT: Size: 8 | ||||||||||||
| ; CHECK-NEXT: Flags: | ||||||||||||
| ; CHECK-NEXT: Doubles: true | ||||||||||||
| ; CHECK-NOT: {{[A-Za-z]+: +true}} | ||||||||||||
| ; CHECK: DX11_1_DoubleExtensions: true | ||||||||||||
|
||||||||||||
| ; CHECK-NEXT: Doubles: true | |
| ; CHECK-NOT: {{[A-Za-z]+: +true}} | |
| ; CHECK: DX11_1_DoubleExtensions: true | |
| ; CHECK: Doubles: true | |
| ; CHECK: DX11_1_DoubleExtensions: true |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have an issue tracking this?