Skip to content

Commit d76f47b

Browse files
committed
roundtrip working
1 parent 34587b8 commit d76f47b

File tree

4 files changed

+37
-3
lines changed

4 files changed

+37
-3
lines changed

mlir/include/mlir/Bytecode/BytecodeOpInterface.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@ def FallbackBytecodeOpInterface : OpInterface<"FallbackBytecodeOpInterface"> {
5656
"::mlir::StringRef":$opName,
5757
"::mlir::OperationState &":$state)
5858
>,
59+
InterfaceMethod<[{
60+
Get the original name for this operation from the bytecode.
61+
}],
62+
"::mlir::StringRef", "getOriginalOperationName", (ins)
63+
>,
5964
StaticInterfaceMethod<[{
6065
Read the properties blob for this operation from the bytecode and populate the state.
6166
}],

mlir/lib/Bytecode/Writer/BytecodeWriter.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -848,11 +848,16 @@ void BytecodeWriter::writeDialectSection(EncodingEmitter &emitter) {
848848

849849
// Emit the referenced operation names grouped by dialect.
850850
auto emitOpName = [&](OpNameNumbering &name) {
851+
bool isRegistered = name.name.isRegistered();
852+
// If we're writing a fallback op, write it as if it were a registered op.
853+
if (name.name.hasInterface<FallbackBytecodeOpInterface>())
854+
isRegistered = true;
855+
851856
size_t stringId = stringSection.insert(name.name.stripDialect());
852857
if (config.bytecodeVersion < bytecode::kNativePropertiesEncoding)
853858
dialectEmitter.emitVarInt(stringId, "dialect op name");
854859
else
855-
dialectEmitter.emitVarIntWithFlag(stringId, name.name.isRegistered(),
860+
dialectEmitter.emitVarIntWithFlag(stringId, isRegistered,
856861
"dialect op name");
857862
};
858863
writeDialectGrouping(dialectEmitter, numberingState.getOpNames(), emitOpName);
@@ -991,7 +996,16 @@ LogicalResult BytecodeWriter::writeBlock(EncodingEmitter &emitter,
991996
}
992997

993998
LogicalResult BytecodeWriter::writeOp(EncodingEmitter &emitter, Operation *op) {
994-
emitter.emitVarInt(numberingState.getNumber(op->getName()), "op name ID");
999+
OperationName opName = op->getName();
1000+
// For fallback ops, create a new operation name referencing the original op
1001+
// instead.
1002+
if (auto fallback = dyn_cast<FallbackBytecodeOpInterface>(op))
1003+
opName = OperationName((fallback->getDialect()->getNamespace() + "." +
1004+
fallback.getOriginalOperationName())
1005+
.str(),
1006+
op->getContext());
1007+
1008+
emitter.emitVarInt(numberingState.getNumber(opName), "op name ID");
9951009

9961010
// Emit a mask for the operation components. We need to fill this in later
9971011
// (when we actually know what needs to be emitted), so emit a placeholder for

mlir/lib/Bytecode/Writer/IRNumbering.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,17 @@ void IRNumberingState::number(Region &region) {
419419
void IRNumberingState::number(Operation &op) {
420420
// Number the components of an operation that won't be numbered elsewhere
421421
// (e.g. we don't number operands, regions, or successors here).
422-
number(op.getName());
422+
423+
OperationName opName = op.getName();
424+
// For fallback ops, create a new operation name referencing the original op
425+
// instead.
426+
if (auto fallback = dyn_cast<FallbackBytecodeOpInterface>(op))
427+
opName = OperationName((fallback->getDialect()->getNamespace() + "." +
428+
fallback.getOriginalOperationName())
429+
.str(),
430+
op.getContext());
431+
number(opName);
432+
423433
for (OpResult result : op.getResults()) {
424434
valueIDs.try_emplace(result, nextValueID++);
425435
number(result.getType());

mlir/test/lib/Dialect/Test/TestOpDefs.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "TestDialect.h"
1010
#include "TestOps.h"
11+
#include "mlir/Bytecode/BytecodeImplementation.h"
1112
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1213
#include "mlir/IR/BuiltinAttributes.h"
1314
#include "mlir/IR/Verifier.h"
@@ -1266,6 +1267,10 @@ void TestBytecodeFallbackOp::setOriginalOperationName(StringRef name,
12661267
StringAttr::get(state.getContext(), name));
12671268
}
12681269

1270+
StringRef TestBytecodeFallbackOp::getOriginalOperationName() {
1271+
return getProperties().getOpname().getValue();
1272+
}
1273+
12691274
LogicalResult
12701275
TestBytecodeFallbackOp::readPropertiesBlob(ArrayRef<char> blob,
12711276
OperationState &state) {

0 commit comments

Comments
 (0)