Skip to content

Commit 34587b8

Browse files
committed
reading working
1 parent fab9e00 commit 34587b8

File tree

5 files changed

+116
-20
lines changed

5 files changed

+116
-20
lines changed

mlir/include/mlir/Bytecode/BytecodeOpInterface.td

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,24 @@ def FallbackBytecodeOpInterface : OpInterface<"FallbackBytecodeOpInterface"> {
4949
let cppNamespace = "::mlir";
5050

5151
let methods = [
52+
StaticInterfaceMethod<[{
53+
Set the original name for this operation from the bytecode.
54+
}],
55+
"void", "setOriginalOperationName", (ins
56+
"::mlir::StringRef":$opName,
57+
"::mlir::OperationState &":$state)
58+
>,
5259
StaticInterfaceMethod<[{
5360
Read the properties blob for this operation from the bytecode and populate the state.
5461
}],
55-
"LogicalResult", "readPropertiesBlob", (ins
56-
"ArrayRef<char>":$blob,
62+
"::mlir::LogicalResult", "readPropertiesBlob", (ins
63+
"::mlir::ArrayRef<char>":$blob,
5764
"::mlir::OperationState &":$state)
5865
>,
5966
InterfaceMethod<[{
6067
Get the properties blob for this operation to be emitted into the bytecode.
6168
}],
62-
"ArrayRef<char>", "getPropertiesBlob", (ins)
69+
"::mlir::ArrayRef<char>", "getPropertiesBlob", (ins)
6370
>,
6471
];
6572
}

mlir/lib/Bytecode/Reader/BytecodeReader.cpp

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1415,8 +1415,8 @@ class mlir::BytecodeReader::Impl {
14151415
/// Parse an operation name reference using the given reader, and set the
14161416
/// `wasRegistered` flag that indicates if the bytecode was produced by a
14171417
/// context where opName was registered.
1418-
FailureOr<OperationName> parseOpName(EncodingReader &reader,
1419-
std::optional<bool> &wasRegistered);
1418+
FailureOr<BytecodeOperationName *>
1419+
parseOpName(EncodingReader &reader, std::optional<bool> &wasRegistered);
14201420

14211421
//===--------------------------------------------------------------------===//
14221422
// Attribute/Type Section
@@ -1848,7 +1848,7 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
18481848
return success();
18491849
}
18501850

1851-
FailureOr<OperationName>
1851+
FailureOr<BytecodeOperationName *>
18521852
BytecodeReader::Impl::parseOpName(EncodingReader &reader,
18531853
std::optional<bool> &wasRegistered) {
18541854
BytecodeOperationName *opName = nullptr;
@@ -1868,21 +1868,28 @@ BytecodeReader::Impl::parseOpName(EncodingReader &reader,
18681868
// Load the dialect and its version.
18691869
DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
18701870
dialectsMap, reader, version);
1871-
if (succeeded(opName->dialect->load(dialectReader, getContext()))) {
1872-
opName->opName.emplace(
1873-
(opName->dialect->name + "." + opName->name).str(), getContext());
1874-
} else if (auto fallbackOp =
1875-
opName->dialect->interface->getFallbackOperationName();
1876-
succeeded(fallbackOp)) {
1877-
// If the dialect's bytecode interface specifies a fallback op, we want
1878-
// to use that instead of an unregistered op.
1879-
opName->opName.emplace(*fallbackOp);
1880-
} else {
1871+
if (failed(opName->dialect->load(dialectReader, getContext())))
18811872
return failure();
1873+
1874+
opName->opName.emplace((opName->dialect->name + "." + opName->name).str(),
1875+
getContext());
1876+
1877+
// If the op is unregistered now, but was not marked as unregistered, try
1878+
// to parse it as a fallback op if the dialect's bytecode interface
1879+
// specifies one.
1880+
// We don't treat this condition as an error because we may still be able
1881+
// to parse the op as an unregistered op if it doesn't use custom
1882+
// properties encoding.
1883+
if (wasRegistered && !opName->opName->isRegistered()) {
1884+
if (auto fallbackOp =
1885+
opName->dialect->interface->getFallbackOperationName();
1886+
succeeded(fallbackOp)) {
1887+
opName->opName.emplace(*fallbackOp);
1888+
}
18821889
}
18831890
}
18841891
}
1885-
return *opName->opName;
1892+
return opName;
18861893
}
18871894

