@@ -293,6 +293,16 @@ class EncodingReader {
293293
294294 Location getLoc () const { return fileLoc; }
295295
296+ // / Snapshot the location of the BytecodeReader so that parsing can be rewound
297+ // / if needed.
298+ struct Snapshot {
299+ EncodingReader &reader;
300+ const uint8_t *dataIt;
301+
302+ Snapshot (EncodingReader &reader) : reader(reader), dataIt(reader.dataIt) {}
303+ void rewind () { reader.dataIt = dataIt; }
304+ };
305+
296306private:
297307 // / Parse a variable length encoded integer from the byte stream. This method
298308 // / is a fallback when the number of bytes used to encode the value is greater
@@ -1417,7 +1427,8 @@ class mlir::BytecodeReader::Impl {
14171427 // / `wasRegistered` flag that indicates if the bytecode was produced by a
14181428 // / context where opName was registered.
14191429 FailureOr<BytecodeOperationName *>
1420- parseOpName (EncodingReader &reader, std::optional<bool > &wasRegistered);
1430+ parseOpName (EncodingReader &reader, std::optional<bool > &wasRegistered,
1431+ bool useDialectFallback);
14211432
14221433 // ===--------------------------------------------------------------------===//
14231434 // Attribute/Type Section
@@ -1482,7 +1493,8 @@ class mlir::BytecodeReader::Impl {
14821493 RegionReadState &readState);
14831494 FailureOr<Operation *> parseOpWithoutRegions (EncodingReader &reader,
14841495 RegionReadState &readState,
1485- bool &isIsolatedFromAbove);
1496+ bool &isIsolatedFromAbove,
1497+ bool useDialectFallback);
14861498
14871499 LogicalResult parseRegion (RegionReadState &readState);
14881500 LogicalResult parseBlockHeader (EncodingReader &reader,
@@ -1851,14 +1863,18 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
18511863
18521864FailureOr<BytecodeOperationName *>
18531865BytecodeReader::Impl::parseOpName (EncodingReader &reader,
1854- std::optional<bool > &wasRegistered) {
1866+ std::optional<bool > &wasRegistered,
1867+ bool useDialectFallback) {
18551868 BytecodeOperationName *opName = nullptr ;
18561869 if (failed (parseEntry (reader, opNames, opName, " operation name" )))
18571870 return failure ();
18581871 wasRegistered = opName->wasRegistered ;
18591872 // Check to see if this operation name has already been resolved. If we
18601873 // haven't, load the dialect and build the operation name.
1861- if (!opName->opName ) {
1874+ // If `useDialectFallback`, it's likely that parsing previously failed. We'll
1875+ // need to reset any previously resolved OperationName with that of the
1876+ // fallback op.
1877+ if (!opName->opName || useDialectFallback) {
18621878 // If the opName is empty, this is because we use to accept names such as
18631879 // `foo` without any `.` separator. We shouldn't tolerate this in textual
18641880 // format anymore but for now we'll be backward compatible. This can only
@@ -1872,21 +1888,19 @@ BytecodeReader::Impl::parseOpName(EncodingReader &reader,
18721888 if (failed (opName->dialect ->load (dialectReader, getContext ())))
18731889 return failure ();
18741890
1875- opName->opName .emplace ((opName->dialect ->name + " ." + opName->name ).str (),
1876- getContext ());
1877-
1878- // If the op is unregistered now, but was not marked as unregistered, try
1879- // to parse it as a fallback op if the dialect's bytecode interface
1880- // specifies one.
1881- // We don't treat this condition as an error because we may still be able
1882- // to parse the op as an unregistered op if it doesn't use custom
1883- // properties encoding.
1884- if (wasRegistered && !opName->opName ->isRegistered ()) {
1885- if (auto fallbackOp =
1886- opName->dialect ->interface ->getFallbackOperationName ();
1887- succeeded (fallbackOp)) {
1888- opName->opName .emplace (*fallbackOp);
1889- }
1891+ if (useDialectFallback) {
1892+ auto fallbackOp =
1893+ opName->dialect ->interface ->getFallbackOperationName ();
1894+
1895+ // If the dialect doesn't have a fallback operation, we can't parse as
1896+ // instructed.
1897+ if (failed (fallbackOp))
1898+ return failure ();
1899+
1900+ opName->opName .emplace (*fallbackOp);
1901+ } else {
1902+ opName->opName .emplace (
1903+ (opName->dialect ->name + " ." + opName->name ).str (), getContext ());
18901904 }
18911905 }
18921906 }
@@ -2164,10 +2178,27 @@ BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> ®ionStack,
21642178 // Read in the next operation. We don't read its regions directly, we
21652179 // handle those afterwards as necessary.
21662180 bool isIsolatedFromAbove = false ;
2167- FailureOr<Operation *> op =
2168- parseOpWithoutRegions (reader, readState, isIsolatedFromAbove);
2169- if (failed (op))
2170- return failure ();
2181+ FailureOr<Operation *> op;
2182+
2183+ // Parse the bytecode.
2184+ {
2185+ EncodingReader::Snapshot snapshot (reader);
2186+ op = parseOpWithoutRegions (reader, readState, isIsolatedFromAbove,
2187+ /* useDialectFallback=*/ false );
2188+
2189+ // If reading fails, try parsing the op again as a dialect fallback
2190+ // op (if supported).
2191+ if (failed (op)) {
2192+ snapshot.rewind ();
2193+ op = parseOpWithoutRegions (reader, readState, isIsolatedFromAbove,
2194+ /* useDialectFallback=*/ true );
2195+ }
2196+
2197+ // If the dialect doesn't have a fallback op, or parsing as a fallback
2198+ // op fails, we can no longer continue.
2199+ if (failed (op))
2200+ return failure ();
2201+ }
21712202
21722203 // If the op has regions, add it to the stack for processing and return:
21732204 // we stop the processing of the current region and resume it after the
@@ -2229,14 +2260,13 @@ BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> ®ionStack,
22292260 return success ();
22302261}
22312262
2232- FailureOr<Operation *>
2233- BytecodeReader::Impl::parseOpWithoutRegions (EncodingReader &reader,
2234- RegionReadState &readState,
2235- bool &isIsolatedFromAbove) {
2263+ FailureOr<Operation *> BytecodeReader::Impl::parseOpWithoutRegions (
2264+ EncodingReader &reader, RegionReadState &readState,
2265+ bool &isIsolatedFromAbove, bool useDialectFallback) {
22362266 // Parse the name of the operation.
22372267 std::optional<bool > wasRegistered;
22382268 FailureOr<BytecodeOperationName *> bytecodeOp =
2239- parseOpName (reader, wasRegistered);
2269+ parseOpName (reader, wasRegistered, useDialectFallback );
22402270 if (failed (bytecodeOp))
22412271 return failure ();
22422272 auto opName = (*bytecodeOp)->opName ;
0 commit comments