Skip to content

Commit b597360

Browse files
authored
[Synth] Fix race condition and memory corruption in longest path analysis caching (#9098)
This commit addresses a race condition and memory corruption issue in the LongestPathAnalysis where concurrent access to cached results could lead to iterator invalidation and memory corruption. The issue occurred when the cachedResults DenseMap was modified during iteration, causing existing iterators and references to become invalid and point to corrupted memory. The fix changes the cachedResults map to store unique_ptr<SmallVector<OpenPath>> instead of SmallVector<OpenPath> directly. This ensures that the underlying SmallVector objects remain at stable memory locations even when the DenseMap is rehashed or modified, preventing iterator invalidation and memory corruption.
1 parent 8c321f9 commit b597360

File tree

1 file changed

+19
-10
lines changed

1 file changed

+19
-10
lines changed

lib/Dialect/Synth/Analysis/LongestPathAnalysis.cpp

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -730,7 +730,8 @@ class LocalVisitor {
730730
std::unique_ptr<llvm::ImmutableListFactory<DebugPoint>> debugPointFactory;
731731

732732
// A map from the value point to the longest paths.
733-
DenseMap<std::pair<Value, size_t>, SmallVector<OpenPath>> cachedResults;
733+
DenseMap<std::pair<Value, size_t>, std::unique_ptr<SmallVector<OpenPath>>>
734+
cachedResults;
734735

735736
// A map from the object to the longest paths.
736737
DenseMap<Object, SmallVector<OpenPath>> endPointResults;
@@ -775,7 +776,7 @@ ArrayRef<OpenPath> LocalVisitor::getCachedPaths(Value value,
775776
// If not found, then consider it to be a constant.
776777
if (it == cachedResults.end())
777778
return {};
778-
return it->second;
779+
return *it->second;
779780
}
780781

781782
void LocalVisitor::putUnclosedResult(const Object &object, int64_t delay,
@@ -817,6 +818,10 @@ LogicalResult LocalVisitor::markRegEndPoint(Value endPoint, Value start,
817818

818819
// Get paths for each bit, and record them.
819820
for (size_t i = 0, e = bitWidth; i < e; ++i) {
821+
// Call getOrComputePaths to make sure the paths are computed for endPoint.
822+
// This avoids a race condition.
823+
if (failed(getOrComputePaths(endPoint, i)))
824+
return failure();
820825
if (failed(record(i, start, i)))
821826
return failure();
822827
}
@@ -1104,6 +1109,10 @@ LogicalResult LocalVisitor::visit(mlir::BlockArgument arg, size_t bitPos,
11041109

11051110
FailureOr<ArrayRef<OpenPath>> LocalVisitor::getOrComputePaths(Value value,
11061111
size_t bitPos) {
1112+
1113+
if (value.getDefiningOp<hw::ConstantOp>())
1114+
return ArrayRef<OpenPath>{};
1115+
11071116
if (ec.contains({value, bitPos})) {
11081117
auto leader = ec.findLeader({value, bitPos});
11091118
// If this is not the leader, then use the leader.
@@ -1114,19 +1123,19 @@ FailureOr<ArrayRef<OpenPath>> LocalVisitor::getOrComputePaths(Value value,
11141123

11151124
auto it = cachedResults.find({value, bitPos});
11161125
if (it != cachedResults.end())
1117-
return ArrayRef<OpenPath>(it->second);
1126+
return ArrayRef<OpenPath>(*it->second);
11181127

1119-
SmallVector<OpenPath> results;
1120-
if (failed(visitValue(value, bitPos, results)))
1128+
auto results = std::make_unique<SmallVector<OpenPath>>();
1129+
if (failed(visitValue(value, bitPos, *results)))
11211130
return {};
11221131

11231132
// Unique the results.
1124-
filterPaths(results, ctx->doKeepOnlyMaxDelayPaths(), ctx->isLocalScope());
1133+
filterPaths(*results, ctx->doKeepOnlyMaxDelayPaths(), ctx->isLocalScope());
11251134
LLVM_DEBUG({
11261135
llvm::dbgs() << value << "[" << bitPos << "] "
1127-
<< "Found " << results.size() << " paths\n";
1136+
<< "Found " << results->size() << " paths\n";
11281137
llvm::dbgs() << "====Paths:\n";
1129-
for (auto &path : results) {
1138+
for (auto &path : *results) {
11301139
path.print(llvm::dbgs());
11311140
llvm::dbgs() << "\n";
11321141
}
@@ -1136,7 +1145,7 @@ FailureOr<ArrayRef<OpenPath>> LocalVisitor::getOrComputePaths(Value value,
11361145
auto insertedResult =
11371146
cachedResults.try_emplace({value, bitPos}, std::move(results));
11381147
assert(insertedResult.second);
1139-
return ArrayRef<OpenPath>(insertedResult.first->second);
1148+
return ArrayRef<OpenPath>(*insertedResult.first->second);
11401149
}
11411150

11421151
LogicalResult LocalVisitor::visitValue(Value value, size_t bitPos,
@@ -1294,7 +1303,7 @@ LogicalResult LocalVisitor::initializeAndRun() {
12941303
op.getEnable());
12951304
})
12961305
.Case<aig::AndInverterOp, comb::AndOp, comb::OrOp, comb::XorOp,
1297-
comb::MuxOp>([&](auto op) {
1306+
comb::MuxOp, seq::FirMemReadOp>([&](auto op) {
12981307
// NOTE: Visiting and-inverter is not necessary but
12991308
// useful to reduce recursion depth.
13001309
for (size_t i = 0, e = getBitWidth(op); i < e; ++i)

0 commit comments

Comments
 (0)