Skip to content

Commit 93d2ef1

Browse files
authored
[mlir][bytecode] Add support for deferred attribute/type parsing. (#170993)
Add ability to defer parsing and re-enqueueing oneself. This enables changing CallSiteLoc parsing to not recurse as deeply: previously this could fail (especially on large inputs in debug mode the recursion could overflow). Add a default depth cutoff, this could be a parameter later if needed.
1 parent 00bccfc commit 93d2ef1

File tree

2 files changed

+241
-53
lines changed

2 files changed

+241
-53
lines changed

mlir/lib/Bytecode/Reader/BytecodeReader.cpp

Lines changed: 204 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
#include <cstddef>
2929
#include <cstdint>
30+
#include <deque>
3031
#include <list>
3132
#include <memory>
3233
#include <numeric>
@@ -830,6 +831,23 @@ namespace {
830831
/// This class provides support for reading attribute and type entries from the
831832
/// bytecode. Attribute and Type entries are read lazily on demand, so we use
832833
/// this reader to manage when to actually parse them from the bytecode.
834+
///
835+
/// The parsing of attributes & types are generally recursive, this can lead to
836+
/// stack overflows for deeply nested structures, so we track a few extra pieces
837+
/// of information to avoid this:
838+
///
839+
/// - `depth`: The current depth while parsing nested attributes. We defer on
840+
/// parsing deeply nested attributes to avoid potential stack overflows. The
841+
/// deferred parsing is achieved by reporting a failure when parsing a nested
842+
/// attribute/type and registering the index of the encountered attribute/type
843+
/// in the deferred parsing worklist. Hence, a failure with deffered entry
844+
/// does not constitute a failure, it also requires that folks return on
845+
/// first failure rather than attempting additional parses.
846+
/// - `deferredWorklist`: A list of attribute/type indices that we could not
847+
/// parse due to hitting the depth limit. The worklist is used to capture the
848+
/// indices of attributes/types that need to be parsed/reparsed when we hit
849+
/// the depth limit. This enables moving the tracking of what needs to be
850+
/// parsed to the heap.
833851
class AttrTypeReader {
834852
/// This class represents a single attribute or type entry.
835853
template <typename T>
@@ -863,12 +881,34 @@ class AttrTypeReader {
863881
ArrayRef<uint8_t> sectionData,
864882
ArrayRef<uint8_t> offsetSectionData);
865883

884+
LogicalResult readAttribute(uint64_t index, Attribute &result,
885+
uint64_t depth = 0) {
886+
return readEntry(attributes, index, result, "attribute", depth);
887+
}
888+
889+
LogicalResult readType(uint64_t index, Type &result, uint64_t depth = 0) {
890+
return readEntry(types, index, result, "type", depth);
891+
}
892+
866893
/// Resolve the attribute or type at the given index. Returns nullptr on
867894
/// failure.
868-
Attribute resolveAttribute(size_t index) {
869-
return resolveEntry(attributes, index, "Attribute");
895+
Attribute resolveAttribute(size_t index, uint64_t depth = 0) {
896+
return resolveEntry(attributes, index, "Attribute", depth);
897+
}
898+
Type resolveType(size_t index, uint64_t depth = 0) {
899+
return resolveEntry(types, index, "Type", depth);
900+
}
901+
902+
Attribute getAttributeOrSentinel(size_t index) {
903+
if (index >= attributes.size())
904+
return nullptr;
905+
return attributes[index].entry;
906+
}
907+
Type getTypeOrSentinel(size_t index) {
908+
if (index >= types.size())
909+
return nullptr;
910+
return types[index].entry;
870911
}
871-
Type resolveType(size_t index) { return resolveEntry(types, index, "Type"); }
872912

873913
/// Parse a reference to an attribute or type using the given reader.
874914
LogicalResult parseAttribute(EncodingReader &reader, Attribute &result) {
@@ -909,23 +949,33 @@ class AttrTypeReader {
909949
llvm::getTypeName<T>(), ", but got: ", baseResult);
910950
}
911951

952+
/// Add an index to the deferred worklist for re-parsing.
953+
void addDeferredParsing(uint64_t index) { deferredWorklist.push_back(index); }
954+
912955
private:
913956
/// Resolve the given entry at `index`.
914957
template <typename T>
915-
T resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index,
916-
StringRef entryType);
958+
T resolveEntry(SmallVectorImpl<Entry<T>> &entries, uint64_t index,
959+
StringRef entryType, uint64_t depth = 0);
917960

918-
/// Parse an entry using the given reader that was encoded using the textual
919-
/// assembly format.
961+
/// Read the entry at the given index, returning failure if the entry is not
962+
/// yet resolved.
920963
template <typename T>
921-
LogicalResult parseAsmEntry(T &result, EncodingReader &reader,
922-
StringRef entryType);
964+
LogicalResult readEntry(SmallVectorImpl<Entry<T>> &entries, uint64_t index,
965+
T &result, StringRef entryType, uint64_t depth);
923966

924967
/// Parse an entry using the given reader that was encoded using a custom
925968
/// bytecode format.
926969
template <typename T>
927970
LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader,
928-
StringRef entryType);
971+
StringRef entryType, uint64_t index,
972+
uint64_t depth);
973+
974+
/// Parse an entry using the given reader that was encoded using the textual
975+
/// assembly format.
976+
template <typename T>
977+
LogicalResult parseAsmEntry(T &result, EncodingReader &reader,
978+
StringRef entryType);
929979

930980
/// The string section reader used to resolve string references when parsing
931981
/// custom encoded attribute/type entries.
@@ -951,6 +1001,10 @@ class AttrTypeReader {
9511001

9521002
/// Reference to the parser configuration.
9531003
const ParserConfig &parserConfig;
1004+
1005+
/// Worklist for deferred attribute/type parsing. This is used to handle
1006+
/// deeply nested structures like CallSiteLoc iteratively.
1007+
std::vector<uint64_t> deferredWorklist;
9541008
};
9551009

9561010
class DialectReader : public DialectBytecodeReader {
@@ -959,10 +1013,11 @@ class DialectReader : public DialectBytecodeReader {
9591013
const StringSectionReader &stringReader,
9601014
const ResourceSectionReader &resourceReader,
9611015
const llvm::StringMap<BytecodeDialect *> &dialectsMap,
962-
EncodingReader &reader, uint64_t &bytecodeVersion)
1016+
EncodingReader &reader, uint64_t &bytecodeVersion,
1017+
uint64_t depth = 0)
9631018
: attrTypeReader(attrTypeReader), stringReader(stringReader),
9641019
resourceReader(resourceReader), dialectsMap(dialectsMap),
965-
reader(reader), bytecodeVersion(bytecodeVersion) {}
1020+
reader(reader), bytecodeVersion(bytecodeVersion), depth(depth) {}
9661021

9671022
InFlightDiagnostic emitError(const Twine &msg) const override {
9681023
return reader.emitError(msg);
@@ -998,14 +1053,40 @@ class DialectReader : public DialectBytecodeReader {
9981053
// IR
9991054
//===--------------------------------------------------------------------===//
10001055

1056+
/// The maximum depth to eagerly parse nested attributes/types before
1057+
/// deferring.
1058+
static constexpr uint64_t maxAttrTypeDepth = 5;
1059+
10011060
LogicalResult readAttribute(Attribute &result) override {
1002-
return attrTypeReader.parseAttribute(reader, result);
1061+
uint64_t index;
1062+
if (failed(reader.parseVarInt(index)))
1063+
return failure();
1064+
if (depth > maxAttrTypeDepth) {
1065+
if (Attribute attr = attrTypeReader.getAttributeOrSentinel(index)) {
1066+
result = attr;
1067+
return success();
1068+
}
1069+
attrTypeReader.addDeferredParsing(index);
1070+
return failure();
1071+
}
1072+
return attrTypeReader.readAttribute(index, result, depth + 1);
10031073
}
10041074
LogicalResult readOptionalAttribute(Attribute &result) override {
10051075
return attrTypeReader.parseOptionalAttribute(reader, result);
10061076
}
10071077
LogicalResult readType(Type &result) override {
1008-
return attrTypeReader.parseType(reader, result);
1078+
uint64_t index;
1079+
if (failed(reader.parseVarInt(index)))
1080+
return failure();
1081+
if (depth > maxAttrTypeDepth) {
1082+
if (Type type = attrTypeReader.getTypeOrSentinel(index)) {
1083+
result = type;
1084+
return success();
1085+
}
1086+
attrTypeReader.addDeferredParsing(index);
1087+
return failure();
1088+
}
1089+
return attrTypeReader.readType(index, result, depth + 1);
10091090
}
10101091

10111092
FailureOr<AsmDialectResourceHandle> readResourceHandle() override {
@@ -1095,6 +1176,7 @@ class DialectReader : public DialectBytecodeReader {
10951176
const llvm::StringMap<BytecodeDialect *> &dialectsMap;
10961177
EncodingReader &reader;
10971178
uint64_t &bytecodeVersion;
1179+
uint64_t depth;
10981180
};
10991181

11001182
/// Wraps the properties section and handles reading properties out of it.
@@ -1239,68 +1321,110 @@ LogicalResult AttrTypeReader::initialize(
12391321

12401322
template <typename T>
12411323
T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index,
1242-
StringRef entryType) {
1324+
StringRef entryType, uint64_t depth) {
12431325
if (index >= entries.size()) {
12441326
emitError(fileLoc) << "invalid " << entryType << " index: " << index;
12451327
return {};
12461328
}
12471329

1248-
// If the entry has already been resolved, there is nothing left to do.
1249-
Entry<T> &entry = entries[index];
1250-
if (entry.entry)
1251-
return entry.entry;
1330+
// Fast path: Try direct parsing without worklist overhead. This handles the
1331+
// common case where there are no deferred dependencies.
1332+
assert(deferredWorklist.empty());
1333+
T result;
1334+
if (succeeded(readEntry(entries, index, result, entryType, depth))) {
1335+
assert(deferredWorklist.empty());
1336+
return result;
1337+
}
1338+
if (deferredWorklist.empty()) {
1339+
// Failed with no deferred entries is error.
1340+
return T();
1341+
}
12521342

1253-
// Parse the entry.
1254-
EncodingReader reader(entry.data, fileLoc);
1343+
// Slow path: Use worklist to handle deferred dependencies. Use a deque to
1344+
// iteratively resolve entries with dependencies.
1345+
// - Pop from front to process
1346+
// - Push new dependencies to front (depth-first)
1347+
// - Move failed entries to back (retry after dependencies)
1348+
std::deque<size_t> worklist;
1349+
llvm::DenseSet<size_t> inWorklist;
12551350

1256-
// Parse based on how the entry was encoded.
1257-
if (entry.hasCustomEncoding) {
1258-
if (failed(parseCustomEntry(entry, reader, entryType)))
1259-
return T();
1260-
} else if (failed(parseAsmEntry(entry.entry, reader, entryType))) {
1261-
return T();
1351+
// Add the original index and any dependencies from the fast path attempt.
1352+
worklist.push_back(index);
1353+
inWorklist.insert(index);
1354+
for (uint64_t idx : llvm::reverse(deferredWorklist)) {
1355+
if (inWorklist.insert(idx).second)
1356+
worklist.push_front(idx);
12621357
}
12631358

1264-
if (!reader.empty()) {
1265-
reader.emitError("unexpected trailing bytes after " + entryType + " entry");
1266-
return T();
1359+
while (!worklist.empty()) {
1360+
size_t currentIndex = worklist.front();
1361+
worklist.pop_front();
1362+
1363+
// Clear the deferred worklist before parsing to capture any new entries.
1364+
deferredWorklist.clear();
1365+
1366+
T result;
1367+
if (succeeded(readEntry(entries, currentIndex, result, entryType, depth))) {
1368+
inWorklist.erase(currentIndex);
1369+
continue;
1370+
}
1371+
1372+
if (deferredWorklist.empty()) {
1373+
// Parsing failed with no deferred entries which implies an error.
1374+
return T();
1375+
}
1376+
1377+
// Move this entry to the back to retry after dependencies.
1378+
worklist.push_back(currentIndex);
1379+
1380+
// Add dependencies to the front (in reverse so they maintain order).
1381+
for (uint64_t idx : llvm::reverse(deferredWorklist)) {
1382+
if (inWorklist.insert(idx).second)
1383+
worklist.push_front(idx);
1384+
}
1385+
deferredWorklist.clear();
12671386
}
1268-
return entry.entry;
1387+
return entries[index].entry;
12691388
}
12701389

12711390
template <typename T>
1272-
LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader,
1273-
StringRef entryType) {
1274-
StringRef asmStr;
1275-
if (failed(reader.parseNullTerminatedString(asmStr)))
1276-
return failure();
1391+
LogicalResult AttrTypeReader::readEntry(SmallVectorImpl<Entry<T>> &entries,
1392+
uint64_t index, T &result,
1393+
StringRef entryType, uint64_t depth) {
1394+
if (index >= entries.size())
1395+
return emitError(fileLoc) << "invalid " << entryType << " index: " << index;
12771396

1278-
// Invoke the MLIR assembly parser to parse the entry text.
1279-
size_t numRead = 0;
1280-
MLIRContext *context = fileLoc->getContext();
1281-
if constexpr (std::is_same_v<T, Type>)
1282-
result =
1283-
::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true);
1284-
else
1285-
result = ::parseAttribute(asmStr, context, Type(), &numRead,
1286-
/*isKnownNullTerminated=*/true);
1287-
if (!result)
1397+
// If the entry has already been resolved, return it.
1398+
Entry<T> &entry = entries[index];
1399+
if (entry.entry) {
1400+
result = entry.entry;
1401+
return success();
1402+
}
1403+
1404+
// If the entry hasn't been resolved, try to parse it.
1405+
EncodingReader reader(entry.data, fileLoc);
1406+
LogicalResult parseResult =
1407+
entry.hasCustomEncoding
1408+
? parseCustomEntry(entry, reader, entryType, index, depth)
1409+
: parseAsmEntry(entry.entry, reader, entryType);
1410+
if (failed(parseResult))
12881411
return failure();
12891412

1290-
// Ensure there weren't dangling characters after the entry.
1291-
if (numRead != asmStr.size()) {
1292-
return reader.emitError("trailing characters found after ", entryType,
1293-
" assembly format: ", asmStr.drop_front(numRead));
1294-
}
1413+
if (!reader.empty())
1414+
return reader.emitError("unexpected trailing bytes after " + entryType +
1415+
" entry");
1416+
1417+
result = entry.entry;
12951418
return success();
12961419
}
12971420

12981421
template <typename T>
12991422
LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
13001423
EncodingReader &reader,
1301-
StringRef entryType) {
1424+
StringRef entryType,
1425+
uint64_t index, uint64_t depth) {
13021426
DialectReader dialectReader(*this, stringReader, resourceReader, dialectsMap,
1303-
reader, bytecodeVersion);
1427+
reader, bytecodeVersion, depth);
13041428
if (failed(entry.dialect->load(dialectReader, fileLoc.getContext())))
13051429
return failure();
13061430

@@ -1350,6 +1474,33 @@ LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
13501474
return success(!!entry.entry);
13511475
}
13521476

1477+
template <typename T>
1478+
LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader,
1479+
StringRef entryType) {
1480+
StringRef asmStr;
1481+
if (failed(reader.parseNullTerminatedString(asmStr)))
1482+
return failure();
1483+
1484+
// Invoke the MLIR assembly parser to parse the entry text.
1485+
size_t numRead = 0;
1486+
MLIRContext *context = fileLoc->getContext();
1487+
if constexpr (std::is_same_v<T, Type>)
1488+
result =
1489+
::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true);
1490+
else
1491+
result = ::parseAttribute(asmStr, context, Type(), &numRead,
1492+
/*isKnownNullTerminated=*/true);
1493+
if (!result)
1494+
return failure();
1495+
1496+
// Ensure there weren't dangling characters after the entry.
1497+
if (numRead != asmStr.size()) {
1498+
return reader.emitError("trailing characters found after ", entryType,
1499+
" assembly format: ", asmStr.drop_front(numRead));
1500+
}
1501+
return success();
1502+
}
1503+
13531504
//===----------------------------------------------------------------------===//
13541505
// Bytecode Reader
13551506
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)