Skip to content

Commit 87baa8f

Browse files
committed
[FuncComp] Compare MDNodes in cmpMetadata using cmpMDNode.
Use cmpMDNode in cmpMetadata to structurally compare MDNodes for metadata arguments. This fixes a mis-compile caused by cmpMetadata incorrectly returning 0 for different nodes. Note that metadata can contain cycles, so we need to make sure we don't get stuck in an infinite cycle.
1 parent 3c4fa5a commit 87baa8f

File tree

3 files changed

+81
-38
lines changed

3 files changed

+81
-38
lines changed

llvm/include/llvm/Transforms/Utils/FunctionComparator.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,8 +328,12 @@ class FunctionComparator {
328328
int cmpOrderings(AtomicOrdering L, AtomicOrdering R) const;
329329
int cmpInlineAsm(const InlineAsm *L, const InlineAsm *R) const;
330330
int cmpAttrs(const AttributeList L, const AttributeList R) const;
331-
int cmpMDNode(const MDNode *L, const MDNode *R) const;
332-
int cmpMetadata(const Metadata *L, const Metadata *R) const;
331+
int cmpMDNode(const MDNode *L, const MDNode *R,
332+
SmallPtrSetImpl<const MDNode *> &SeenL,
333+
SmallPtrSetImpl<const MDNode *> &SeenR) const;
334+
int cmpMetadata(const Metadata *L, const Metadata *R,
335+
SmallPtrSetImpl<const MDNode *> &SeenL,
336+
SmallPtrSetImpl<const MDNode *> &SeenR) const;
333337
int cmpInstMetadata(Instruction const *L, Instruction const *R) const;
334338
int cmpOperandBundlesSchema(const CallBase &LCS, const CallBase &RCS) const;
335339

llvm/lib/Transforms/Utils/FunctionComparator.cpp

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,13 @@ int FunctionComparator::cmpAttrs(const AttributeList L,
185185
return 0;
186186
}
187187

188-
int FunctionComparator::cmpMetadata(const Metadata *L,
189-
const Metadata *R) const {
188+
int FunctionComparator::cmpMetadata(
189+
const Metadata *L, const Metadata *R,
190+
SmallPtrSetImpl<const MDNode *> &SeenL,
191+
SmallPtrSetImpl<const MDNode *> &SeenR) const {
192+
if (L == R)
193+
return 0;
194+
190195
// TODO: the following routine coerce the metadata contents into constants
191196
// or MDStrings before comparison.
192197
// It ignores any other cases, so that the metadata nodes are considered
@@ -207,22 +212,51 @@ int FunctionComparator::cmpMetadata(const Metadata *L,
207212

208213
auto *CL = dyn_cast<ConstantAsMetadata>(L);
209214
auto *CR = dyn_cast<ConstantAsMetadata>(R);
210-
if (CL == CR)
211-
return 0;
212-
if (!CL)
215+
if (CL && CR) {
216+
if (!CL)
217+
return -1;
218+
if (!CR)
219+
return 1;
220+
return cmpConstants(CL->getValue(), CR->getValue());
221+
}
222+
223+
auto *NodeL = dyn_cast<const MDNode>(L);
224+
auto *NodeR = dyn_cast<const MDNode>(R);
225+
if (NodeL && NodeR)
226+
return cmpMDNode(NodeL, NodeR, SeenL, SeenR);
227+
228+
if (NodeR)
213229
return -1;
214-
if (!CR)
230+
231+
if (NodeL)
215232
return 1;
216-
return cmpConstants(CL->getValue(), CR->getValue());
233+
234+
assert(false);
235+
236+
return 0;
217237
}
218238

219-
int FunctionComparator::cmpMDNode(const MDNode *L, const MDNode *R) const {
239+
int FunctionComparator::cmpMDNode(
240+
const MDNode *L, const MDNode *R, SmallPtrSetImpl<const MDNode *> &SeenL,
241+
SmallPtrSetImpl<const MDNode *> &SeenR) const {
220242
if (L == R)
221243
return 0;
222244
if (!L)
223245
return -1;
224246
if (!R)
225247
return 1;
248+
249+
// Check if we already checked either L or R previosuly. This can be the case
250+
// for metadata nodes with cycles.
251+
bool AlreadySeenL = !SeenL.insert(L).second;
252+
bool AlreadySeenR = !SeenR.insert(R).second;
253+
if (AlreadySeenL && AlreadySeenR)
254+
return 0;
255+
if (AlreadySeenR)
256+
return -1;
257+
if (AlreadySeenL)
258+
return 1;
259+
226260
// TODO: Note that as this is metadata, it is possible to drop and/or merge
227261
// this data when considering functions to merge. Thus this comparison would
228262
// return 0 (i.e. equivalent), but merging would become more complicated
@@ -232,7 +266,7 @@ int FunctionComparator::cmpMDNode(const MDNode *L, const MDNode *R) const {
232266
if (int Res = cmpNumbers(L->getNumOperands(), R->getNumOperands()))
233267
return Res;
234268
for (size_t I = 0; I < L->getNumOperands(); ++I)
235-
if (int Res = cmpMetadata(L->getOperand(I), R->getOperand(I)))
269+
if (int Res = cmpMetadata(L->getOperand(I), R->getOperand(I), SeenL, SeenR))
236270
return Res;
237271
return 0;
238272
}
@@ -254,7 +288,10 @@ int FunctionComparator::cmpInstMetadata(Instruction const *L,
254288
auto const [KeyR, MR] = MDR[I];
255289
if (int Res = cmpNumbers(KeyL, KeyR))
256290
return Res;
257-
if (int Res = cmpMDNode(ML, MR))
291+
292+
SmallPtrSet<const MDNode *, 4> SeenL;
293+
SmallPtrSet<const MDNode *, 4> SeenR;
294+
if (int Res = cmpMDNode(ML, MR, SeenL, SeenR))
258295
return Res;
259296
}
260297
return 0;
@@ -721,8 +758,11 @@ int FunctionComparator::cmpOperations(const Instruction *L,
721758
if (int Res = cmpNumbers(CI->getTailCallKind(),
722759
cast<CallInst>(R)->getTailCallKind()))
723760
return Res;
761+
762+
SmallPtrSet<const MDNode *, 4> SeenL;
763+
SmallPtrSet<const MDNode *, 4> SeenR;
724764
return cmpMDNode(L->getMetadata(LLVMContext::MD_range),
725-
R->getMetadata(LLVMContext::MD_range));
765+
R->getMetadata(LLVMContext::MD_range), SeenL, SeenR);
726766
}
727767
if (const InsertValueInst *IVI = dyn_cast<InsertValueInst>(L)) {
728768
ArrayRef<unsigned> LIndices = IVI->getIndices();
@@ -895,11 +935,10 @@ int FunctionComparator::cmpValues(const Value *L, const Value *R) const {
895935
const MetadataAsValue *MetadataValueL = dyn_cast<MetadataAsValue>(L);
896936
const MetadataAsValue *MetadataValueR = dyn_cast<MetadataAsValue>(R);
897937
if (MetadataValueL && MetadataValueR) {
898-
if (MetadataValueL == MetadataValueR)
899-
return 0;
900-
938+
SmallPtrSet<const MDNode *, 4> SeenL;
939+
SmallPtrSet<const MDNode *, 4> SeenR;
901940
return cmpMetadata(MetadataValueL->getMetadata(),
902-
MetadataValueR->getMetadata());
941+
MetadataValueR->getMetadata(), SeenL, SeenR);
903942
}
904943

905944
if (MetadataValueL)

llvm/test/Transforms/MergeFunc/metadata-call-arguments.ll

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --include-generated-funcs --version 5
22
; RUN: opt -p mergefunc -S %s | FileCheck %s
33

4-
; FIXME: Should not be merged with @call_mdtuple_arg_not_equal_2.
54
define i64 @call_mdtuple_arg_not_equal_1() {
65
%r = call i64 @llvm.read_volatile_register.i64(metadata !0)
76
ret i64 %r
@@ -22,7 +21,6 @@ define i64 @call_mdtuple_arg_with_cycle_equal_2() {
2221
ret i64 %r
2322
}
2423

25-
; FIXME: Should not be merged with @call_mdtuple_arg_with_cycle_not_equal_2.
2624
define i64 @call_mdtuple_arg_with_cycle_not_equal_1() {
2725
%r = call i64 @llvm.read_volatile_register.i64(metadata !3)
2826
ret i64 %r
@@ -55,48 +53,50 @@ declare i64 @llvm.read_volatile_register.i64(metadata)
5553

5654
!5 = !{!"foo", i64 10}
5755
!6 = !{!"foo", i64 10}
56+
; CHECK-LABEL: define i64 @call_mdtuple_arg_not_equal_1() {
57+
; CHECK-NEXT: [[TMP1:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META0:![0-9]+]])
58+
; CHECK-NEXT: ret i64 [[TMP1]]
59+
;
60+
;
61+
; CHECK-LABEL: define i64 @call_mdtuple_arg_not_equal_2() {
62+
; CHECK-NEXT: [[TMP1:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META1:![0-9]+]])
63+
; CHECK-NEXT: ret i64 [[TMP1]]
64+
;
65+
;
5866
; CHECK-LABEL: define i64 @call_mdtuple_arg_with_cycle_equal_1() {
59-
; CHECK-NEXT: [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META0:![0-9]+]])
67+
; CHECK-NEXT: [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META2:![0-9]+]])
6068
; CHECK-NEXT: ret i64 [[R]]
6169
;
6270
;
6371
; CHECK-LABEL: define i64 @call_mdtuple_arg_with_cycle_equal_2() {
64-
; CHECK-NEXT: [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META1:![0-9]+]])
72+
; CHECK-NEXT: [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META3:![0-9]+]])
6573
; CHECK-NEXT: ret i64 [[R]]
6674
;
6775
;
6876
; CHECK-LABEL: define i64 @call_mdtuple_arg_with_cycle_not_equal_1() {
69-
; CHECK-NEXT: [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META1]])
77+
; CHECK-NEXT: [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META3]])
7078
; CHECK-NEXT: ret i64 [[R]]
7179
;
7280
;
7381
; CHECK-LABEL: define i64 @call_mdtuple_arg_with_cycle_not_equal_2() {
74-
; CHECK-NEXT: [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META2:![0-9]+]])
82+
; CHECK-NEXT: [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META4:![0-9]+]])
7583
; CHECK-NEXT: ret i64 [[R]]
7684
;
7785
;
7886
; CHECK-LABEL: define i64 @call_mdtuple_arg_equal_1() {
79-
; CHECK-NEXT: [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META3:![0-9]+]])
87+
; CHECK-NEXT: [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META5:![0-9]+]])
8088
; CHECK-NEXT: ret i64 [[R]]
8189
;
8290
;
83-
; CHECK-LABEL: define i64 @call_mdtuple_arg_not_equal_2() {
84-
; CHECK-NEXT: [[TMP1:%.*]] = tail call i64 @call_mdtuple_arg_equal_1()
85-
; CHECK-NEXT: ret i64 [[TMP1]]
86-
;
87-
;
88-
; CHECK-LABEL: define i64 @call_mdtuple_arg_not_equal_1() {
89-
; CHECK-NEXT: [[TMP1:%.*]] = tail call i64 @call_mdtuple_arg_equal_1()
90-
; CHECK-NEXT: ret i64 [[TMP1]]
91-
;
92-
;
9391
; CHECK-LABEL: define i64 @call_mdtuple_arg_equal_2() {
9492
; CHECK-NEXT: [[TMP1:%.*]] = tail call i64 @call_mdtuple_arg_equal_1()
9593
; CHECK-NEXT: ret i64 [[TMP1]]
9694
;
9795
;.
98-
; CHECK: [[META0]] = distinct !{[[META0]], !"foo"}
99-
; CHECK: [[META1]] = distinct !{[[META1]], !"foo"}
100-
; CHECK: [[META2]] = distinct !{[[META2]], !"bar"}
101-
; CHECK: [[META3]] = !{!"foo", i64 10}
96+
; CHECK: [[META0]] = !{!"foo"}
97+
; CHECK: [[META1]] = !{!"bar"}
98+
; CHECK: [[META2]] = distinct !{[[META2]], !"foo"}
99+
; CHECK: [[META3]] = distinct !{[[META3]], !"foo"}
100+
; CHECK: [[META4]] = distinct !{[[META4]], !"bar"}
101+
; CHECK: [[META5]] = !{!"foo", i64 10}
102102
;.

0 commit comments

Comments
 (0)