Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 156 additions & 67 deletions llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ using namespace llvm;
using namespace llvm::dxil;

namespace {
/// A simple Wrapper DiagnosticInfo that generates Module-level diagnostic
/// for TranslateMetadata pass
class DiagnosticInfoTranslateMD : public DiagnosticInfo {

/// A simple wrapper of DiagnosticInfo that generates module-level diagnostic
/// for the DXILValidateMetadata pass
class DiagnosticInfoValidateMD : public DiagnosticInfo {
private:
const Twine &Msg;
const Module &Mod;
Expand All @@ -47,16 +48,26 @@ class DiagnosticInfoTranslateMD : public DiagnosticInfo {
/// \p M is the module for which the diagnostic is being emitted. \p Msg is
/// the message to show. Note that this class does not copy this message, so
/// this reference must be valid for the whole life time of the diagnostic.
DiagnosticInfoTranslateMD(const Module &M,
const Twine &Msg LLVM_LIFETIME_BOUND,
DiagnosticSeverity Severity = DS_Error)
DiagnosticInfoValidateMD(const Module &M,
const Twine &Msg LLVM_LIFETIME_BOUND,
DiagnosticSeverity Severity = DS_Error)
: DiagnosticInfo(DK_Unsupported, Severity), Msg(Msg), Mod(M) {}

void print(DiagnosticPrinter &DP) const override {
DP << Mod.getName() << ": " << Msg << '\n';
}
};

static void reportError(Module &M, Twine Message,
DiagnosticSeverity Severity = DS_Error) {
M.getContext().diagnose(DiagnosticInfoValidateMD(M, Message, Severity));
}

static void reportLoopError(Module &M, Twine Message,
DiagnosticSeverity Severity = DS_Error) {
reportError(M, Twine("Invalid \"llvm.loop\" metadata: ") + Message, Severity);
}

enum class EntryPropsTag {
ShaderFlags = 0,
GSState,
Expand Down Expand Up @@ -314,25 +325,122 @@ static void translateBranchMetadata(Module &M, Instruction *BBTerminatorInst) {
BBTerminatorInst->setMetadata("hlsl.controlflow.hint", nullptr);
}

static std::array<unsigned, 6> getCompatibleInstructionMDs(llvm::Module &M) {
// Determines if the metadata node will be compatible with DXIL's loop metadata
// representation.
//
// Reports an error for compatible metadata that is ill-formed.
static bool isLoopMDCompatible(Module &M, Metadata *MD) {
// DXIL only accepts the following loop hints:
std::array<StringLiteral, 3> ValidHintNames = {"llvm.loop.unroll.count",
"llvm.loop.unroll.disable",
"llvm.loop.unroll.full"};

MDNode *HintMD = dyn_cast<MDNode>(MD);
if (!HintMD || HintMD->getNumOperands() == 0)
return false;

auto *HintStr = dyn_cast<MDString>(HintMD->getOperand(0));
if (!HintStr)
return false;

if (!llvm::is_contained(ValidHintNames, HintStr->getString()))
return false;

auto ValidCountNode = [](MDNode *CountMD) -> bool {
if (CountMD->getNumOperands() == 2)
if (auto *Count = dyn_cast<ConstantAsMetadata>(CountMD->getOperand(1)))
if (isa<ConstantInt>(Count->getValue()))
return true;
return false;
};

if (HintStr->getString() == "llvm.loop.unroll.count") {
if (!ValidCountNode(HintMD)) {
reportLoopError(M, "\"llvm.loop.unroll.count\" must have 2 operands and "
"the second must be a constant integer");
return false;
}
} else if (HintMD->getNumOperands() != 1) {
reportLoopError(
M, "\"llvm.loop.unroll.disable\" and \"llvm.loop.unroll.full\" "
"must be provided as a single operand");
return false;
}

return true;
}

static void translateLoopMetadata(Module &M, Instruction *I, MDNode *BaseMD) {
// A distinct node has the self-referential form: !0 = !{ !0, ... }
auto IsDistinctNode = [](MDNode *Node) -> bool {
return Node && Node->getNumOperands() != 0 && Node == Node->getOperand(0);
};

// Set metadata to null to remove empty/ill-formed metadata from instruction
if (BaseMD->getNumOperands() == 0 || !IsDistinctNode(BaseMD))
return I->setMetadata("llvm.loop", nullptr);

// It is valid to have a chain of self-refential loop metadata nodes, as
// below. We will collapse these into just one when we reconstruct the
// metadata.
//
// Eg:
// !0 = !{!0, !1}
// !1 = !{!1, !2}
// !2 = !{!"llvm.loop.unroll.disable"}
//
// So, traverse down a potential self-referential chain
while (1 < BaseMD->getNumOperands() &&
IsDistinctNode(dyn_cast<MDNode>(BaseMD->getOperand(1))))
BaseMD = dyn_cast<MDNode>(BaseMD->getOperand(1));

// To reconstruct a distinct node we create a temporary node that we will
// then update to create a self-reference.
llvm::TempMDTuple TempNode = llvm::MDNode::getTemporary(M.getContext(), {});
SmallVector<Metadata *> CompatibleOperands = {TempNode.get()};

// Iterate and reconstruct the metadata nodes that contains any hints,
// stripping any unrecognized metadata.
ArrayRef<MDOperand> Operands = BaseMD->operands();
for (auto &Op : Operands.drop_front())
if (isLoopMDCompatible(M, Op.get()))
CompatibleOperands.push_back(Op.get());

if (2 < CompatibleOperands.size())
reportLoopError(M, "Provided conflicting hints");

MDNode *CompatibleLoopMD = MDNode::get(M.getContext(), CompatibleOperands);
TempNode->replaceAllUsesWith(CompatibleLoopMD);

I->setMetadata("llvm.loop", CompatibleLoopMD);
}

using InstructionMDList = std::array<unsigned, 7>;

static InstructionMDList getCompatibleInstructionMDs(llvm::Module &M) {
return {
M.getMDKindID("dx.nonuniform"), M.getMDKindID("dx.controlflow.hints"),
M.getMDKindID("dx.precise"), llvm::LLVMContext::MD_range,
llvm::LLVMContext::MD_alias_scope, llvm::LLVMContext::MD_noalias};
llvm::LLVMContext::MD_alias_scope, llvm::LLVMContext::MD_noalias,
M.getMDKindID("llvm.loop")};
}

static void translateInstructionMetadata(Module &M) {
// construct allowlist of valid metadata node kinds
std::array<unsigned, 6> DXILCompatibleMDs = getCompatibleInstructionMDs(M);
InstructionMDList DXILCompatibleMDs = getCompatibleInstructionMDs(M);
unsigned char MDLoopKind = M.getContext().getMDKindID("llvm.loop");

for (Function &F : M) {
for (BasicBlock &BB : F) {
// This needs to be done first so that "hlsl.controlflow.hints" isn't
// removed in the whitelist below
// removed in the allow-list below
if (auto *I = BB.getTerminator())
translateBranchMetadata(M, I);

for (auto &I : make_early_inc_range(BB)) {
if (isa<BranchInst>(I))
if (MDNode *LoopMD = I.getMetadata(MDLoopKind))
translateLoopMetadata(M, &I, LoopMD);
I.dropUnknownNonDebugMetadata(DXILCompatibleMDs);
}
}
Expand Down Expand Up @@ -389,31 +497,23 @@ static void translateGlobalMetadata(Module &M, DXILResourceMap &DRM,
uint64_t CombinedMask = ShaderFlags.getCombinedFlags();
EntryFnMDNodes.emplace_back(
emitTopLevelLibraryNode(M, ResourceMD, CombinedMask));
} else if (MMDI.EntryPropertyVec.size() > 1) {
M.getContext().diagnose(DiagnosticInfoTranslateMD(
M, "Non-library shader: One and only one entry expected"));
}
} else if (1 < MMDI.EntryPropertyVec.size())
reportError(M, "Non-library shader: One and only one entry expected");

for (const EntryProperties &EntryProp : MMDI.EntryPropertyVec) {
const ComputedShaderFlags &EntrySFMask =
ShaderFlags.getFunctionFlags(EntryProp.Entry);

// If ShaderProfile is Library, mask is already consolidated in the
// top-level library node. Hence it is not emitted.
uint64_t EntryShaderFlags = 0;
if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) {
EntryShaderFlags = EntrySFMask;
if (EntryProp.ShaderStage != MMDI.ShaderProfile) {
M.getContext().diagnose(DiagnosticInfoTranslateMD(
M,
"Shader stage '" +
Twine(getShortShaderStage(EntryProp.ShaderStage) +
"' for entry '" + Twine(EntryProp.Entry->getName()) +
"' different from specified target profile '" +
Twine(Triple::getEnvironmentTypeName(MMDI.ShaderProfile) +
"'"))));
}
EntryShaderFlags = ShaderFlags.getFunctionFlags(EntryProp.Entry);
if (EntryProp.ShaderStage != MMDI.ShaderProfile)
reportError(
M, "Shader stage '" +
Twine(getShortShaderStage(EntryProp.ShaderStage)) +
"' for entry '" + Twine(EntryProp.Entry->getName()) +
"' different from specified target profile '" +
Twine(Triple::getEnvironmentTypeName(MMDI.ShaderProfile) +
"'"));
}

EntryFnMDNodes.emplace_back(emitEntryMD(EntryProp, Signatures, ResourceMD,
EntryShaderFlags,
MMDI.ShaderProfile));
Expand Down Expand Up @@ -454,45 +554,34 @@ PreservedAnalyses DXILTranslateMetadata::run(Module &M,
return PreservedAnalyses::all();
}

namespace {
class DXILTranslateMetadataLegacy : public ModulePass {
public:
static char ID; // Pass identification, replacement for typeid
explicit DXILTranslateMetadataLegacy() : ModulePass(ID) {}

StringRef getPassName() const override { return "DXIL Translate Metadata"; }

void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<DXILResourceTypeWrapperPass>();
AU.addRequired<DXILResourceWrapperPass>();
AU.addRequired<ShaderFlagsAnalysisWrapper>();
AU.addRequired<DXILMetadataAnalysisWrapperPass>();
AU.addRequired<RootSignatureAnalysisWrapper>();

AU.addPreserved<DXILMetadataAnalysisWrapperPass>();
AU.addPreserved<DXILResourceBindingWrapperPass>();
AU.addPreserved<DXILResourceWrapperPass>();
AU.addPreserved<RootSignatureAnalysisWrapper>();
AU.addPreserved<ShaderFlagsAnalysisWrapper>();
}
void DXILTranslateMetadataLegacy::getAnalysisUsage(AnalysisUsage &AU) const {
AU.addRequired<DXILResourceTypeWrapperPass>();
AU.addRequired<DXILResourceWrapperPass>();
AU.addRequired<ShaderFlagsAnalysisWrapper>();
AU.addRequired<DXILMetadataAnalysisWrapperPass>();
AU.addRequired<RootSignatureAnalysisWrapper>();

AU.addPreserved<DXILMetadataAnalysisWrapperPass>();
AU.addPreserved<DXILResourceBindingWrapperPass>();
AU.addPreserved<DXILResourceWrapperPass>();
AU.addPreserved<RootSignatureAnalysisWrapper>();
AU.addPreserved<ShaderFlagsAnalysisWrapper>();
}

bool runOnModule(Module &M) override {
DXILResourceMap &DRM =
getAnalysis<DXILResourceWrapperPass>().getResourceMap();
DXILResourceTypeMap &DRTM =
getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap();
const ModuleShaderFlags &ShaderFlags =
getAnalysis<ShaderFlagsAnalysisWrapper>().getShaderFlags();
dxil::ModuleMetadataInfo MMDI =
getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();

translateGlobalMetadata(M, DRM, DRTM, ShaderFlags, MMDI);
translateInstructionMetadata(M);
return true;
}
};
bool DXILTranslateMetadataLegacy::runOnModule(Module &M) {
DXILResourceMap &DRM =
getAnalysis<DXILResourceWrapperPass>().getResourceMap();
DXILResourceTypeMap &DRTM =
getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap();
const ModuleShaderFlags &ShaderFlags =
getAnalysis<ShaderFlagsAnalysisWrapper>().getShaderFlags();
dxil::ModuleMetadataInfo MMDI =
getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();

} // namespace
translateGlobalMetadata(M, DRM, DRTM, ShaderFlags, MMDI);
translateInstructionMetadata(M);
return true;
}

char DXILTranslateMetadataLegacy::ID = 0;

Expand Down
17 changes: 17 additions & 0 deletions llvm/lib/Target/DirectX/DXILTranslateMetadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#define LLVM_TARGET_DIRECTX_DXILTRANSLATEMETADATA_H

#include "llvm/IR/PassManager.h"
#include "llvm/Pass.h"

namespace llvm {

Expand All @@ -20,6 +21,22 @@ class DXILTranslateMetadata : public PassInfoMixin<DXILTranslateMetadata> {
PreservedAnalyses run(Module &M, ModuleAnalysisManager &);
};

/// Wrapper pass for the legacy pass manager.
///
/// This is required because the passes that will depend on this are codegen
/// passes which run through the legacy pass manager.
class DXILTranslateMetadataLegacy : public ModulePass {
public:
static char ID; // Pass identification, replacement for typeid
explicit DXILTranslateMetadataLegacy() : ModulePass(ID) {}

StringRef getPassName() const override { return "DXIL Translate Metadata"; }

void getAnalysisUsage(AnalysisUsage &AU) const override;

bool runOnModule(Module &M) override;
};

} // namespace llvm

#endif // LLVM_TARGET_DIRECTX_DXILTRANSLATEMETADATA_H
Loading