18881895
//===----------------------------------------------------------------------===//
@@ -2227,8 +2234,12 @@ BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
22272234
bool &isIsolatedFromAbove) {
22282235
// Parse the name of the operation.
22292236
std::optional<bool> wasRegistered;
2230-
FailureOr<OperationName> opName = parseOpName(reader, wasRegistered);
2231-
if (failed(opName))
2237+
FailureOr<BytecodeOperationName *> bytecodeOp =
2238+
parseOpName(reader, wasRegistered);
2239+
if (failed(bytecodeOp))
2240+
return failure();
2241+
auto opName = (*bytecodeOp)->opName;
2242+
if (!opName)
22322243
return failure();
22332244

22342245
// Parse the operation mask, which indicates which components of the operation
@@ -2245,6 +2256,9 @@ BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
22452256
// With the location and name resolved, we can start building the operation
22462257
// state.
22472258
OperationState opState(opLoc, *opName);
2259+
// If this is a fallback op, provide the original name of the operation.
2260+
if (auto *iface = opName->getInterface<FallbackBytecodeOpInterface>())
2261+
iface->setOriginalOperationName((*bytecodeOp)->name, opState);
22482262

22492263
// Parse the attributes of the operation.
22502264
if (opMask & bytecode::OpEncodingMask::kHasAttrs) {

mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "TestDialect.h"
1010
#include "TestOps.h"
11+
#include "mlir/IR/OperationSupport.h"
1112
#include "mlir/Interfaces/FoldInterfaces.h"
1213
#include "mlir/Reducer/ReductionPatternInterface.h"
1314
#include "mlir/Transforms/InliningUtils.h"
@@ -92,6 +93,11 @@ struct TestBytecodeDialectInterface : public BytecodeDialectInterface {
9293
return Attribute();
9394
}
9495

96+
FailureOr<OperationName> getFallbackOperationName() const final {
97+
return OperationName(TestBytecodeFallbackOp::getOperationName(),
98+
getContext());
99+
}
100+
95101
// Emit a specific version of the dialect.
96102
void writeVersion(DialectBytecodeWriter &writer) const final {
97103
// Construct the current dialect version.

mlir/test/lib/Dialect/Test/TestOpDefs.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@
99
#include "TestDialect.h"
1010
#include "TestOps.h"
1111
#include "mlir/Dialect/Tensor/IR/Tensor.h"
12+
#include "mlir/IR/BuiltinAttributes.h"
1213
#include "mlir/IR/Verifier.h"
1314
#include "mlir/Interfaces/FunctionImplementation.h"
1415
#include "mlir/Interfaces/MemorySlotInterfaces.h"
16+
#include "llvm/Support/LogicalResult.h"
17+
#include <cstdint>
1518

1619
using namespace mlir;
1720
using namespace test;
@@ -1230,6 +1233,53 @@ void TestVersionedOpA::writeProperties(mlir::DialectBytecodeWriter &writer) {
12301233
writer.writeAttribute(prop.modifier);
12311234
}
12321235

1236+
//===----------------------------------------------------------------------===//
1237+
// TestVersionedOpD
1238+
//===----------------------------------------------------------------------===//
1239+
1240+
// LogicalResult
1241+
// TestVersionedOpD::readProperties(mlir::DialectBytecodeReader &reader,
1242+
// mlir::OperationState &state) {
1243+
// auto &prop = state.getOrAddProperties<Properties>();
1244+
// StringRef res;
1245+
// if (failed(reader.readString(res)))
1246+
// return failure();
1247+
// if (failed(reader.readAttribute(prop.attribute)))
1248+
// return failure();
1249+
1250+
// return success();
1251+
// }
1252+
1253+
// void TestVersionedOpD::writeProperties(mlir::DialectBytecodeWriter &writer) {
1254+
// auto &prop = getProperties();
1255+
// writer.writeOwnedString("version 1");
1256+
// writer.writeAttribute(prop.attribute);
1257+
// }
1258+
1259+
//===----------------------------------------------------------------------===//
1260+
// TestBytecodeFallbackOp
1261+
//===----------------------------------------------------------------------===//
1262+
1263+
void TestBytecodeFallbackOp::setOriginalOperationName(StringRef name,
1264+
OperationState &state) {
1265+
state.getOrAddProperties<Properties>().setOpname(
1266+
StringAttr::get(state.getContext(), name));
1267+
}
1268+
1269+
LogicalResult
1270+
TestBytecodeFallbackOp::readPropertiesBlob(ArrayRef<char> blob,
1271+
OperationState &state) {
1272+
state.getOrAddProperties<Properties>().bytecodeProperties =
1273+
DenseI8ArrayAttr::get(state.getContext(),
1274+
ArrayRef((const int8_t *)blob.data(), blob.size()));
1275+
return success();
1276+
}
1277+
1278+
ArrayRef<char> TestBytecodeFallbackOp::getPropertiesBlob() {
1279+
return ArrayRef((const char *)getBytecodeProperties().data(),
1280+
getBytecodeProperties().size());
1281+
}
1282+
12331283
//===----------------------------------------------------------------------===//
12341284
// TestOpWithVersionedProperties
12351285
//===----------------------------------------------------------------------===//

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ include "mlir/IR/OpAsmInterface.td"
2020
include "mlir/IR/PatternBase.td"
2121
include "mlir/IR/RegionKindInterface.td"
2222
include "mlir/IR/SymbolInterfaces.td"
23+
include "mlir/Bytecode/BytecodeOpInterface.td"
2324
include "mlir/Interfaces/CallInterfaces.td"
2425
include "mlir/Interfaces/ControlFlowInterfaces.td"
2526
include "mlir/Interfaces/CopyOpInterface.td"
@@ -31,7 +32,6 @@ include "mlir/Interfaces/LoopLikeInterface.td"
3132
include "mlir/Interfaces/MemorySlotInterfaces.td"
3233
include "mlir/Interfaces/SideEffectInterfaces.td"
3334

34-
3535
// Include the attribute definitions.
3636
include "TestAttrDefs.td"
3737
// Include the type definitions.
@@ -3030,6 +3030,25 @@ def TestVersionedOpC : TEST_Op<"versionedC"> {
30303030
);
30313031
}
30323032

3033+
// def TestVersionedOpD : TEST_Op<"versionedD"> {
3034+
// let arguments = (ins AnyAttrOf<[TestAttrParams,
3035+
// I32ElementsAttr]>:$attribute
3036+
// );
3037+
3038+
// let useCustomPropertiesEncoding = 1;
3039+
// }
3040+
3041+
def TestBytecodeFallbackOp : TEST_Op<"bytecode.fallback", [
3042+
DeclareOpInterfaceMethods<FallbackBytecodeOpInterface, ["setOriginalOperationName", "readPropertiesBlob", "getPropertiesBlob"]>
3043+
]> {
3044+
let arguments = (ins
3045+
StrAttr:$opname,
3046+
DenseI8ArrayAttr:$bytecodeProperties,
3047+
Variadic<AnyType>:$operands);
3048+
let regions = (region VariadicRegion<AnyRegion>:$bodyRegions);
3049+
let results = (outs Variadic<AnyType>:$results);
3050+
}
3051+
30333052
//===----------------------------------------------------------------------===//
30343053
// Test Properties
30353054
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)