Skip to content

Commit d2ad942

Browse files
committed
use fallback on parsing failure
1 parent f88ae0d commit d2ad942

File tree

1 file changed

+58
-28
lines changed

1 file changed

+58
-28
lines changed

mlir/lib/Bytecode/Reader/BytecodeReader.cpp

Lines changed: 58 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,16 @@ class EncodingReader {
293293

294294
Location getLoc() const { return fileLoc; }
295295

296+
/// Snapshot the location of the BytecodeReader so that parsing can be rewound
297+
/// if needed.
298+
struct Snapshot {
299+
EncodingReader &reader;
300+
const uint8_t *dataIt;
301+
302+
Snapshot(EncodingReader &reader) : reader(reader), dataIt(reader.dataIt) {}
303+
void rewind() { reader.dataIt = dataIt; }
304+
};
305+
296306
private:
297307
/// Parse a variable length encoded integer from the byte stream. This method
298308
/// is a fallback when the number of bytes used to encode the value is greater
@@ -1417,7 +1427,8 @@ class mlir::BytecodeReader::Impl {
14171427
/// `wasRegistered` flag that indicates if the bytecode was produced by a
14181428
/// context where opName was registered.
14191429
FailureOr<BytecodeOperationName *>
1420-
parseOpName(EncodingReader &reader, std::optional<bool> &wasRegistered);
1430+
parseOpName(EncodingReader &reader, std::optional<bool> &wasRegistered,
1431+
bool useDialectFallback);
14211432

14221433
//===--------------------------------------------------------------------===//
14231434
// Attribute/Type Section
@@ -1482,7 +1493,8 @@ class mlir::BytecodeReader::Impl {
14821493
RegionReadState &readState);
14831494
FailureOr<Operation *> parseOpWithoutRegions(EncodingReader &reader,
14841495
RegionReadState &readState,
1485-
bool &isIsolatedFromAbove);
1496+
bool &isIsolatedFromAbove,
1497+
bool useDialectFallback);
14861498

14871499
LogicalResult parseRegion(RegionReadState &readState);
14881500
LogicalResult parseBlockHeader(EncodingReader &reader,
@@ -1851,14 +1863,18 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
18511863

18521864
FailureOr<BytecodeOperationName *>
18531865
BytecodeReader::Impl::parseOpName(EncodingReader &reader,
1854-
std::optional<bool> &wasRegistered) {
1866+
std::optional<bool> &wasRegistered,
1867+
bool useDialectFallback) {
18551868
BytecodeOperationName *opName = nullptr;
18561869
if (failed(parseEntry(reader, opNames, opName, "operation name")))
18571870
return failure();
18581871
wasRegistered = opName->wasRegistered;
18591872
// Check to see if this operation name has already been resolved. If we
18601873
// haven't, load the dialect and build the operation name.
1861-
if (!opName->opName) {
1874+
// If `useDialectFallback`, it's likely that parsing previously failed. We'll
1875+
// need to reset any previously resolved OperationName with that of the
1876+
// fallback op.
1877+
if (!opName->opName || useDialectFallback) {
18621878
// If the opName is empty, this is because we use to accept names such as
18631879
// `foo` without any `.` separator. We shouldn't tolerate this in textual
18641880
// format anymore but for now we'll be backward compatible. This can only
@@ -1872,21 +1888,19 @@ BytecodeReader::Impl::parseOpName(EncodingReader &reader,
18721888
if (failed(opName->dialect->load(dialectReader, getContext())))
18731889
return failure();
18741890

1875-
opName->opName.emplace((opName->dialect->name + "." + opName->name).str(),
1876-
getContext());
1877-
1878-
// If the op is unregistered now, but was not marked as unregistered, try
1879-
// to parse it as a fallback op if the dialect's bytecode interface
1880-
// specifies one.
1881-
// We don't treat this condition as an error because we may still be able
1882-
// to parse the op as an unregistered op if it doesn't use custom
1883-
// properties encoding.
1884-
if (wasRegistered && !opName->opName->isRegistered()) {
1885-
if (auto fallbackOp =
1886-
opName->dialect->interface->getFallbackOperationName();
1887-
succeeded(fallbackOp)) {
1888-
opName->opName.emplace(*fallbackOp);
1889-
}
1891+
if (useDialectFallback) {
1892+
auto fallbackOp =
1893+
opName->dialect->interface->getFallbackOperationName();
1894+
1895+
// If the dialect doesn't have a fallback operation, we can't parse as
1896+
// instructed.
1897+
if (failed(fallbackOp))
1898+
return failure();
1899+
1900+
opName->opName.emplace(*fallbackOp);
1901+
} else {
1902+
opName->opName.emplace(
1903+
(opName->dialect->name + "." + opName->name).str(), getContext());
18901904
}
18911905
}
18921906
}
@@ -2164,10 +2178,27 @@ BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> &regionStack,
21642178
// Read in the next operation. We don't read its regions directly, we
21652179
// handle those afterwards as necessary.
21662180
bool isIsolatedFromAbove = false;
2167-
FailureOr<Operation *> op =
2168-
parseOpWithoutRegions(reader, readState, isIsolatedFromAbove);
2169-
if (failed(op))
2170-
return failure();
2181+
FailureOr<Operation *> op;
2182+
2183+
// Parse the bytecode.
2184+
{
2185+
EncodingReader::Snapshot snapshot(reader);
2186+
op = parseOpWithoutRegions(reader, readState, isIsolatedFromAbove,
2187+
/*useDialectFallback=*/false);
2188+
2189+
// If reading fails, try parsing the op again as a dialect fallback
2190+
// op (if supported).
2191+
if (failed(op)) {
2192+
snapshot.rewind();
2193+
op = parseOpWithoutRegions(reader, readState, isIsolatedFromAbove,
2194+
/*useDialectFallback=*/true);
2195+
}
2196+
2197+
// If the dialect doesn't have a fallback op, or parsing as a fallback
2198+
// op fails, we can no longer continue.
2199+
if (failed(op))
2200+
return failure();
2201+
}
21712202

21722203
// If the op has regions, add it to the stack for processing and return:
21732204
// we stop the processing of the current region and resume it after the
@@ -2229,14 +2260,13 @@ BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> &regionStack,
22292260
return success();
22302261
}
22312262

2232-
FailureOr<Operation *>
2233-
BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
2234-
RegionReadState &readState,
2235-
bool &isIsolatedFromAbove) {
2263+
FailureOr<Operation *> BytecodeReader::Impl::parseOpWithoutRegions(
2264+
EncodingReader &reader, RegionReadState &readState,
2265+
bool &isIsolatedFromAbove, bool useDialectFallback) {
22362266
// Parse the name of the operation.
22372267
std::optional<bool> wasRegistered;
22382268
FailureOr<BytecodeOperationName *> bytecodeOp =
2239-
parseOpName(reader, wasRegistered);
2269+
parseOpName(reader, wasRegistered, useDialectFallback);
22402270
if (failed(bytecodeOp))
22412271
return failure();
22422272
auto opName = (*bytecodeOp)->opName;

0 commit comments

Comments
 (0)