|
9 | 9 | #include "DXILValidateMetadata.h" |
10 | 10 | #include "DXILTranslateMetadata.h" |
11 | 11 | #include "DirectX.h" |
| 12 | +#include "llvm/ADT/STLExtras.h" |
| 13 | +#include "llvm/ADT/StringRef.h" |
12 | 14 | #include "llvm/ADT/Twine.h" |
13 | 15 | #include "llvm/Analysis/DXILMetadataAnalysis.h" |
14 | 16 | #include "llvm/IR/BasicBlock.h" |
| 17 | +#include "llvm/IR/Constants.h" |
15 | 18 | #include "llvm/IR/DiagnosticInfo.h" |
16 | 19 | #include "llvm/IR/DiagnosticPrinter.h" |
| 20 | +#include "llvm/IR/LLVMContext.h" |
17 | 21 | #include "llvm/IR/Metadata.h" |
18 | 22 | #include "llvm/IR/Module.h" |
19 | 23 | #include "llvm/InitializePasses.h" |
| 24 | +#include "llvm/Support/Casting.h" |
20 | 25 | #include "llvm/Support/ErrorHandling.h" |
21 | 26 |
|
22 | 27 | using namespace llvm; |
@@ -50,10 +55,96 @@ static bool reportError(Module &M, Twine Message, |
50 | 55 | return true; |
51 | 56 | } |
52 | 57 |
|
| 58 | +static bool reportLoopError(Module &M, Twine Message, |
| 59 | + DiagnosticSeverity Severity = DS_Error) { |
| 60 | + return reportError(M, Twine("Invalid \"llvm.loop\" metadata: ") + Message, |
| 61 | + Severity); |
| 62 | +} |
| 63 | + |
53 | 64 | } // namespace |
54 | 65 |
|
| 66 | +static void validateLoopMetadata(Module &M, MDNode *LoopMD) { |
| 67 | + // DXIL only accepts the following loop hints: |
| 68 | + // llvm.loop.unroll.disable, llvm.loop.unroll.full, llvm.loop.unroll.count |
| 69 | + std::array<StringLiteral, 3> ValidHintNames = {"llvm.loop.unroll.count", |
| 70 | + "llvm.loop.unroll.disable", |
| 71 | + "llvm.loop.unroll.full"}; |
| 72 | + |
| 73 | + // llvm.loop metadata must have it's first operand be a self-reference, so we |
| 74 | + // require at least 1 operand. |
| 75 | + // |
| 76 | + // It only makes sense to specify up to 1 of the hints on a branch, so we can |
| 77 | + // have at most 2 operands. |
| 78 | + |
| 79 | + if (LoopMD->getNumOperands() != 1 && LoopMD->getNumOperands() != 2) { |
| 80 | + reportLoopError(M, "Requires exactly 1 or 2 operands"); |
| 81 | + return; |
| 82 | + } |
| 83 | + |
| 84 | + if (LoopMD != LoopMD->getOperand(0)) { |
| 85 | + reportLoopError(M, "First operand must be a self-reference"); |
| 86 | + return; |
| 87 | + } |
| 88 | + |
| 89 | + // A node only containing a self-reference is a valid use to denote a loop |
| 90 | + if (LoopMD->getNumOperands() == 1) |
| 91 | + return; |
| 92 | + |
| 93 | + LoopMD = dyn_cast<MDNode>(LoopMD->getOperand(1)); |
| 94 | + if (!LoopMD) { |
| 95 | + reportLoopError(M, "Second operand must be a metadata node"); |
| 96 | + return; |
| 97 | + } |
| 98 | + |
| 99 | + if (LoopMD->getNumOperands() != 1 && LoopMD->getNumOperands() != 2) { |
| 100 | + reportLoopError(M, "Requires exactly 1 or 2 operands"); |
| 101 | + return; |
| 102 | + } |
| 103 | + |
| 104 | + // It is valid to have a chain of self-referential loop metadata nodes so if |
| 105 | + // we have another self-reference, recurse. |
| 106 | + // |
| 107 | + // Eg: |
| 108 | + // !0 = !{!0, !1} |
| 109 | + // !1 = !{!1, !2} |
| 110 | + // !2 = !{"llvm.loop.unroll.disable"} |
| 111 | + if (LoopMD == LoopMD->getOperand(0)) |
| 112 | + return validateLoopMetadata(M, LoopMD); |
| 113 | + |
| 114 | + // Otherwise, we are at our base hint metadata node |
| 115 | + auto *HintStr = dyn_cast<MDString>(LoopMD->getOperand(0)); |
| 116 | + if (!HintStr || !llvm::is_contained(ValidHintNames, HintStr->getString())) { |
| 117 | + reportLoopError(M, |
| 118 | + "First operand must be a valid \"llvm.loop.unroll\" hint"); |
| 119 | + return; |
| 120 | + } |
| 121 | + |
| 122 | + // Ensure count node is a constant integer value |
| 123 | + auto ValidCountNode = [](MDNode *HintMD) -> bool { |
| 124 | + if (HintMD->getNumOperands() == 2) |
| 125 | + if (auto *CountMD = dyn_cast<ConstantAsMetadata>(HintMD->getOperand(1))) |
| 126 | + if (isa<ConstantInt>(CountMD->getValue())) |
| 127 | + return true; |
| 128 | + return false; |
| 129 | + }; |
| 130 | + |
| 131 | + if (HintStr->getString() == "llvm.loop.unroll.count" && |
| 132 | + !ValidCountNode(LoopMD)) { |
| 133 | + reportLoopError(M, "Second operand of \"llvm.loop.unroll.count\" " |
| 134 | + "must be a constant integer"); |
| 135 | + return; |
| 136 | + } |
| 137 | +} |
| 138 | + |
55 | 139 | static void validateInstructionMetadata(Module &M) { |
56 | | - llvm::errs() << "hello from new pass!\n"; |
| 140 | + unsigned char MDLoopKind = M.getContext().getMDKindID("llvm.loop"); |
| 141 | + |
| 142 | + for (Function &F : M) |
| 143 | + for (BasicBlock &BB : F) |
| 144 | + for (Instruction &I : BB) { |
| 145 | + if (MDNode *LoopMD = I.getMetadata(MDLoopKind)) |
| 146 | + validateLoopMetadata(M, LoopMD); |
| 147 | + } |
57 | 148 | } |
58 | 149 |
|
59 | 150 | static void validateGlobalMetadata(Module &M, |
@@ -104,6 +195,7 @@ class DXILValidateMetadataLegacy : public ModulePass { |
104 | 195 | dxil::ModuleMetadataInfo MMDI = |
105 | 196 | getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata(); |
106 | 197 | validateGlobalMetadata(M, MMDI); |
| 198 | + validateInstructionMetadata(M); |
107 | 199 | return true; |
108 | 200 | } |
109 | 201 | }; |
|
0 commit comments