Skip to content

Commit ec478a1

Browse files
authored
Merge branch 'main' into main
2 parents 019b3ba + 6b6e8e1 commit ec478a1

File tree

4 files changed

+105
-34
lines changed

4 files changed

+105
-34
lines changed

llvm/include/llvm/ADT/DenseMap.h

Lines changed: 23 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,10 +1184,14 @@ class DenseMapIterator : DebugEpochBase::HandleBase {
11841184
using iterator_category = std::forward_iterator_tag;
11851185

11861186
private:
1187-
pointer Ptr = nullptr;
1188-
pointer End = nullptr;
1187+
using BucketItTy =
1188+
std::conditional_t<shouldReverseIterate<KeyT>(),
1189+
std::reverse_iterator<pointer>, pointer>;
11891190

1190-
DenseMapIterator(pointer Pos, pointer E, const DebugEpochBase &Epoch)
1191+
BucketItTy Ptr = {};
1192+
BucketItTy End = {};
1193+
1194+
DenseMapIterator(BucketItTy Pos, BucketItTy E, const DebugEpochBase &Epoch)
11911195
: DebugEpochBase::HandleBase(&Epoch), Ptr(Pos), End(E) {
11921196
assert(isHandleInSync() && "invalid construction!");
11931197
}
@@ -1201,29 +1205,24 @@ class DenseMapIterator : DebugEpochBase::HandleBase {
12011205
// empty buckets.
12021206
if (IsEmpty)
12031207
return makeEnd(Buckets, Epoch);
1204-
if (shouldReverseIterate<KeyT>()) {
1205-
DenseMapIterator Iter(Buckets.end(), Buckets.begin(), Epoch);
1206-
Iter.RetreatPastEmptyBuckets();
1207-
return Iter;
1208-
}
1209-
DenseMapIterator Iter(Buckets.begin(), Buckets.end(), Epoch);
1208+
auto R = maybeReverse(Buckets);
1209+
DenseMapIterator Iter(R.begin(), R.end(), Epoch);
12101210
Iter.AdvancePastEmptyBuckets();
12111211
return Iter;
12121212
}
12131213

12141214
static DenseMapIterator makeEnd(iterator_range<pointer> Buckets,
12151215
const DebugEpochBase &Epoch) {
1216-
if (shouldReverseIterate<KeyT>())
1217-
return DenseMapIterator(Buckets.begin(), Buckets.begin(), Epoch);
1218-
return DenseMapIterator(Buckets.end(), Buckets.end(), Epoch);
1216+
auto R = maybeReverse(Buckets);
1217+
return DenseMapIterator(R.end(), R.end(), Epoch);
12191218
}
12201219

12211220
static DenseMapIterator makeIterator(pointer P,
12221221
iterator_range<pointer> Buckets,
12231222
const DebugEpochBase &Epoch) {
1224-
if (shouldReverseIterate<KeyT>())
1225-
return DenseMapIterator(P + 1, Buckets.begin(), Epoch);
1226-
return DenseMapIterator(P, Buckets.end(), Epoch);
1223+
auto R = maybeReverse(Buckets);
1224+
constexpr int Offset = shouldReverseIterate<KeyT>() ? 1 : 0;
1225+
return DenseMapIterator(BucketItTy(P + Offset), R.end(), Epoch);
12271226
}
12281227

12291228
// Converting ctor from non-const iterators to const iterators. SFINAE'd out
@@ -1238,16 +1237,16 @@ class DenseMapIterator : DebugEpochBase::HandleBase {
12381237
reference operator*() const {
12391238
assert(isHandleInSync() && "invalid iterator access!");
12401239
assert(Ptr != End && "dereferencing end() iterator");
1241-
if (shouldReverseIterate<KeyT>())
1242-
return Ptr[-1];
12431240
return *Ptr;
12441241
}
12451242
pointer operator->() const { return &operator*(); }
12461243

12471244
friend bool operator==(const DenseMapIterator &LHS,
12481245
const DenseMapIterator &RHS) {
1249-
assert((!LHS.Ptr || LHS.isHandleInSync()) && "handle not in sync!");
1250-
assert((!RHS.Ptr || RHS.isHandleInSync()) && "handle not in sync!");
1246+
assert((!LHS.getEpochAddress() || LHS.isHandleInSync()) &&
1247+
"handle not in sync!");
1248+
assert((!RHS.getEpochAddress() || RHS.isHandleInSync()) &&
1249+
"handle not in sync!");
12511250
assert(LHS.getEpochAddress() == RHS.getEpochAddress() &&
12521251
"comparing incomparable iterators!");
12531252
return LHS.Ptr == RHS.Ptr;
@@ -1261,11 +1260,6 @@ class DenseMapIterator : DebugEpochBase::HandleBase {
12611260
inline DenseMapIterator &operator++() { // Preincrement
12621261
assert(isHandleInSync() && "invalid iterator access!");
12631262
assert(Ptr != End && "incrementing end() iterator");
1264-
if (shouldReverseIterate<KeyT>()) {
1265-
--Ptr;
1266-
RetreatPastEmptyBuckets();
1267-
return *this;
1268-
}
12691263
++Ptr;
12701264
AdvancePastEmptyBuckets();
12711265
return *this;
@@ -1288,14 +1282,11 @@ class DenseMapIterator : DebugEpochBase::HandleBase {
12881282
++Ptr;
12891283
}
12901284

1291-
void RetreatPastEmptyBuckets() {
1292-
assert(Ptr >= End);
1293-
const KeyT Empty = KeyInfoT::getEmptyKey();
1294-
const KeyT Tombstone = KeyInfoT::getTombstoneKey();
1295-
1296-
while (Ptr != End && (KeyInfoT::isEqual(Ptr[-1].getFirst(), Empty) ||
1297-
KeyInfoT::isEqual(Ptr[-1].getFirst(), Tombstone)))
1298-
--Ptr;
1285+
static auto maybeReverse(iterator_range<pointer> Range) {
1286+
if constexpr (shouldReverseIterate<KeyT>())
1287+
return reverse(Range);
1288+
else
1289+
return Range;
12991290
}
13001291
};
13011292

llvm/lib/Transforms/IPO/FunctionSpecialization.cpp

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ static cl::opt<bool> SpecializeLiteralConstant(
8989
"Enable specialization of functions that take a literal constant as an "
9090
"argument"));
9191

92+
extern cl::opt<bool> ProfcheckDisableMetadataFixes;
93+
9294
bool InstCostVisitor::canEliminateSuccessor(BasicBlock *BB,
9395
BasicBlock *Succ) const {
9496
unsigned I = 0;
@@ -784,9 +786,31 @@ bool FunctionSpecializer::run() {
784786

785787
// Update the known call sites to call the clone.
786788
for (CallBase *Call : S.CallSites) {
789+
Function *Clone = S.Clone;
787790
LLVM_DEBUG(dbgs() << "FnSpecialization: Redirecting " << *Call
788-
<< " to call " << S.Clone->getName() << "\n");
791+
<< " to call " << Clone->getName() << "\n");
789792
Call->setCalledFunction(S.Clone);
793+
auto &BFI = GetBFI(*Call->getFunction());
794+
std::optional<uint64_t> Count =
795+
BFI.getBlockProfileCount(Call->getParent());
796+
if (Count && !ProfcheckDisableMetadataFixes) {
797+
std::optional<llvm::Function::ProfileCount> MaybeCloneCount =
798+
Clone->getEntryCount();
799+
assert(MaybeCloneCount && "Clone entry count was not set!");
800+
uint64_t CallCount = *Count + MaybeCloneCount->getCount();
801+
Clone->setEntryCount(CallCount);
802+
if (std::optional<llvm::Function::ProfileCount> MaybeOriginalCount =
803+
S.F->getEntryCount()) {
804+
uint64_t OriginalCount = MaybeOriginalCount->getCount();
805+
if (OriginalCount >= CallCount) {
806+
S.F->setEntryCount(OriginalCount - CallCount);
807+
} else {
808+
// This should generally not happen as that would mean there are
809+
// more computed calls to the function than what was recorded.
810+
LLVM_DEBUG(S.F->setEntryCount(0));
811+
}
812+
}
813+
}
790814
}
791815

792816
Clones.push_back(S.Clone);
@@ -1043,6 +1067,9 @@ Function *FunctionSpecializer::createSpecialization(Function *F,
10431067
// clone must.
10441068
Clone->setLinkage(GlobalValue::InternalLinkage);
10451069

1070+
if (F->getEntryCount() && !ProfcheckDisableMetadataFixes)
1071+
Clone->setEntryCount(0);
1072+
10461073
// Initialize the lattice state of the arguments of the function clone,
10471074
// marking the argument on which we specialized the function constant
10481075
// with the given value.
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
; RUN: opt -passes="ipsccp<func-spec>" -force-specialization -S < %s | FileCheck %s
2+
target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
3+
4+
@A = external dso_local constant i32, align 4
5+
@B = external dso_local constant i32, align 4
6+
7+
; CHECK: define dso_local i32 @bar(i32 %x, i32 %y, ptr %z) !prof ![[BAR_PROF:[0-9]]] {
8+
define dso_local i32 @bar(i32 %x, i32 %y, ptr %z) !prof !0 {
9+
entry:
10+
%tobool = icmp ne i32 %x, 0
11+
; CHECK: br i1 %tobool, label %if.then, label %if.else, !prof ![[BRANCH_PROF:[0-9]]]
12+
br i1 %tobool, label %if.then, label %if.else, !prof !1
13+
14+
; CHECK: if.then:
15+
; CHECK: call i32 @foo.specialized.1(i32 %x, ptr @A)
16+
if.then:
17+
%call = call i32 @foo(i32 %x, ptr @A)
18+
br label %return
19+
20+
; CHECK: if.else:
21+
; CHECK: call i32 @foo.specialized.2(i32 %y, ptr @B)
22+
if.else:
23+
%call1 = call i32 @foo(i32 %y, ptr @B)
24+
br label %return
25+
26+
; CHECK: return:
27+
; CHECK: %call2 = call i32 @foo(i32 %x, ptr %z)
28+
return:
29+
%retval.0 = phi i32 [ %call, %if.then ], [ %call1, %if.else ]
30+
%call2 = call i32 @foo(i32 %x, ptr %z);
31+
%add = add i32 %retval.0, %call2
32+
ret i32 %add
33+
}
34+
35+
; CHECK: define internal i32 @foo(i32 %x, ptr %b) !prof ![[FOO_UNSPEC_PROF:[0-9]]]
36+
; CHECK: define internal i32 @foo.specialized.1(i32 %x, ptr %b) !prof ![[FOO_SPEC_1_PROF:[0-9]]]
37+
; CHECK: define internal i32 @foo.specialized.2(i32 %x, ptr %b) !prof ![[FOO_SPEC_2_PROF:[0-9]]]
38+
define internal i32 @foo(i32 %x, ptr %b) !prof !2 {
39+
entry:
40+
%0 = load i32, ptr %b, align 4
41+
%add = add nsw i32 %x, %0
42+
ret i32 %add
43+
}
44+
45+
; CHECK: ![[BAR_PROF]] = !{!"function_entry_count", i64 1000}
46+
; CHECK: ![[BRANCH_PROF]] = !{!"branch_weights", i32 1, i32 3}
47+
; CHECK: ![[FOO_UNSPEC_PROF]] = !{!"function_entry_count", i64 234}
48+
; CHECK: ![[FOO_SPEC_1_PROF]] = !{!"function_entry_count", i64 250}
49+
; CHECK: ![[FOO_SPEC_2_PROF]] = !{!"function_entry_count", i64 750}
50+
!0 = !{!"function_entry_count", i64 1000}
51+
!1 = !{!"branch_weights", i32 1, i32 3}
52+
!2 = !{!"function_entry_count", i64 1234}

mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,8 @@ std::unique_ptr<MPIImplTraits> MPIImplTraits::get(ModuleOp &moduleOp) {
405405
return std::make_unique<OMPIImplTraits>(moduleOp);
406406
if (!strAttr || strAttr.getValue() != "MPICH")
407407
moduleOp.emitWarning() << "Unknown \"MPI:Implementation\" value in DLTI ("
408-
<< strAttr.getValue() << "), defaulting to MPICH";
408+
<< (strAttr ? strAttr.getValue() : "<NULL>")
409+
<< "), defaulting to MPICH";
409410
return std::make_unique<MPICHImplTraits>(moduleOp);
410411
}
411412

0 commit comments

Comments
 (0)