Skip to content

Commit c075fb8

Browse files
authored
[MLIR] Fix duplicated attribute nodes in MLIR bytecode deserialization (#151267)
Fixes #150163 MLIR bytecode does not preserve alias definitions, so each attribute encountered during deserialization is treated as a new one. This can generate duplicate `DISubprogram` nodes during deserialization. The patch adds a `StringMap` cache that records attributes and fetches them when encountered again.
1 parent 5f0515d commit c075fb8

File tree

5 files changed

+45
-4
lines changed

5 files changed

+45
-4
lines changed

mlir/include/mlir/AsmParser/AsmParser.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ parseAsmSourceFile(const llvm::SourceMgr &sourceMgr, Block *block,
5353
/// null terminated.
5454
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context,
5555
Type type = {}, size_t *numRead = nullptr,
56-
bool isKnownNullTerminated = false);
56+
bool isKnownNullTerminated = false,
57+
llvm::StringMap<Attribute> *attributesCache = nullptr);
5758

5859
/// This parses a single MLIR type to an MLIR context if it was valid. If not,
5960
/// an error diagnostic is emitted to the context.

mlir/lib/AsmParser/DialectSymbolParser.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,15 @@ static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState,
245245
return nullptr;
246246
}
247247

248+
if constexpr (std::is_same_v<Symbol, Attribute>) {
249+
auto &cache = p.getState().symbols.attributesCache;
250+
auto cacheIt = cache.find(symbolData);
251+
// Skip cached attribute if it has type.
252+
if (cacheIt != cache.end() && !p.getToken().is(Token::colon))
253+
return cacheIt->second;
254+
255+
return cache[symbolData] = createSymbol(dialectName, symbolData, loc);
256+
}
248257
return createSymbol(dialectName, symbolData, loc);
249258
}
250259

@@ -337,6 +346,7 @@ Type Parser::parseExtendedType() {
337346
template <typename T, typename ParserFn>
338347
static T parseSymbol(StringRef inputStr, MLIRContext *context,
339348
size_t *numReadOut, bool isKnownNullTerminated,
349+
llvm::StringMap<Attribute> *attributesCache,
340350
ParserFn &&parserFn) {
341351
// Set the buffer name to the string being parsed, so that it appears in error
342352
// diagnostics.
@@ -348,6 +358,9 @@ static T parseSymbol(StringRef inputStr, MLIRContext *context,
348358
SourceMgr sourceMgr;
349359
sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
350360
SymbolState aliasState;
361+
if (attributesCache)
362+
aliasState.attributesCache = *attributesCache;
363+
351364
ParserConfig config(context);
352365
ParserState state(sourceMgr, config, aliasState, /*asmState=*/nullptr,
353366
/*codeCompleteContext=*/nullptr);
@@ -358,6 +371,11 @@ static T parseSymbol(StringRef inputStr, MLIRContext *context,
358371
if (!symbol)
359372
return T();
360373

374+
if constexpr (std::is_same_v<T, Attribute>) {
375+
if (attributesCache)
376+
*attributesCache = state.symbols.attributesCache;
377+
}
378+
361379
// Provide the number of bytes that were read.
362380
Token endTok = parser.getToken();
363381
size_t numRead =
@@ -374,13 +392,15 @@ static T parseSymbol(StringRef inputStr, MLIRContext *context,
374392

375393
Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context,
376394
Type type, size_t *numRead,
377-
bool isKnownNullTerminated) {
395+
bool isKnownNullTerminated,
396+
llvm::StringMap<Attribute> *attributesCache) {
378397
return parseSymbol<Attribute>(
379-
attrStr, context, numRead, isKnownNullTerminated,
398+
attrStr, context, numRead, isKnownNullTerminated, attributesCache,
380399
[type](Parser &parser) { return parser.parseAttribute(type); });
381400
}
382401
Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t *numRead,
383402
bool isKnownNullTerminated) {
384403
return parseSymbol<Type>(typeStr, context, numRead, isKnownNullTerminated,
404+
/*attributesCache=*/nullptr,
385405
[](Parser &parser) { return parser.parseType(); });
386406
}

mlir/lib/AsmParser/ParserState.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ struct SymbolState {
4040

4141
/// A map from unique integer identifier to DistinctAttr.
4242
DenseMap<uint64_t, DistinctAttr> distinctAttributes;
43+
44+
/// A map from unique string identifier to Attribute.
45+
llvm::StringMap<Attribute> attributesCache;
4346
};
4447

4548
//===----------------------------------------------------------------------===//

mlir/lib/Bytecode/Reader/BytecodeReader.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -895,6 +895,10 @@ class AttrTypeReader {
895895
SmallVector<AttrEntry> attributes;
896896
SmallVector<TypeEntry> types;
897897

898+
/// The map of cached attributes, used to avoid re-parsing the same
899+
/// attribute multiple times.
900+
llvm::StringMap<Attribute> attributesCache;
901+
898902
/// A location used for error emission.
899903
Location fileLoc;
900904

@@ -1235,7 +1239,7 @@ LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader,
12351239
::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true);
12361240
else
12371241
result = ::parseAttribute(asmStr, context, Type(), &numRead,
1238-
/*isKnownNullTerminated=*/true);
1242+
/*isKnownNullTerminated=*/true, &attributesCache);
12391243
if (!result)
12401244
return failure();
12411245

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// RUN: mlir-opt -emit-bytecode %s | mlir-opt --mlir-print-debuginfo | FileCheck %s
2+
3+
// Verify that the distinct attribute which is used transitively
4+
// through two aliases does not end up duplicated when round-tripped
5+
// through bytecode.
6+
7+
// CHECK: distinct[0]
8+
// CHECK-NOT: distinct[1]
9+
#attr_ugly = #test<attr_ugly begin distinct[0]<> end>
10+
#attr_ugly1 = #test<attr_ugly begin #attr_ugly end>
11+
12+
module attributes {test.alias = #attr_ugly, test.alias1 = #attr_ugly1} {
13+
}

0 commit comments

Comments
 (0)