Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,24 @@
//===----------------------------------------------------------------------===//

#include <optional>
#include <system_error>
#include <utility>

#include "Globals.h"
#include "IRModule.h"
#include "NanobindUtils.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/Debug.h"
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
#include "nanobind/nanobind.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/raw_ostream.h"

namespace nb = nanobind;
using namespace nb::literals;
Expand Down Expand Up @@ -1329,20 +1332,18 @@ void PyOperationBase::print(PyAsmState &state, nb::object fileObject,
accum.getUserData());
}

void PyOperationBase::writeBytecode(const nb::object &fileObject,
std::optional<int64_t> bytecodeVersion) {
PyOperation &operation = getOperation();
operation.checkValid();
PyFileAccumulator accum(fileObject, /*binary=*/true);

template <typename T>
static void
writeBytecodeForOperation(T &accumulator, MlirOperation operation,
const std::optional<int64_t> &bytecodeVersion) {
if (!bytecodeVersion.has_value())
return mlirOperationWriteBytecode(operation, accum.getCallback(),
accum.getUserData());
return mlirOperationWriteBytecode(operation, accumulator.getCallback(),
accumulator.getUserData());

MlirBytecodeWriterConfig config = mlirBytecodeWriterConfigCreate();
mlirBytecodeWriterConfigDesiredEmitVersion(config, *bytecodeVersion);
MlirLogicalResult res = mlirOperationWriteBytecodeWithConfig(
operation, config, accum.getCallback(), accum.getUserData());
operation, config, accumulator.getCallback(), accumulator.getUserData());
mlirBytecodeWriterConfigDestroy(config);
if (mlirLogicalResultIsFailure(res))
throw nb::value_error((Twine("Unable to honor desired bytecode version ") +
Expand All @@ -1351,6 +1352,27 @@ void PyOperationBase::writeBytecode(const nb::object &fileObject,
.c_str());
}

void PyOperationBase::writeBytecode(const nb::object &fileObject,
std::optional<int64_t> bytecodeVersion) {
PyOperation &operation = getOperation();
operation.checkValid();

std::string filePath;
if (nb::try_cast<std::string>(fileObject, filePath)) {
std::error_code ec;
llvm::raw_fd_ostream ostream(filePath, ec);
if (ec) {
throw nb::value_error("Unable to open file for writing");
}

OstreamAccumulator accum(ostream);
writeBytecodeForOperation(accum, operation, bytecodeVersion);
} else {
PyFileAccumulator accum(fileObject, /*binary=*/true);
writeBytecodeForOperation(accum, operation, bytecodeVersion);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this logic be something we could just move to the PyFileAccumulator itself so that every usage benefits from this optimization instead of specializing it everywhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure -- done!

}
}

void PyOperationBase::walk(
std::function<MlirWalkResult(MlirOperation)> callback,
MlirWalkOrder walkOrder) {
Expand Down
22 changes: 21 additions & 1 deletion mlir/lib/Bindings/Python/NanobindUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/DataTypes.h"
#include "llvm/Support/raw_ostream.h"

template <>
struct std::iterator_traits<nanobind::detail::fast_iterator> {
Expand Down Expand Up @@ -128,7 +130,7 @@ struct PyPrintAccumulator {
}
};

/// Accumulates int a python file-like object, either writing text (default)
/// Accumulates into a python file-like object, either writing text (default)
/// or binary.
class PyFileAccumulator {
public:
Expand Down Expand Up @@ -158,6 +160,24 @@ class PyFileAccumulator {
bool binary;
};

/// Accumulates into a LLVM ostream.
class OstreamAccumulator {
public:
OstreamAccumulator(llvm::raw_ostream &ostream) : ostream(ostream) {}

void *getUserData() { return this; }

MlirStringCallback getCallback() {
return [](MlirStringRef part, void *userData) {
OstreamAccumulator *accum = static_cast<OstreamAccumulator *>(userData);
accum->ostream << llvm::StringRef(part.data, part.length);
};
}

private:
llvm::raw_ostream &ostream;
};

/// Accumulates into a python string from a method that is expected to make
/// one (no more, no less) call to the callback (asserts internally on
/// violation).
Expand Down
6 changes: 3 additions & 3 deletions mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ import collections
from collections.abc import Callable, Sequence
import io
from pathlib import Path
from typing import Any, ClassVar, TypeVar, overload
from typing import Any, BinaryIO, ClassVar, TypeVar, overload

__all__ = [
"AffineAddExpr",
Expand Down Expand Up @@ -285,12 +285,12 @@ class _OperationBase:
"""
Verify the operation. Raises MLIRError if verification fails, and returns true otherwise.
"""
def write_bytecode(self, file: Any, desired_version: int | None = None) -> None:
def write_bytecode(self, file: BinaryIO | str, desired_version: int | None = None) -> None:
"""
Write the bytecode form of the operation to a file like object.

Args:
file: The file like object to write to.
file: The file like object or path to write to.
desired_version: The version of bytecode to emit.
Returns:
The bytecode writer status.
Expand Down
7 changes: 7 additions & 0 deletions mlir/test/python/ir/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import gc
import io
import itertools
from tempfile import NamedTemporaryFile
from mlir.ir import *
from mlir.dialects.builtin import ModuleOp
from mlir.dialects import arith
Expand Down Expand Up @@ -617,6 +618,12 @@ def testOperationPrint():
module.operation.write_bytecode(bytecode_stream, desired_version=1)
bytecode = bytecode_stream.getvalue()
assert bytecode.startswith(b"ML\xefR"), "Expected bytecode to start with MLïR"
with NamedTemporaryFile() as tmpfile:
module.operation.write_bytecode(str(tmpfile.name), desired_version=1)
tmpfile.seek(0)
assert tmpfile.read().startswith(
b"ML\xefR"
), "Expected bytecode to start with MLïR"
ctx2 = Context()
module_roundtrip = Module.parse(bytecode, ctx2)
f = io.StringIO()
Expand Down