Skip to content
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions mlir/include/mlir/Bytecode/BytecodeImplementation.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/ErrorHandling.h"

namespace mlir {
//===--------------------------------------------------------------------===//
Expand Down Expand Up @@ -445,6 +446,14 @@ class BytecodeDialectInterface
return Type();
}

/// Fall back to an operation of this type if parsing an op from bytecode
/// fails for any reason. This can be used to handle new ops emitted from a
/// different version of the dialect, that cannot be read by an older version
/// of the dialect.
virtual FailureOr<OperationName> getFallbackOperationName() const {
return failure();
}

//===--------------------------------------------------------------------===//
// Writing
//===--------------------------------------------------------------------===//
Expand Down
24 changes: 24 additions & 0 deletions mlir/include/mlir/Bytecode/BytecodeOpInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,28 @@ def BytecodeOpInterface : OpInterface<"BytecodeOpInterface"> {
];
}

// `FallbackBytecodeOpInterface`
def FallbackBytecodeOpInterface : OpInterface<"FallbackBytecodeOpInterface"> {
let description = [{
This interface allows fallback operations sideband access to the
original operation's intrinsic details.
}];
let cppNamespace = "::mlir";

let methods = [
StaticInterfaceMethod<[{
Set the original name for this operation from the bytecode.
}],
"void", "setOriginalOperationName", (ins
"const ::mlir::Twine&":$opName,
"::mlir::OperationState &":$state)
>,
InterfaceMethod<[{
Get the original name for this operation from the bytecode.
}],
"::mlir::StringRef", "getOriginalOperationName", (ins)
>
];
}

