-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[mlir][RFC] Bytecode: op fallback path #129784
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir-core Author: Nikhil Kalra (nikalra) ChangesBytecode dialect versioning today is built around the premise that MLIR bytecode is used as a long-term storage format: specifically, that the producing process has knowledge about the version of dialect to target, and is able to downgrade the dialect prior to serialization. If it cannot, this is an error at serialization time. This poses a problem if bytecode is used as an exchange format between distributed compilers with different versions of the dialect. If a given module references an operation that the receiving process doesn't know about, or utilizes a newer version of the op that the receiving process doesn't have support for, the receiving process will fail to parse the bytecode in its entirety, regardless of whether or not the failing operation is relevant to the receiving process. This proposal adds a fallback mechanism for dialects to construct an operation that maintains the semantics of the unknown operation, while supporting roundtrip bytecode serialization in a bitwise exact manner. Specifically, the flow changes to the following:
Existing options considered:
Patch is 23.99 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/129784.diff 10 Files Affected:
diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
index 0ddc531073e23..36fa010f7e11e 100644
--- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h
+++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
@@ -21,6 +21,7 @@
#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Twine.h"
+#include "llvm/Support/ErrorHandling.h"
namespace mlir {
//===--------------------------------------------------------------------===//
@@ -445,6 +446,14 @@ class BytecodeDialectInterface
return Type();
}
+ /// Fall back to an operation of this type if parsing an op from bytecode
+ /// fails for any reason. This can be used to handle new ops emitted from a
+ /// different version of the dialect, that cannot be read by an older version
+ /// of the dialect.
+ virtual FailureOr<OperationName> getFallbackOperationName() const {
+ return failure();
+ }
+
//===--------------------------------------------------------------------===//
// Writing
//===--------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Bytecode/BytecodeOpInterface.td b/mlir/include/mlir/Bytecode/BytecodeOpInterface.td
index 54fb03e34ec51..87ba27ad6ac27 100644
--- a/mlir/include/mlir/Bytecode/BytecodeOpInterface.td
+++ b/mlir/include/mlir/Bytecode/BytecodeOpInterface.td
@@ -40,4 +40,28 @@ def BytecodeOpInterface : OpInterface<"BytecodeOpInterface"> {
];
}
+// `FallbackBytecodeOpInterface`
+def FallbackBytecodeOpInterface : OpInterface<"FallbackBytecodeOpInterface"> {
+ let description = [{
+ This interface allows fallback operations sideband access to the
+ original operation's intrinsic details.
+ }];
+ let cppNamespace = "::mlir";
+
+ let methods = [
+ StaticInterfaceMethod<[{
+ Set the original name for this operation from the bytecode.
+ }],
+ "void", "setOriginalOperationName", (ins
+ "const ::mlir::Twine&":$opName,
+ "::mlir::OperationState &":$state)
+ >,
+ InterfaceMethod<[{
+ Get the original name for this operation from the bytecode.
+ }],
+ "::mlir::StringRef", "getOriginalOperationName", (ins)
+ >
+ ];
+}
+
#endif // MLIR_BYTECODE_BYTECODEOPINTERFACES
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 1204f1c069b1e..64fcc4ed7c6dc 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -21,6 +21,7 @@
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/Twine.h"
#include "llvm/Support/Endian.h"
#include "llvm/Support/MemoryBufferRef.h"
#include "llvm/Support/SourceMgr.h"
@@ -292,6 +293,16 @@ class EncodingReader {
Location getLoc() const { return fileLoc; }
+ /// Snapshot the location of the BytecodeReader so that parsing can be rewound
+ /// if needed.
+ struct Snapshot {
+ EncodingReader &reader;
+ const uint8_t *dataIt;
+
+ Snapshot(EncodingReader &reader) : reader(reader), dataIt(reader.dataIt) {}
+ void rewind() { reader.dataIt = dataIt; }
+ };
+
private:
/// Parse a variable length encoded integer from the byte stream. This method
/// is a fallback when the number of bytes used to encode the value is greater
@@ -1410,8 +1421,9 @@ class mlir::BytecodeReader::Impl {
/// Parse an operation name reference using the given reader, and set the
/// `wasRegistered` flag that indicates if the bytecode was produced by a
/// context where opName was registered.
- FailureOr<OperationName> parseOpName(EncodingReader &reader,
- std::optional<bool> &wasRegistered);
+ FailureOr<BytecodeOperationName *>
+ parseOpName(EncodingReader &reader, std::optional<bool> &wasRegistered,
+ bool useDialectFallback);
//===--------------------------------------------------------------------===//
// Attribute/Type Section
@@ -1476,7 +1488,8 @@ class mlir::BytecodeReader::Impl {
RegionReadState &readState);
FailureOr<Operation *> parseOpWithoutRegions(EncodingReader &reader,
RegionReadState &readState,
- bool &isIsolatedFromAbove);
+ bool &isIsolatedFromAbove,
+ bool useDialectFallback);
LogicalResult parseRegion(RegionReadState &readState);
LogicalResult parseBlockHeader(EncodingReader &reader,
@@ -1506,7 +1519,7 @@ class mlir::BytecodeReader::Impl {
UseListOrderStorage(bool isIndexPairEncoding,
SmallVector<unsigned, 4> &&indices)
: indices(std::move(indices)),
- isIndexPairEncoding(isIndexPairEncoding){};
+ isIndexPairEncoding(isIndexPairEncoding) {};
/// The vector containing the information required to reorder the
/// use-list of a value.
SmallVector<unsigned, 4> indices;
@@ -1843,16 +1856,20 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
return success();
}
-FailureOr<OperationName>
+FailureOr<BytecodeOperationName *>
BytecodeReader::Impl::parseOpName(EncodingReader &reader,
- std::optional<bool> &wasRegistered) {
+ std::optional<bool> &wasRegistered,
+ bool useDialectFallback) {
BytecodeOperationName *opName = nullptr;
if (failed(parseEntry(reader, opNames, opName, "operation name")))
return failure();
wasRegistered = opName->wasRegistered;
// Check to see if this operation name has already been resolved. If we
// haven't, load the dialect and build the operation name.
- if (!opName->opName) {
+ // If `useDialectFallback`, it's likely that parsing previously failed. We'll
+ // need to reset any previously resolved OperationName with that of the
+ // fallback op.
+ if (!opName->opName || useDialectFallback) {
// If the opName is empty, this is because we use to accept names such as
// `foo` without any `.` separator. We shouldn't tolerate this in textual
// format anymore but for now we'll be backward compatible. This can only
@@ -1865,11 +1882,26 @@ BytecodeReader::Impl::parseOpName(EncodingReader &reader,
dialectsMap, reader, version);
if (failed(opName->dialect->load(dialectReader, getContext())))
return failure();
- opName->opName.emplace((opName->dialect->name + "." + opName->name).str(),
- getContext());
+
+ const BytecodeDialectInterface *dialectIface = opName->dialect->interface;
+ if (useDialectFallback) {
+ FailureOr<OperationName> fallbackOp =
+ dialectIface ? dialectIface->getFallbackOperationName()
+ : FailureOr<OperationName>{};
+
+ // If the dialect doesn't have a fallback operation, we can't parse as
+ // instructed.
+ if (failed(fallbackOp))
+ return failure();
+
+ opName->opName.emplace(*fallbackOp);
+ } else {
+ opName->opName.emplace(
+ (opName->dialect->name + "." + opName->name).str(), getContext());
+ }
}
}
- return *opName->opName;
+ return opName;
}
//===----------------------------------------------------------------------===//
@@ -2143,10 +2175,30 @@ BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> ®ionStack,
// Read in the next operation. We don't read its regions directly, we
// handle those afterwards as necessary.
bool isIsolatedFromAbove = false;
- FailureOr<Operation *> op =
- parseOpWithoutRegions(reader, readState, isIsolatedFromAbove);
- if (failed(op))
- return failure();
+ FailureOr<Operation *> op;
+
+ // Parse the bytecode.
+ {
+ // If the op is registered (and serialized in a compatible manner), or
+ // unregistered but uses standard properties encoding, parsing without
+ // going through the fallback path should work.
+ EncodingReader::Snapshot snapshot(reader);
+ op = parseOpWithoutRegions(reader, readState, isIsolatedFromAbove,
+ /*useDialectFallback=*/false);
+
+ // If reading fails, try parsing the op again as a dialect fallback
+ // op (if supported).
+ if (failed(op)) {
+ snapshot.rewind();
+ op = parseOpWithoutRegions(reader, readState, isIsolatedFromAbove,
+ /*useDialectFallback=*/true);
+ }
+
+ // If the dialect doesn't have a fallback op, or parsing as a fallback
+ // op fails, we can no longer continue.
+ if (failed(op))
+ return failure();
+ }
// If the op has regions, add it to the stack for processing and return:
// we stop the processing of the current region and resume it after the
@@ -2208,14 +2260,17 @@ BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> ®ionStack,
return success();
}
-FailureOr<Operation *>
-BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
- RegionReadState &readState,
- bool &isIsolatedFromAbove) {
+FailureOr<Operation *> BytecodeReader::Impl::parseOpWithoutRegions(
+ EncodingReader &reader, RegionReadState &readState,
+ bool &isIsolatedFromAbove, bool useDialectFallback) {
// Parse the name of the operation.
std::optional<bool> wasRegistered;
- FailureOr<OperationName> opName = parseOpName(reader, wasRegistered);
- if (failed(opName))
+ FailureOr<BytecodeOperationName *> bytecodeOp =
+ parseOpName(reader, wasRegistered, useDialectFallback);
+ if (failed(bytecodeOp))
+ return failure();
+ auto opName = (*bytecodeOp)->opName;
+ if (!opName)
return failure();
// Parse the operation mask, which indicates which components of the operation
@@ -2232,6 +2287,12 @@ BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
// With the location and name resolved, we can start building the operation
// state.
OperationState opState(opLoc, *opName);
+ // If this is a fallback op, provide the original name of the operation.
+ if (auto *iface = opName->getInterface<FallbackBytecodeOpInterface>()) {
+ const Twine originalName =
+ opName->getDialect()->getNamespace() + "." + (*bytecodeOp)->name;
+ iface->setOriginalOperationName(originalName, opState);
+ }
// Parse the attributes of the operation.
if (opMask & bytecode::OpEncodingMask::kHasAttrs) {
diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
index cc5aaed416512..526dfb3654492 100644
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -841,12 +841,12 @@ void BytecodeWriter::writeDialectSection(EncodingEmitter &emitter) {
// Emit the referenced operation names grouped by dialect.
auto emitOpName = [&](OpNameNumbering &name) {
+ const bool isKnownOp = name.isOpaqueEntry || name.name.isRegistered();
size_t stringId = stringSection.insert(name.name.stripDialect());
if (config.bytecodeVersion < bytecode::kNativePropertiesEncoding)
dialectEmitter.emitVarInt(stringId, "dialect op name");
else
- dialectEmitter.emitVarIntWithFlag(stringId, name.name.isRegistered(),
- "dialect op name");
+ dialectEmitter.emitVarIntWithFlag(stringId, isKnownOp, "dialect op name");
};
writeDialectGrouping(dialectEmitter, numberingState.getOpNames(), emitOpName);
@@ -984,7 +984,14 @@ LogicalResult BytecodeWriter::writeBlock(EncodingEmitter &emitter,
}
LogicalResult BytecodeWriter::writeOp(EncodingEmitter &emitter, Operation *op) {
- emitter.emitVarInt(numberingState.getNumber(op->getName()), "op name ID");
+ OperationName opName = op->getName();
+ // For fallback ops, create a new operation name referencing the original op
+ // instead.
+ if (auto fallback = dyn_cast<FallbackBytecodeOpInterface>(op))
+ opName =
+ OperationName(fallback.getOriginalOperationName(), op->getContext());
+
+ emitter.emitVarInt(numberingState.getNumber(opName), "op name ID");
// Emit a mask for the operation components. We need to fill this in later
// (when we actually know what needs to be emitted), so emit a placeholder for
diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
index 1bc02e1721573..60bc6bd5170c5 100644
--- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
@@ -419,7 +419,16 @@ void IRNumberingState::number(Region ®ion) {
void IRNumberingState::number(Operation &op) {
// Number the components of an operation that won't be numbered elsewhere
// (e.g. we don't number operands, regions, or successors here).
- number(op.getName());
+
+ // For fallback ops, create a new OperationName referencing the original op
+ // instead.
+ if (auto fallback = dyn_cast<FallbackBytecodeOpInterface>(op)) {
+ OperationName opName(fallback.getOriginalOperationName(), op.getContext());
+ number(opName, /*isOpaque=*/true);
+ } else {
+ number(op.getName(), /*isOpaque=*/false);
+ }
+
for (OpResult result : op.getResults()) {
valueIDs.try_emplace(result, nextValueID++);
number(result.getType());
@@ -457,7 +466,7 @@ void IRNumberingState::number(Operation &op) {
number(op.getLoc());
}
-void IRNumberingState::number(OperationName opName) {
+void IRNumberingState::number(OperationName opName, bool isOpaque) {
OpNameNumbering *&numbering = opNames[opName];
if (numbering) {
++numbering->refCount;
@@ -469,8 +478,8 @@ void IRNumberingState::number(OperationName opName) {
else
dialectNumber = &numberDialect(opName.getDialectNamespace());
- numbering =
- new (opNameAllocator.Allocate()) OpNameNumbering(dialectNumber, opName);
+ numbering = new (opNameAllocator.Allocate())
+ OpNameNumbering(dialectNumber, opName, isOpaque);
orderedOpNames.push_back(numbering);
}
diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.h b/mlir/lib/Bytecode/Writer/IRNumbering.h
index 9b7ac0d3688e3..033b3771b46a3 100644
--- a/mlir/lib/Bytecode/Writer/IRNumbering.h
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.h
@@ -63,8 +63,8 @@ struct TypeNumbering : public AttrTypeNumbering {
/// This class represents the numbering entry of an operation name.
struct OpNameNumbering {
- OpNameNumbering(DialectNumbering *dialect, OperationName name)
- : dialect(dialect), name(name) {}
+ OpNameNumbering(DialectNumbering *dialect, OperationName name, bool isOpaque)
+ : dialect(dialect), name(name), isOpaqueEntry(isOpaque) {}
/// The dialect of this value.
DialectNumbering *dialect;
@@ -72,6 +72,9 @@ struct OpNameNumbering {
/// The concrete name.
OperationName name;
+ /// This entry represents an opaque operation entry.
+ bool isOpaqueEntry = false;
+
/// The number assigned to this name.
unsigned number = 0;
@@ -210,7 +213,7 @@ class IRNumberingState {
/// Get the set desired bytecode version to emit.
int64_t getDesiredBytecodeVersion() const;
-
+
private:
/// This class is used to provide a fake dialect writer for numbering nested
/// attributes and types.
@@ -225,7 +228,7 @@ class IRNumberingState {
DialectNumbering &numberDialect(Dialect *dialect);
DialectNumbering &numberDialect(StringRef dialect);
void number(Operation &op);
- void number(OperationName opName);
+ void number(OperationName opName, bool isOpaque);
void number(Region ®ion);
void number(Type type);
diff --git a/mlir/test/Bytecode/versioning/versioning-fallback.mlir b/mlir/test/Bytecode/versioning/versioning-fallback.mlir
new file mode 100644
index 0000000000000..a078613360af6
--- /dev/null
+++ b/mlir/test/Bytecode/versioning/versioning-fallback.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-opt %s --emit-bytecode > %T/versioning-fallback.mlirbc
+"test.versionedD"() <{attribute = #test.attr_params<42, 24>}> : () -> ()
+
+// COM: check that versionedD was parsed as a fallback op.
+// RUN: mlir-opt %T/versioning-fallback.mlirbc | FileCheck %s --check-prefix=CHECK-PARSE
+// CHECK-PARSE: test.bytecode.fallback
+// CHECK-PARSE-SAME: opname = "test.versionedD"
+
+// COM: check that the bytecode roundtrip was successful
+// RUN: mlir-opt %T/versioning-fallback.mlirbc --verify-roundtrip
+
+// COM: check that the bytecode roundtrip is bitwise exact
+// RUN: mlir-opt %T/versioning-fallback.mlirbc --emit-bytecode | diff %T/versioning-fallback.mlirbc -
diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
index 64add8cef3698..24e31ec44a85b 100644
--- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
@@ -8,6 +8,7 @@
#include "TestDialect.h"
#include "TestOps.h"
+#include "mlir/IR/OperationSupport.h"
#include "mlir/Interfaces/FoldInterfaces.h"
#include "mlir/Reducer/ReductionPatternInterface.h"
#include "mlir/Transforms/InliningUtils.h"
@@ -92,6 +93,11 @@ struct TestBytecodeDialectInterface : public BytecodeDialectInterface {
return Attribute();
}
+ FailureOr<OperationName> getFallbackOperationName() const final {
+ return OperationName(TestBytecodeFallbackOp::getOperationName(),
+ getContext());
+ }
+
// Emit a specific version of the dialect.
void writeVersion(DialectBytecodeWriter &writer) const final {
// Construct the current dialect version.
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index f6b8a0005f285..77428517f2b12 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -8,10 +8,17 @@
#include "TestDialect.h"
#include "TestOps.h"
+#include "mlir/Bytecode/BytecodeImplementation.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/LogicalResult.h"
+#include <cstdint>
using namespace mlir;
using namespace test;
@@ -1230,6 +1237,106 @@ void TestVersionedOpA::writeProperties(mlir::DialectBytecodeWriter &writer) {
writer.writeAttribute(prop.modifier);
}
+//===----------------------------------------------------------------------===//
+// TestVersionedOpD
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+TestVersionedOpD::readProperties(mlir::DialectBytecodeReader &reader,
+ mlir::OperationState &state) {
+ // Always fail so that this uses the fallback path.
+ return failure();
+}
+
+struct FallbackCompliantPropertiesEncoding {
+ int64_t version;
+ SmallVector<Attribute> requiredAttributes;
+ SmallVector<Attribute> optionalAttributes;
+
+ void writeProperties(DialectBytecodeWriter &writer) const {
+ // Write the op version.
+ writer.writeSignedVarInt(version);
+
+ // Write the required attributes.
+ writer.writeList(requiredAttributes,
+ [&](Attribute attr) { writer.writeAttribute(attr); });
+
+ // Write the optional attributes.
+ writer.writeList(optionalAttributes, [&](Attribute attr) {
+ writer.writeOptionalAttribute(attr);
+ });
+ }
+
+ LogicalResult readProperties(DialectBytecodeReader &reader) {
+ // Read the op version.
+ if (failed(reader.readSignedVarInt(version)))
+ return failure();
+
+ // Read the required attributes.
+ if (failed(reader.readList(requiredAttributes, [&](Attribute &attr) {
+ return reader.readAttribute(attr);
+ })))
+ return failure();
+
+ // Read the optional attributes.
+ if (failed(reader.readList(optionalAttributes, [&](Attribute &attr) {
+ return reader.readOptionalAttribute(attr);
+ ...
[truncated]
|
|
@llvm/pr-subscribers-mlir Author: Nikhil Kalra (nikalra) ChangesBytecode dialect versioning today is built around the premise that MLIR bytecode is used as a long-term storage format: specifically, that the producing process has knowledge about the version of dialect to target, and is able to downgrade the dialect prior to serialization. If it cannot, this is an error at serialization time. This poses a problem if bytecode is used as an exchange format between distributed compilers with different versions of the dialect. If a given module references an operation that the receiving process doesn't know about, or utilizes a newer version of the op that the receiving process doesn't have support for, the receiving process will fail to parse the bytecode in its entirety, regardless of whether or not the failing operation is relevant to the receiving process. This proposal adds a fallback mechanism for dialects to construct an operation that maintains the semantics of the unknown operation, while supporting roundtrip bytecode serialization in a bitwise exact manner. Specifically, the flow changes to the following:
Existing options considered:
Patch is 23.99 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/129784.diff 10 Files Affected:
diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
index 0ddc531073e23..36fa010f7e11e 100644
--- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h
+++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
@@ -21,6 +21,7 @@
#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Twine.h"
+#include "llvm/Support/ErrorHandling.h"
namespace mlir {
//===--------------------------------------------------------------------===//
@@ -445,6 +446,14 @@ class BytecodeDialectInterface
return Type();
}
+ /// Fall back to an operation of this type if parsing an op from bytecode
+ /// fails for any reason. This can be used to handle new ops emitted from a
+ /// different version of the dialect, that cannot be read by an older version
+ /// of the dialect.
+ virtual FailureOr<OperationName> getFallbackOperationName() const {
+ return failure();
+ }
+
//===--------------------------------------------------------------------===//
// Writing
//===--------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Bytecode/BytecodeOpInterface.td b/mlir/include/mlir/Bytecode/BytecodeOpInterface.td
index 54fb03e34ec51..87ba27ad6ac27 100644
--- a/mlir/include/mlir/Bytecode/BytecodeOpInterface.td
+++ b/mlir/include/mlir/Bytecode/BytecodeOpInterface.td
@@ -40,4 +40,28 @@ def BytecodeOpInterface : OpInterface<"BytecodeOpInterface"> {
];
}
+// `FallbackBytecodeOpInterface`
+def FallbackBytecodeOpInterface : OpInterface<"FallbackBytecodeOpInterface"> {
+ let description = [{
+ This interface allows fallback operations sideband access to the
+ original operation's intrinsic details.
+ }];
+ let cppNamespace = "::mlir";
+
+ let methods = [
+ StaticInterfaceMethod<[{
+ Set the original name for this operation from the bytecode.
+ }],
+ "void", "setOriginalOperationName", (ins
+ "const ::mlir::Twine&":$opName,
+ "::mlir::OperationState &":$state)
+ >,
+ InterfaceMethod<[{
+ Get the original name for this operation from the bytecode.
+ }],
+ "::mlir::StringRef", "getOriginalOperationName", (ins)
+ >
+ ];
+}
+
#endif // MLIR_BYTECODE_BYTECODEOPINTERFACES
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 1204f1c069b1e..64fcc4ed7c6dc 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -21,6 +21,7 @@
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/Twine.h"
#include "llvm/Support/Endian.h"
#include "llvm/Support/MemoryBufferRef.h"
#include "llvm/Support/SourceMgr.h"
@@ -292,6 +293,16 @@ class EncodingReader {
Location getLoc() const { return fileLoc; }
+ /// Snapshot the location of the BytecodeReader so that parsing can be rewound
+ /// if needed.
+ struct Snapshot {
+ EncodingReader &reader;
+ const uint8_t *dataIt;
+
+ Snapshot(EncodingReader &reader) : reader(reader), dataIt(reader.dataIt) {}
+ void rewind() { reader.dataIt = dataIt; }
+ };
+
private:
/// Parse a variable length encoded integer from the byte stream. This method
/// is a fallback when the number of bytes used to encode the value is greater
@@ -1410,8 +1421,9 @@ class mlir::BytecodeReader::Impl {
/// Parse an operation name reference using the given reader, and set the
/// `wasRegistered` flag that indicates if the bytecode was produced by a
/// context where opName was registered.
- FailureOr<OperationName> parseOpName(EncodingReader &reader,
- std::optional<bool> &wasRegistered);
+ FailureOr<BytecodeOperationName *>
+ parseOpName(EncodingReader &reader, std::optional<bool> &wasRegistered,
+ bool useDialectFallback);
//===--------------------------------------------------------------------===//
// Attribute/Type Section
@@ -1476,7 +1488,8 @@ class mlir::BytecodeReader::Impl {
RegionReadState &readState);
FailureOr<Operation *> parseOpWithoutRegions(EncodingReader &reader,
RegionReadState &readState,
- bool &isIsolatedFromAbove);
+ bool &isIsolatedFromAbove,
+ bool useDialectFallback);
LogicalResult parseRegion(RegionReadState &readState);
LogicalResult parseBlockHeader(EncodingReader &reader,
@@ -1506,7 +1519,7 @@ class mlir::BytecodeReader::Impl {
UseListOrderStorage(bool isIndexPairEncoding,
SmallVector<unsigned, 4> &&indices)
: indices(std::move(indices)),
- isIndexPairEncoding(isIndexPairEncoding){};
+ isIndexPairEncoding(isIndexPairEncoding) {};
/// The vector containing the information required to reorder the
/// use-list of a value.
SmallVector<unsigned, 4> indices;
@@ -1843,16 +1856,20 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
return success();
}
-FailureOr<OperationName>
+FailureOr<BytecodeOperationName *>
BytecodeReader::Impl::parseOpName(EncodingReader &reader,
- std::optional<bool> &wasRegistered) {
+ std::optional<bool> &wasRegistered,
+ bool useDialectFallback) {
BytecodeOperationName *opName = nullptr;
if (failed(parseEntry(reader, opNames, opName, "operation name")))
return failure();
wasRegistered = opName->wasRegistered;
// Check to see if this operation name has already been resolved. If we
// haven't, load the dialect and build the operation name.
- if (!opName->opName) {
+ // If `useDialectFallback`, it's likely that parsing previously failed. We'll
+ // need to reset any previously resolved OperationName with that of the
+ // fallback op.
+ if (!opName->opName || useDialectFallback) {
// If the opName is empty, this is because we use to accept names such as
// `foo` without any `.` separator. We shouldn't tolerate this in textual
// format anymore but for now we'll be backward compatible. This can only
@@ -1865,11 +1882,26 @@ BytecodeReader::Impl::parseOpName(EncodingReader &reader,
dialectsMap, reader, version);
if (failed(opName->dialect->load(dialectReader, getContext())))
return failure();
- opName->opName.emplace((opName->dialect->name + "." + opName->name).str(),
- getContext());
+
+ const BytecodeDialectInterface *dialectIface = opName->dialect->interface;
+ if (useDialectFallback) {
+ FailureOr<OperationName> fallbackOp =
+ dialectIface ? dialectIface->getFallbackOperationName()
+ : FailureOr<OperationName>{};
+
+ // If the dialect doesn't have a fallback operation, we can't parse as
+ // instructed.
+ if (failed(fallbackOp))
+ return failure();
+
+ opName->opName.emplace(*fallbackOp);
+ } else {
+ opName->opName.emplace(
+ (opName->dialect->name + "." + opName->name).str(), getContext());
+ }
}
}
- return *opName->opName;
+ return opName;
}
//===----------------------------------------------------------------------===//
@@ -2143,10 +2175,30 @@ BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> ®ionStack,
// Read in the next operation. We don't read its regions directly, we
// handle those afterwards as necessary.
bool isIsolatedFromAbove = false;
- FailureOr<Operation *> op =
- parseOpWithoutRegions(reader, readState, isIsolatedFromAbove);
- if (failed(op))
- return failure();
+ FailureOr<Operation *> op;
+
+ // Parse the bytecode.
+ {
+ // If the op is registered (and serialized in a compatible manner), or
+ // unregistered but uses standard properties encoding, parsing without
+ // going through the fallback path should work.
+ EncodingReader::Snapshot snapshot(reader);
+ op = parseOpWithoutRegions(reader, readState, isIsolatedFromAbove,
+ /*useDialectFallback=*/false);
+
+ // If reading fails, try parsing the op again as a dialect fallback
+ // op (if supported).
+ if (failed(op)) {
+ snapshot.rewind();
+ op = parseOpWithoutRegions(reader, readState, isIsolatedFromAbove,
+ /*useDialectFallback=*/true);
+ }
+
+ // If the dialect doesn't have a fallback op, or parsing as a fallback
+ // op fails, we can no longer continue.
+ if (failed(op))
+ return failure();
+ }
// If the op has regions, add it to the stack for processing and return:
// we stop the processing of the current region and resume it after the
@@ -2208,14 +2260,17 @@ BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> ®ionStack,
return success();
}
-FailureOr<Operation *>
-BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
- RegionReadState &readState,
- bool &isIsolatedFromAbove) {
+FailureOr<Operation *> BytecodeReader::Impl::parseOpWithoutRegions(
+ EncodingReader &reader, RegionReadState &readState,
+ bool &isIsolatedFromAbove, bool useDialectFallback) {
// Parse the name of the operation.
std::optional<bool> wasRegistered;
- FailureOr<OperationName> opName = parseOpName(reader, wasRegistered);
- if (failed(opName))
+ FailureOr<BytecodeOperationName *> bytecodeOp =
+ parseOpName(reader, wasRegistered, useDialectFallback);
+ if (failed(bytecodeOp))
+ return failure();
+ auto opName = (*bytecodeOp)->opName;
+ if (!opName)
return failure();
// Parse the operation mask, which indicates which components of the operation
@@ -2232,6 +2287,12 @@ BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
// With the location and name resolved, we can start building the operation
// state.
OperationState opState(opLoc, *opName);
+ // If this is a fallback op, provide the original name of the operation.
+ if (auto *iface = opName->getInterface<FallbackBytecodeOpInterface>()) {
+ const Twine originalName =
+ opName->getDialect()->getNamespace() + "." + (*bytecodeOp)->name;
+ iface->setOriginalOperationName(originalName, opState);
+ }
// Parse the attributes of the operation.
if (opMask & bytecode::OpEncodingMask::kHasAttrs) {
diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
index cc5aaed416512..526dfb3654492 100644
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -841,12 +841,12 @@ void BytecodeWriter::writeDialectSection(EncodingEmitter &emitter) {
// Emit the referenced operation names grouped by dialect.
auto emitOpName = [&](OpNameNumbering &name) {
+ const bool isKnownOp = name.isOpaqueEntry || name.name.isRegistered();
size_t stringId = stringSection.insert(name.name.stripDialect());
if (config.bytecodeVersion < bytecode::kNativePropertiesEncoding)
dialectEmitter.emitVarInt(stringId, "dialect op name");
else
- dialectEmitter.emitVarIntWithFlag(stringId, name.name.isRegistered(),
- "dialect op name");
+ dialectEmitter.emitVarIntWithFlag(stringId, isKnownOp, "dialect op name");
};
writeDialectGrouping(dialectEmitter, numberingState.getOpNames(), emitOpName);
@@ -984,7 +984,14 @@ LogicalResult BytecodeWriter::writeBlock(EncodingEmitter &emitter,
}
LogicalResult BytecodeWriter::writeOp(EncodingEmitter &emitter, Operation *op) {
- emitter.emitVarInt(numberingState.getNumber(op->getName()), "op name ID");
+ OperationName opName = op->getName();
+ // For fallback ops, create a new operation name referencing the original op
+ // instead.
+ if (auto fallback = dyn_cast<FallbackBytecodeOpInterface>(op))
+ opName =
+ OperationName(fallback.getOriginalOperationName(), op->getContext());
+
+ emitter.emitVarInt(numberingState.getNumber(opName), "op name ID");
// Emit a mask for the operation components. We need to fill this in later
// (when we actually know what needs to be emitted), so emit a placeholder for
diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
index 1bc02e1721573..60bc6bd5170c5 100644
--- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
@@ -419,7 +419,16 @@ void IRNumberingState::number(Region ®ion) {
void IRNumberingState::number(Operation &op) {
// Number the components of an operation that won't be numbered elsewhere
// (e.g. we don't number operands, regions, or successors here).
- number(op.getName());
+
+ // For fallback ops, create a new OperationName referencing the original op
+ // instead.
+ if (auto fallback = dyn_cast<FallbackBytecodeOpInterface>(op)) {
+ OperationName opName(fallback.getOriginalOperationName(), op.getContext());
+ number(opName, /*isOpaque=*/true);
+ } else {
+ number(op.getName(), /*isOpaque=*/false);
+ }
+
for (OpResult result : op.getResults()) {
valueIDs.try_emplace(result, nextValueID++);
number(result.getType());
@@ -457,7 +466,7 @@ void IRNumberingState::number(Operation &op) {
number(op.getLoc());
}
-void IRNumberingState::number(OperationName opName) {
+void IRNumberingState::number(OperationName opName, bool isOpaque) {
OpNameNumbering *&numbering = opNames[opName];
if (numbering) {
++numbering->refCount;
@@ -469,8 +478,8 @@ void IRNumberingState::number(OperationName opName) {
else
dialectNumber = &numberDialect(opName.getDialectNamespace());
- numbering =
- new (opNameAllocator.Allocate()) OpNameNumbering(dialectNumber, opName);
+ numbering = new (opNameAllocator.Allocate())
+ OpNameNumbering(dialectNumber, opName, isOpaque);
orderedOpNames.push_back(numbering);
}
diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.h b/mlir/lib/Bytecode/Writer/IRNumbering.h
index 9b7ac0d3688e3..033b3771b46a3 100644
--- a/mlir/lib/Bytecode/Writer/IRNumbering.h
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.h
@@ -63,8 +63,8 @@ struct TypeNumbering : public AttrTypeNumbering {
/// This class represents the numbering entry of an operation name.
struct OpNameNumbering {
- OpNameNumbering(DialectNumbering *dialect, OperationName name)
- : dialect(dialect), name(name) {}
+ OpNameNumbering(DialectNumbering *dialect, OperationName name, bool isOpaque)
+ : dialect(dialect), name(name), isOpaqueEntry(isOpaque) {}
/// The dialect of this value.
DialectNumbering *dialect;
@@ -72,6 +72,9 @@ struct OpNameNumbering {
/// The concrete name.
OperationName name;
+ /// This entry represents an opaque operation entry.
+ bool isOpaqueEntry = false;
+
/// The number assigned to this name.
unsigned number = 0;
@@ -210,7 +213,7 @@ class IRNumberingState {
/// Get the set desired bytecode version to emit.
int64_t getDesiredBytecodeVersion() const;
-
+
private:
/// This class is used to provide a fake dialect writer for numbering nested
/// attributes and types.
@@ -225,7 +228,7 @@ class IRNumberingState {
DialectNumbering &numberDialect(Dialect *dialect);
DialectNumbering &numberDialect(StringRef dialect);
void number(Operation &op);
- void number(OperationName opName);
+ void number(OperationName opName, bool isOpaque);
void number(Region ®ion);
void number(Type type);
diff --git a/mlir/test/Bytecode/versioning/versioning-fallback.mlir b/mlir/test/Bytecode/versioning/versioning-fallback.mlir
new file mode 100644
index 0000000000000..a078613360af6
--- /dev/null
+++ b/mlir/test/Bytecode/versioning/versioning-fallback.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-opt %s --emit-bytecode > %T/versioning-fallback.mlirbc
+"test.versionedD"() <{attribute = #test.attr_params<42, 24>}> : () -> ()
+
+// COM: check that versionedD was parsed as a fallback op.
+// RUN: mlir-opt %T/versioning-fallback.mlirbc | FileCheck %s --check-prefix=CHECK-PARSE
+// CHECK-PARSE: test.bytecode.fallback
+// CHECK-PARSE-SAME: opname = "test.versionedD"
+
+// COM: check that the bytecode roundtrip was successful
+// RUN: mlir-opt %T/versioning-fallback.mlirbc --verify-roundtrip
+
+// COM: check that the bytecode roundtrip is bitwise exact
+// RUN: mlir-opt %T/versioning-fallback.mlirbc --emit-bytecode | diff %T/versioning-fallback.mlirbc -
diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
index 64add8cef3698..24e31ec44a85b 100644
--- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
@@ -8,6 +8,7 @@
#include "TestDialect.h"
#include "TestOps.h"
+#include "mlir/IR/OperationSupport.h"
#include "mlir/Interfaces/FoldInterfaces.h"
#include "mlir/Reducer/ReductionPatternInterface.h"
#include "mlir/Transforms/InliningUtils.h"
@@ -92,6 +93,11 @@ struct TestBytecodeDialectInterface : public BytecodeDialectInterface {
return Attribute();
}
+ FailureOr<OperationName> getFallbackOperationName() const final {
+ return OperationName(TestBytecodeFallbackOp::getOperationName(),
+ getContext());
+ }
+
// Emit a specific version of the dialect.
void writeVersion(DialectBytecodeWriter &writer) const final {
// Construct the current dialect version.
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index f6b8a0005f285..77428517f2b12 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -8,10 +8,17 @@
#include "TestDialect.h"
#include "TestOps.h"
+#include "mlir/Bytecode/BytecodeImplementation.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/LogicalResult.h"
+#include <cstdint>
using namespace mlir;
using namespace test;
@@ -1230,6 +1237,106 @@ void TestVersionedOpA::writeProperties(mlir::DialectBytecodeWriter &writer) {
writer.writeAttribute(prop.modifier);
}
+//===----------------------------------------------------------------------===//
+// TestVersionedOpD
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+TestVersionedOpD::readProperties(mlir::DialectBytecodeReader &reader,
+ mlir::OperationState &state) {
+ // Always fail so that this uses the fallback path.
+ return failure();
+}
+
+struct FallbackCompliantPropertiesEncoding {
+ int64_t version;
+ SmallVector<Attribute> requiredAttributes;
+ SmallVector<Attribute> optionalAttributes;
+
+ void writeProperties(DialectBytecodeWriter &writer) const {
+ // Write the op version.
+ writer.writeSignedVarInt(version);
+
+ // Write the required attributes.
+ writer.writeList(requiredAttributes,
+ [&](Attribute attr) { writer.writeAttribute(attr); });
+
+ // Write the optional attributes.
+ writer.writeList(optionalAttributes, [&](Attribute attr) {
+ writer.writeOptionalAttribute(attr);
+ });
+ }
+
+ LogicalResult readProperties(DialectBytecodeReader &reader) {
+ // Read the op version.
+ if (failed(reader.readSignedVarInt(version)))
+ return failure();
+
+ // Read the required attributes.
+ if (failed(reader.readList(requiredAttributes, [&](Attribute &attr) {
+ return reader.readAttribute(attr);
+ })))
+ return failure();
+
+ // Read the optional attributes.
+ if (failed(reader.readList(optionalAttributes, [&](Attribute &attr) {
+ return reader.readOptionalAttribute(attr);
+ ...
[truncated]
|
|
Maybe i am missing something, but if an op is not supported on an "older" version of dialect, how the fallback can guarantee that the semantic is preserved? It seems to me that the only point where this can be guaranteed is at serialization. Downgrade before serialization is already available and it can be implemented by downstream dialects as a legalization pass before producing bytecode. If the "unsupported" op is not semantically relevant on a given target, why this cannot be handled at serialization? |
The fallback can't guarantee that operation semantics are preserved, but it can guarantee that program semantics are preserved. Consider a module where 99% of the operations are recognized by the receiving process, but 1 operation is not. Using a strict dialect versioning scheme, the receiving process would fail to parse the module. Instead, the arbitrary receiving process should be able to parse the module in a way that it can recognize whichever ops it does support, perform downstream work on those ops, and faithfully encode the ops it did not recognize back into bytecode while still maintaining the original structure of the module.
We considered this approach, but it ends up getting pretty expensive. A full fledged downgrade (i.e. representing the unsupported op in a supported manner) may change the semantics of the program in a way that is detrimental to the performance of the system, by forcing the different compilers to operate on the oldest supported program semantics. This could have further downstream effects on partitioning, etc. Consider STFT: STFT can be decomposed into matmul/conv, but that has worse complexity than native STFT. If a partition is made in the middle of the decomposition, it's impossible to recreate the original semantics even on a version of the dialect that does provide STFT. If we apply the fallback approach at serialization time, we end up with effectively the same resultant module representation as proposed by this RFC; nevertheless, it requires the host compiler to query for the supported dialect version from each of the receiving processes, rewrite unsupported ops as "fallback" ops, serialize to bytecode and do the same in reverse when receiving bytecode back from those processes. This RFC enables processes in this setup to avoid that overhead since everything is faithfully represented the same way in bytecode. |
|
Overall the idea is interesting. Is there an RFC on discourse? Some more questions/comments below.
Isn't this somewhat equivalent to parsing unregistered/unverifiable ops? It seems you are basically trying to read and partition the ops that are registered/verified by the dialect. How the fallback would extend to types/attributes for the same versioned dialect?
Wouldn't this indicate a pass ordering problem, rather than a serialization problem?
Have you considered using lazy loading + multiple serializations of the same top level func for different targets? |
I'll create one: https://discourse.llvm.org/t/rfc-bytecode-op-fallback-path/84993
Yes, but that the ops are registered at serialization time. The problem with using unregistered ops directly is that they encode properties as attributes, which breaks down for versioned ops that utilize custom properties encoding to support reading/writing older versions. There's another option here which involves converting newer ops to the fallback op pre-serialization. It would still require determining the dialect version used by the receiving process, but would turn this into a serialization problem vs a deserialization problem.
Types/attributes are a little easier because they can be treated as OpaqueType/OpaqueAttr as long as the dialect has an encoding scheme that supports round-tripping between them.
True, unless bytecode is used as a transfer medium inside of a pass. In that case, changing the semantics of the module would end up changing the results of the pass. The alternative to using bytecode would be to create a sideband encoding scheme for operations and constants to be communicated to an external system, but that would effectively duplicate the capabilities provided by bytecode today.
Yeah, but it ends up getting pretty complicated: the module may have many top-level funcs, and reconciling the different versions together at the end is non-trivial due to the number of permutations that exist. |
|
I took a look over and I have a few concerns about the approach here. This seems to be trying to handle a very specific situation, but I don't think it's fully correct. The op not being understood is one aspect, but you can't directly grab data from the bytecode and rely on that being able to roundtrip (the bytecode has references to different parts of itself that you won't be able to replicate and understand). The easiest example here is what to do about new attributes/types that weren't present in an older dialect, if that attribute/type uses a custom encoding you can't take the bytecode for it directly and roundtrip it. I also don't want to encorage trying to make expectations on the structure of bytecode or how things are referenced. Consider a new container attribute |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mark
Agreed, but unless I'm misunderstanding what you're saying, I don't think this proposal does that. Specifically, the type and attribute sections (which I believe is what you're referring to) still need to get created at time of writing. Rather, this proposal just creates an additional path during parsing where the dialect has an opportunity to deserialize an op that cannot be deserialized using the default process.
Agreed, but the BytcodeDialectInterface leaves attribute and type parsing up to the dialect to do in an implementation defined manner. That's unchanged with this proposal, and would be up to the dialect to handle that complexity.
I don't think I'm following this -- In the same way that the current bytecode interfaces allow a dialect to define ser/de for an Attribute and Type, this proposal adds a path for a dialect to define custom deserialization for an Op that cannot be parsed using the default mechanisms. While the proposed use case is for passing through new ops, the infrastructure still provides a fallback path for parsing existing registered ops that could not be deserialized. The dialect can opt-into this behavior but is not required to. The fallback path is also not required to succeed if the dialect cannot interpret the encoded contents. Considering the new container attribute, it would be on the dialect to encode the container in a way that the nested attributes can be serialized/deserialized using a dialect-defined encoding scheme—in the same way as would be required for the fallback op to successfully deserialize an unknown op. 8d8ac5f adds an example of what this could look like, but isn't meant to be prescriptive on what that encoding should look like (nor does it support all cases). It might be worth mentioning that this proposal doesn't create a passthrough mechanism for bytecode to be round-tripped; it just creates a mechanism for the dialect to round-trip unknown ops if it needs to. The onus is still on the dialect to get that right, as it is for any of the custom dialect representations that are currently enabled for types/attributes/properties. In other words, the dialect bytecode interface already gives dialects the ability to replace one type or attribute with another during deserialization; this does the same for ops. |
|
Can you describe how does the round-trip work when the properties contains types/attributes?
But I'm not sure to follow... |
This isn't a fully complete example, but shows what a dialect-wide properties encoding scheme could look like:
Then, compliant attributes/types would also have to use a similar encoding scheme:
|
|
Seems like it implies serializing the properties as attributes?
You could have a flag on the serialization that always convert properties to an attribute. |
Serializing as attributes with a flag would work, but that'd still require implementing a hook for an operation to read encoded attributes from the serialized file that may not exist on the newest version of the op, or any of the other flexibilities that using the custom encoding hook provides. But the encoding scheme also doesn't need to imply that properties may only be serialized as attributes; ops can still use a custom scheme for writing out properties as long as this scheme is not unique to a single op. For example, it'd still work if a dialect chose to serialize properties as JSON, Flatbuffer, or some other custom format that can be roundtripped. |
|
@joker-eph @River707 pinging for feedback :) |
Bytecode dialect versioning today is built around the premise that MLIR bytecode is used as a long-term storage format: specifically, that the producing process has knowledge about the version of dialect to target, and is able to downgrade the dialect prior to serialization. If it cannot, this is an error at serialization time.
This poses a problem if bytecode is used as an exchange format between distributed compilers with different versions of the dialect. If a given module references an operation that the receiving process doesn't know about, or utilizes a newer version of the op that the receiving process doesn't have support for, the receiving process will fail to parse the bytecode in its entirety, regardless of whether or not the failing operation is relevant to the receiving process.
This proposal adds a fallback mechanism for dialects to construct an operation that maintains the semantics of the unknown operation, while supporting roundtrip bytecode serialization in a bitwise exact manner. Specifically, the flow changes to the following:
Existing options considered: