diff --git a/llvm/include/llvm/Transforms/Utils/FunctionComparator.h b/llvm/include/llvm/Transforms/Utils/FunctionComparator.h index 19c5f7449f23e..6035692c86218 100644 --- a/llvm/include/llvm/Transforms/Utils/FunctionComparator.h +++ b/llvm/include/llvm/Transforms/Utils/FunctionComparator.h @@ -328,8 +328,9 @@ class FunctionComparator { int cmpOrderings(AtomicOrdering L, AtomicOrdering R) const; int cmpInlineAsm(const InlineAsm *L, const InlineAsm *R) const; int cmpAttrs(const AttributeList L, const AttributeList R) const; - int cmpMDNode(const MDNode *L, const MDNode *R) const; - int cmpMetadata(const Metadata *L, const Metadata *R) const; + int cmpMDNode(const MDNode *L, const MDNode *R, bool InValueContext) const; + int cmpMetadata(const Metadata *L, const Metadata *R, + bool InValueContext) const; int cmpInstMetadata(Instruction const *L, Instruction const *R) const; int cmpOperandBundlesSchema(const CallBase &LCS, const CallBase &RCS) const; diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp index 5934f7adffb93..49d78f184191e 100644 --- a/llvm/lib/IR/Verifier.cpp +++ b/llvm/lib/IR/Verifier.cpp @@ -3543,6 +3543,31 @@ void Verifier::visitPHINode(PHINode &PN) { visitInstruction(PN); } +/// Returns true of \p MD is valid for as a metadata argument. It must be on of +/// the following +/// * a MDNode without cycles (expect self-reference in the first operand), +/// * MDString, +/// * ValueAsMetadata. +static bool isValidMetadataArgument(const Metadata *MD, + SmallPtrSetImpl &Seen) { + // Potential cycles are not allowed. + if (!Seen.insert(MD).second) + return false; + + if (auto *Node = dyn_cast(MD)) { + if (Node->getNumOperands() == 0) + return true; + ArrayRef Ops = Node->operands(); + if (Node->getOperand(0) == Node) + Ops = Ops.drop_front(); + return all_of(Ops, [&](const Metadata *MD) { + return MD && isValidMetadataArgument(MD, Seen); + }); + } + + return isa(MD) || isa(MD); +} + void Verifier::visitCallBase(CallBase &Call) { Check(Call.getCalledOperand()->getType()->isPointerTy(), "Called function must be a pointer!", Call); @@ -3562,6 +3587,19 @@ void Verifier::visitCallBase(CallBase &Call) { "Call parameter type does not match function signature!", Call.getArgOperand(i), FTy->getParamType(i), Call); + // Verify metadata arguments. + for (unsigned i = 0, e = FTy->getNumParams(); i != e; ++i) { + auto Arg = Call.getArgOperand(i); + if (!Arg->getType()->isMetadataTy() || isa(Call)) + continue; + SmallPtrSet Seen; + Check(isValidMetadataArgument(cast(Arg)->getMetadata(), + Seen), + "Function arguments must be string metadata, value-as-metadata or an " + "MDNode!", + Call); + } + AttributeList Attrs = Call.getAttributes(); Check(verifyAttributeCount(Attrs, Call.arg_size()), diff --git a/llvm/lib/Transforms/Utils/FunctionComparator.cpp b/llvm/lib/Transforms/Utils/FunctionComparator.cpp index 6d4026e8209de..b9d8762d9d4a6 100644 --- a/llvm/lib/Transforms/Utils/FunctionComparator.cpp +++ b/llvm/lib/Transforms/Utils/FunctionComparator.cpp @@ -185,21 +185,21 @@ int FunctionComparator::cmpAttrs(const AttributeList L, return 0; } -int FunctionComparator::cmpMetadata(const Metadata *L, - const Metadata *R) const { +int FunctionComparator::cmpMetadata(const Metadata *L, const Metadata *R, + bool InValueContext) const { // TODO: the following routine coerce the metadata contents into constants // or MDStrings before comparison. // It ignores any other cases, so that the metadata nodes are considered // equal even though this is not correct. // We should structurally compare the metadata nodes to be perfect here. + if (L == R) + return 0; + auto *MDStringL = dyn_cast(L); auto *MDStringR = dyn_cast(R); - if (MDStringL && MDStringR) { - if (MDStringL == MDStringR) - return 0; + if (MDStringL && MDStringR) return MDStringL->getString().compare(MDStringR->getString()); - } if (MDStringR) return -1; if (MDStringL) @@ -207,16 +207,31 @@ int FunctionComparator::cmpMetadata(const Metadata *L, auto *CL = dyn_cast(L); auto *CR = dyn_cast(R); - if (CL == CR) - return 0; - if (!CL) + if (CL && CR) + return cmpConstants(CL->getValue(), CR->getValue()); + if (CR) return -1; - if (!CR) + if (CL) return 1; - return cmpConstants(CL->getValue(), CR->getValue()); + + auto *NodeL = dyn_cast(L); + auto *NodeR = dyn_cast(R); + if (NodeL && NodeR) { + if (InValueContext) + return cmpMDNode(NodeL, NodeR, InValueContext); + } else { + if (NodeR) + return -1; + if (NodeL) + return 1; + } + assert(!InValueContext && + "all cases must be handled when comparing metadata arguments"); + return 0; } -int FunctionComparator::cmpMDNode(const MDNode *L, const MDNode *R) const { +int FunctionComparator::cmpMDNode(const MDNode *L, const MDNode *R, + bool InValueContext) const { if (L == R) return 0; if (!L) @@ -231,8 +246,20 @@ int FunctionComparator::cmpMDNode(const MDNode *L, const MDNode *R) const { // function semantically. if (int Res = cmpNumbers(L->getNumOperands(), R->getNumOperands())) return Res; - for (size_t I = 0; I < L->getNumOperands(); ++I) - if (int Res = cmpMetadata(L->getOperand(I), R->getOperand(I))) + + size_t StartIdx = 0; + if (L->getNumOperands() > 0) { + if (L->getOperand(0) == L) { + if (R->getOperand(0) != R) + return -1; + StartIdx = 1; + } else if (R->getOperand(0) == R) + return 1; + } + + for (size_t I = StartIdx; I < L->getNumOperands(); ++I) + if (int Res = + cmpMetadata(L->getOperand(I), R->getOperand(I), InValueContext)) return Res; return 0; } @@ -254,7 +281,7 @@ int FunctionComparator::cmpInstMetadata(Instruction const *L, auto const [KeyR, MR] = MDR[I]; if (int Res = cmpNumbers(KeyL, KeyR)) return Res; - if (int Res = cmpMDNode(ML, MR)) + if (int Res = cmpMDNode(ML, MR, false)) return Res; } return 0; @@ -722,7 +749,7 @@ int FunctionComparator::cmpOperations(const Instruction *L, cast(R)->getTailCallKind())) return Res; return cmpMDNode(L->getMetadata(LLVMContext::MD_range), - R->getMetadata(LLVMContext::MD_range)); + R->getMetadata(LLVMContext::MD_range), false); } if (const InsertValueInst *IVI = dyn_cast(L)) { ArrayRef LIndices = IVI->getIndices(); @@ -899,7 +926,7 @@ int FunctionComparator::cmpValues(const Value *L, const Value *R) const { return 0; return cmpMetadata(MetadataValueL->getMetadata(), - MetadataValueR->getMetadata()); + MetadataValueR->getMetadata(), true); } if (MetadataValueL) diff --git a/llvm/test/Transforms/MergeFunc/metadata-call-arguments.ll b/llvm/test/Transforms/MergeFunc/metadata-call-arguments.ll index 28263741f2cde..875160090bed5 100644 --- a/llvm/test/Transforms/MergeFunc/metadata-call-arguments.ll +++ b/llvm/test/Transforms/MergeFunc/metadata-call-arguments.ll @@ -55,48 +55,50 @@ declare i64 @llvm.read_volatile_register.i64(metadata) !5 = !{!"foo", i64 10} !6 = !{!"foo", i64 10} +; CHECK-LABEL: define i64 @call_mdtuple_arg_not_equal_1() { +; CHECK-NEXT: [[TMP1:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META0:![0-9]+]]) +; CHECK-NEXT: ret i64 [[TMP1]] +; +; +; CHECK-LABEL: define i64 @call_mdtuple_arg_not_equal_2() { +; CHECK-NEXT: [[TMP1:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META1:![0-9]+]]) +; CHECK-NEXT: ret i64 [[TMP1]] +; +; ; CHECK-LABEL: define i64 @call_mdtuple_arg_with_cycle_equal_1() { -; CHECK-NEXT: [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META0:![0-9]+]]) +; CHECK-NEXT: [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META2:![0-9]+]]) ; CHECK-NEXT: ret i64 [[R]] ; ; ; CHECK-LABEL: define i64 @call_mdtuple_arg_with_cycle_equal_2() { -; CHECK-NEXT: [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META1:![0-9]+]]) +; CHECK-NEXT: [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META3:![0-9]+]]) ; CHECK-NEXT: ret i64 [[R]] ; ; ; CHECK-LABEL: define i64 @call_mdtuple_arg_with_cycle_not_equal_1() { -; CHECK-NEXT: [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META1]]) +; CHECK-NEXT: [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META3]]) ; CHECK-NEXT: ret i64 [[R]] ; ; ; CHECK-LABEL: define i64 @call_mdtuple_arg_with_cycle_not_equal_2() { -; CHECK-NEXT: [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META2:![0-9]+]]) +; CHECK-NEXT: [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META4:![0-9]+]]) ; CHECK-NEXT: ret i64 [[R]] ; ; ; CHECK-LABEL: define i64 @call_mdtuple_arg_equal_1() { -; CHECK-NEXT: [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META3:![0-9]+]]) +; CHECK-NEXT: [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META5:![0-9]+]]) ; CHECK-NEXT: ret i64 [[R]] ; ; -; CHECK-LABEL: define i64 @call_mdtuple_arg_not_equal_2() { -; CHECK-NEXT: [[TMP1:%.*]] = tail call i64 @call_mdtuple_arg_equal_1() -; CHECK-NEXT: ret i64 [[TMP1]] -; -; -; CHECK-LABEL: define i64 @call_mdtuple_arg_not_equal_1() { -; CHECK-NEXT: [[TMP1:%.*]] = tail call i64 @call_mdtuple_arg_equal_1() -; CHECK-NEXT: ret i64 [[TMP1]] -; -; ; CHECK-LABEL: define i64 @call_mdtuple_arg_equal_2() { ; CHECK-NEXT: [[TMP1:%.*]] = tail call i64 @call_mdtuple_arg_equal_1() ; CHECK-NEXT: ret i64 [[TMP1]] ; ;. -; CHECK: [[META0]] = distinct !{[[META0]], !"foo"} -; CHECK: [[META1]] = distinct !{[[META1]], !"foo"} -; CHECK: [[META2]] = distinct !{[[META2]], !"bar"} -; CHECK: [[META3]] = !{!"foo", i64 10} +; CHECK: [[META0]] = !{!"foo"} +; CHECK: [[META1]] = !{!"bar"} +; CHECK: [[META2]] = distinct !{[[META2]], !"foo"} +; CHECK: [[META3]] = distinct !{[[META3]], !"foo"} +; CHECK: [[META4]] = distinct !{[[META4]], !"bar"} +; CHECK: [[META5]] = !{!"foo", i64 10} ;.