#endif // MLIR_BYTECODE_BYTECODEOPINTERFACES
101 changes: 81 additions & 20 deletions mlir/lib/Bytecode/Reader/BytecodeReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Endian.h"
#include "llvm/Support/MemoryBufferRef.h"
#include "llvm/Support/SourceMgr.h"
Expand Down Expand Up @@ -292,6 +293,16 @@ class EncodingReader {

Location getLoc() const { return fileLoc; }

/// Snapshot the location of the BytecodeReader so that parsing can be rewound
/// if needed.
struct Snapshot {
EncodingReader &reader;
const uint8_t *dataIt;

Snapshot(EncodingReader &reader) : reader(reader), dataIt(reader.dataIt) {}
void rewind() { reader.dataIt = dataIt; }
};

private:
/// Parse a variable length encoded integer from the byte stream. This method
/// is a fallback when the number of bytes used to encode the value is greater
Expand Down Expand Up @@ -1410,8 +1421,9 @@ class mlir::BytecodeReader::Impl {
/// Parse an operation name reference using the given reader, and set the
/// `wasRegistered` flag that indicates if the bytecode was produced by a
/// context where opName was registered.
FailureOr<OperationName> parseOpName(EncodingReader &reader,
std::optional<bool> &wasRegistered);
FailureOr<BytecodeOperationName *>
parseOpName(EncodingReader &reader, std::optional<bool> &wasRegistered,
bool useDialectFallback);

//===--------------------------------------------------------------------===//
// Attribute/Type Section
Expand Down Expand Up @@ -1476,7 +1488,8 @@ class mlir::BytecodeReader::Impl {
RegionReadState &readState);
FailureOr<Operation *> parseOpWithoutRegions(EncodingReader &reader,
RegionReadState &readState,
bool &isIsolatedFromAbove);
bool &isIsolatedFromAbove,
bool useDialectFallback);

LogicalResult parseRegion(RegionReadState &readState);
LogicalResult parseBlockHeader(EncodingReader &reader,
Expand Down Expand Up @@ -1506,7 +1519,7 @@ class mlir::BytecodeReader::Impl {
UseListOrderStorage(bool isIndexPairEncoding,
SmallVector<unsigned, 4> &&indices)
: indices(std::move(indices)),
isIndexPairEncoding(isIndexPairEncoding){};
isIndexPairEncoding(isIndexPairEncoding) {};
/// The vector containing the information required to reorder the
/// use-list of a value.
SmallVector<unsigned, 4> indices;
Expand Down Expand Up @@ -1843,16 +1856,20 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
return success();
}

FailureOr<OperationName>
FailureOr<BytecodeOperationName *>
BytecodeReader::Impl::parseOpName(EncodingReader &reader,
std::optional<bool> &wasRegistered) {
std::optional<bool> &wasRegistered,
bool useDialectFallback) {
BytecodeOperationName *opName = nullptr;
if (failed(parseEntry(reader, opNames, opName, "operation name")))
return failure();
wasRegistered = opName->wasRegistered;
// Check to see if this operation name has already been resolved. If we
// haven't, load the dialect and build the operation name.
if (!opName->opName) {
// If `useDialectFallback`, it's likely that parsing previously failed. We'll
// need to reset any previously resolved OperationName with that of the
// fallback op.
if (!opName->opName || useDialectFallback) {
// If the opName is empty, this is because we use to accept names such as
// `foo` without any `.` separator. We shouldn't tolerate this in textual
// format anymore but for now we'll be backward compatible. This can only
Expand All @@ -1865,11 +1882,26 @@ BytecodeReader::Impl::parseOpName(EncodingReader &reader,
dialectsMap, reader, version);
if (failed(opName->dialect->load(dialectReader, getContext())))
return failure();
opName->opName.emplace((opName->dialect->name + "." + opName->name).str(),
getContext());

const BytecodeDialectInterface *dialectIface = opName->dialect->interface;
if (useDialectFallback) {
FailureOr<OperationName> fallbackOp =
dialectIface ? dialectIface->getFallbackOperationName()
: FailureOr<OperationName>{};

// If the dialect doesn't have a fallback operation, we can't parse as
// instructed.
if (failed(fallbackOp))
return failure();

opName->opName.emplace(*fallbackOp);
} else {
opName->opName.emplace(
(opName->dialect->name + "." + opName->name).str(), getContext());
}
}
}
return *opName->opName;
return opName;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2143,10 +2175,30 @@ BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> &regionStack,
// Read in the next operation. We don't read its regions directly, we
// handle those afterwards as necessary.
bool isIsolatedFromAbove = false;
FailureOr<Operation *> op =
parseOpWithoutRegions(reader, readState, isIsolatedFromAbove);
if (failed(op))
return failure();
FailureOr<Operation *> op;

// Parse the bytecode.
{
// If the op is registered (and serialized in a compatible manner), or
// unregistered but uses standard properties encoding, parsing without
// going through the fallback path should work.
EncodingReader::Snapshot snapshot(reader);
op = parseOpWithoutRegions(reader, readState, isIsolatedFromAbove,
/*useDialectFallback=*/false);

// If reading fails, try parsing the op again as a dialect fallback
// op (if supported).
if (failed(op)) {
snapshot.rewind();
op = parseOpWithoutRegions(reader, readState, isIsolatedFromAbove,
/*useDialectFallback=*/true);
}

// If the dialect doesn't have a fallback op, or parsing as a fallback
// op fails, we can no longer continue.
if (failed(op))
return failure();
}

// If the op has regions, add it to the stack for processing and return:
// we stop the processing of the current region and resume it after the
Expand Down Expand Up @@ -2208,14 +2260,17 @@ BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> &regionStack,
return success();
}

FailureOr<Operation *>
BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
RegionReadState &readState,
bool &isIsolatedFromAbove) {
FailureOr<Operation *> BytecodeReader::Impl::parseOpWithoutRegions(
EncodingReader &reader, RegionReadState &readState,
bool &isIsolatedFromAbove, bool useDialectFallback) {
// Parse the name of the operation.
std::optional<bool> wasRegistered;
FailureOr<OperationName> opName = parseOpName(reader, wasRegistered);
if (failed(opName))
FailureOr<BytecodeOperationName *> bytecodeOp =
parseOpName(reader, wasRegistered, useDialectFallback);
if (failed(bytecodeOp))
return failure();
auto opName = (*bytecodeOp)->opName;
if (!opName)
return failure();

// Parse the operation mask, which indicates which components of the operation
Expand All @@ -2232,6 +2287,12 @@ BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
// With the location and name resolved, we can start building the operation
// state.
OperationState opState(opLoc, *opName);
// If this is a fallback op, provide the original name of the operation.
if (auto *iface = opName->getInterface<FallbackBytecodeOpInterface>()) {
const Twine originalName =
opName->getDialect()->getNamespace() + "." + (*bytecodeOp)->name;
iface->setOriginalOperationName(originalName, opState);
}

// Parse the attributes of the operation.
if (opMask & bytecode::OpEncodingMask::kHasAttrs) {
Expand Down
13 changes: 10 additions & 3 deletions mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -841,12 +841,12 @@ void BytecodeWriter::writeDialectSection(EncodingEmitter &emitter) {

// Emit the referenced operation names grouped by dialect.
auto emitOpName = [&](OpNameNumbering &name) {
const bool isKnownOp = name.isOpaqueEntry || name.name.isRegistered();
size_t stringId = stringSection.insert(name.name.stripDialect());
if (config.bytecodeVersion < bytecode::kNativePropertiesEncoding)
dialectEmitter.emitVarInt(stringId, "dialect op name");
else
dialectEmitter.emitVarIntWithFlag(stringId, name.name.isRegistered(),
"dialect op name");
dialectEmitter.emitVarIntWithFlag(stringId, isKnownOp, "dialect op name");
};
writeDialectGrouping(dialectEmitter, numberingState.getOpNames(), emitOpName);

Expand Down Expand Up @@ -984,7 +984,14 @@ LogicalResult BytecodeWriter::writeBlock(EncodingEmitter &emitter,
}

LogicalResult BytecodeWriter::writeOp(EncodingEmitter &emitter, Operation *op) {
emitter.emitVarInt(numberingState.getNumber(op->getName()), "op name ID");
OperationName opName = op->getName();
// For fallback ops, create a new operation name referencing the original op
// instead.
if (auto fallback = dyn_cast<FallbackBytecodeOpInterface>(op))
opName =
OperationName(fallback.getOriginalOperationName(), op->getContext());

emitter.emitVarInt(numberingState.getNumber(opName), "op name ID");

// Emit a mask for the operation components. We need to fill this in later
// (when we actually know what needs to be emitted), so emit a placeholder for
Expand Down
17 changes: 13 additions & 4 deletions mlir/lib/Bytecode/Writer/IRNumbering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,16 @@ void IRNumberingState::number(Region &region) {
void IRNumberingState::number(Operation &op) {
// Number the components of an operation that won't be numbered elsewhere
// (e.g. we don't number operands, regions, or successors here).
number(op.getName());

// For fallback ops, create a new OperationName referencing the original op
// instead.
if (auto fallback = dyn_cast<FallbackBytecodeOpInterface>(op)) {
OperationName opName(fallback.getOriginalOperationName(), op.getContext());
number(opName, /*isOpaque=*/true);
} else {
number(op.getName(), /*isOpaque=*/false);
}

for (OpResult result : op.getResults()) {
valueIDs.try_emplace(result, nextValueID++);
number(result.getType());
Expand Down Expand Up @@ -457,7 +466,7 @@ void IRNumberingState::number(Operation &op) {
number(op.getLoc());
}

void IRNumberingState::number(OperationName opName) {
void IRNumberingState::number(OperationName opName, bool isOpaque) {
OpNameNumbering *&numbering = opNames[opName];
if (numbering) {
++numbering->refCount;
Expand All @@ -469,8 +478,8 @@ void IRNumberingState::number(OperationName opName) {
else
dialectNumber = &numberDialect(opName.getDialectNamespace());

numbering =
new (opNameAllocator.Allocate()) OpNameNumbering(dialectNumber, opName);
numbering = new (opNameAllocator.Allocate())
OpNameNumbering(dialectNumber, opName, isOpaque);
orderedOpNames.push_back(numbering);
}

Expand Down
11 changes: 7 additions & 4 deletions mlir/lib/Bytecode/Writer/IRNumbering.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,18 @@ struct TypeNumbering : public AttrTypeNumbering {

/// This class represents the numbering entry of an operation name.
struct OpNameNumbering {
OpNameNumbering(DialectNumbering *dialect, OperationName name)
: dialect(dialect), name(name) {}
OpNameNumbering(DialectNumbering *dialect, OperationName name, bool isOpaque)
: dialect(dialect), name(name), isOpaqueEntry(isOpaque) {}

/// The dialect of this value.
DialectNumbering *dialect;

/// The concrete name.
OperationName name;

/// This entry represents an opaque operation entry.
bool isOpaqueEntry = false;

/// The number assigned to this name.
unsigned number = 0;

Expand Down Expand Up @@ -210,7 +213,7 @@ class IRNumberingState {

/// Get the set desired bytecode version to emit.
int64_t getDesiredBytecodeVersion() const;

private:
/// This class is used to provide a fake dialect writer for numbering nested
/// attributes and types.
Expand All @@ -225,7 +228,7 @@ class IRNumberingState {
DialectNumbering &numberDialect(Dialect *dialect);
DialectNumbering &numberDialect(StringRef dialect);
void number(Operation &op);
void number(OperationName opName);
void number(OperationName opName, bool isOpaque);
void number(Region &region);
void number(Type type);

Expand Down
13 changes: 13 additions & 0 deletions mlir/test/Bytecode/versioning/versioning-fallback.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// RUN: mlir-opt %s --emit-bytecode > %T/versioning-fallback.mlirbc
"test.versionedD"() <{attribute = #test.attr_params<42, 24>}> : () -> ()

// COM: check that versionedD was parsed as a fallback op.
// RUN: mlir-opt %T/versioning-fallback.mlirbc | FileCheck %s --check-prefix=CHECK-PARSE
// CHECK-PARSE: test.bytecode.fallback
// CHECK-PARSE-SAME: opname = "test.versionedD"

// COM: check that the bytecode roundtrip was successful
// RUN: mlir-opt %T/versioning-fallback.mlirbc --verify-roundtrip

// COM: check that the bytecode roundtrip is bitwise exact
// RUN: mlir-opt %T/versioning-fallback.mlirbc --emit-bytecode | diff %T/versioning-fallback.mlirbc -
6 changes: 6 additions & 0 deletions mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "TestDialect.h"
#include "TestOps.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Interfaces/FoldInterfaces.h"
#include "mlir/Reducer/ReductionPatternInterface.h"
#include "mlir/Transforms/InliningUtils.h"
Expand Down Expand Up @@ -92,6 +93,11 @@ struct TestBytecodeDialectInterface : public BytecodeDialectInterface {
return Attribute();
}

FailureOr<OperationName> getFallbackOperationName() const final {
return OperationName(TestBytecodeFallbackOp::getOperationName(),
getContext());
}

// Emit a specific version of the dialect.
void writeVersion(DialectBytecodeWriter &writer) const final {
// Construct the current dialect version.
Expand Down
Loading
Loading