From 9db2dc4c039e87572433c9cd714c92574ea7c74a Mon Sep 17 00:00:00 2001 From: Nikhil Kalra Date: Tue, 4 Feb 2025 10:10:49 -0800 Subject: [PATCH 1/3] [mlir] Python: Parse ModuleOp from file path For extremely large models, it may be inefficient to load the model into memory in Python prior to passing it to the MLIR C APIs for deserialization. This change adds an API to parse a ModuleOp directly from a file path. --- mlir/include/mlir-c/IR.h | 4 ++++ mlir/include/mlir/Bindings/Python/Nanobind.h | 1 + mlir/lib/Bindings/Python/IRCore.cpp | 16 +++++++++++++++- mlir/lib/CAPI/IR/IR.cpp | 10 ++++++++++ mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 3 ++- 5 files changed, 32 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 7d2fd89e8560f..14ccae650606a 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -309,6 +309,10 @@ MLIR_CAPI_EXPORTED MlirModule mlirModuleCreateEmpty(MlirLocation location); MLIR_CAPI_EXPORTED MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module); +/// Parses a module from file and transfers ownership to the caller. +MLIR_CAPI_EXPORTED MlirModule +mlirModuleCreateParseFromFile(MlirContext context, MlirStringRef fileName); + /// Gets the context that a module was created with. MLIR_CAPI_EXPORTED MlirContext mlirModuleGetContext(MlirModule module); diff --git a/mlir/include/mlir/Bindings/Python/Nanobind.h b/mlir/include/mlir/Bindings/Python/Nanobind.h index ca942c83d3e2f..bc8bddf4caf7e 100644 --- a/mlir/include/mlir/Bindings/Python/Nanobind.h +++ b/mlir/include/mlir/Bindings/Python/Nanobind.h @@ -23,6 +23,7 @@ #endif #include #include +#include #include #include #include diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 8e351cb22eb94..b772c9a583a6b 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include #include #include @@ -299,7 +300,7 @@ struct PyAttrBuilderMap { return *builder; } static void dunderSetItemNamed(const std::string &attributeKind, - nb::callable func, bool replace) { + nb::callable func, bool replace) { PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func), replace); } @@ -3047,6 +3048,19 @@ void mlir::python::populateIRCore(nb::module_ &m) { }, nb::arg("asm"), nb::arg("context").none() = nb::none(), kModuleParseDocstring) + .def_static( + "parse", + [](const std::filesystem::path &path, + DefaultingPyMlirContext context) { + PyMlirContext::ErrorCapture errors(context->getRef()); + MlirModule module = mlirModuleCreateParseFromFile( + context->get(), toMlirStringRef(path.string())); + if (mlirModuleIsNull(module)) + throw MLIRError("Unable to parse module assembly", errors.take()); + return PyModule::forModule(module).releaseObject(); + }, + nb::arg("asm"), nb::arg("context").none() = nb::none(), + kModuleParseDocstring) .def_static( "create", [](DefaultingPyLocation loc) { diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index f27af0ca9a2c7..999e8cbda1295 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -22,6 +22,7 @@ #include "mlir/IR/Location.h" #include "mlir/IR/Operation.h" #include "mlir/IR/OperationSupport.h" +#include "mlir/IR/OwningOpRef.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" #include "mlir/IR/Verifier.h" @@ -328,6 +329,15 @@ MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module) { return MlirModule{owning.release().getOperation()}; } +MlirModule mlirModuleCreateParseFromFile(MlirContext context, + MlirStringRef fileName) { + OwningOpRef owning = + parseSourceFile(unwrap(fileName), unwrap(context)); + if (!owning) + return MlirModule{nullptr}; + return MlirModule{owning.release().getOperation()}; +} + MlirContext mlirModuleGetContext(MlirModule module) { return wrap(unwrap(module).getContext()); } diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index fb7efb8cd28a5..096b87b362443 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -46,6 +46,7 @@ import abc import collections from collections.abc import Callable, Sequence import io +from pathlib import Path from typing import Any, ClassVar, TypeVar, overload __all__ = [ @@ -2123,7 +2124,7 @@ class Module: Creates an empty module """ @staticmethod - def parse(asm: str | bytes, context: Context | None = None) -> Module: + def parse(asm: str | bytes | Path, context: Context | None = None) -> Module: """ Parses a module's assembly format from a string. From b2e0c9b6e0b4f64f7f6ed426c45d542e5d478f1d Mon Sep 17 00:00:00 2001 From: Nikhil Kalra Date: Tue, 4 Feb 2025 16:25:12 -0800 Subject: [PATCH 2/3] add a test --- mlir/test/python/ir/module.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/mlir/test/python/ir/module.py b/mlir/test/python/ir/module.py index ecafcb46af217..d0ef69f39b4e3 100644 --- a/mlir/test/python/ir/module.py +++ b/mlir/test/python/ir/module.py @@ -1,6 +1,8 @@ # RUN: %PYTHON %s | FileCheck %s import gc +from pathlib import Path +from tempfile import NamedTemporaryFile from mlir.ir import * @@ -27,6 +29,24 @@ def testParseSuccess(): print(str(module)) +# Verify successful parse from file. +# CHECK-LABEL: TEST: testParseFromFileSuccess +# CHECK: module @successfulParse +@run +def testParseFromFileSuccess(): + ctx = Context() + with NamedTemporaryFile(mode="w") as tmp_file: + tmp_file.write(r"""module @successfulParse {}""") + tmp_file.flush() + module = Module.parse(Path(tmp_file.name), ctx) + assert module.context is ctx + print("CLEAR CONTEXT") + ctx = None # Ensure that module captures the context. + gc.collect() + module.dump() # Just outputs to stderr. Verifies that it functions. + print(str(module)) + + # Verify parse error. # CHECK-LABEL: TEST: testParseError # CHECK: testParseError: < From 7f78298dd022076237960c4a36af9d71f1af54d4 Mon Sep 17 00:00:00 2001 From: Nikhil Kalra Date: Tue, 4 Feb 2025 16:33:58 -0800 Subject: [PATCH 3/3] use verify instead --- mlir/test/python/ir/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/test/python/ir/module.py b/mlir/test/python/ir/module.py index d0ef69f39b4e3..441916b38ee73 100644 --- a/mlir/test/python/ir/module.py +++ b/mlir/test/python/ir/module.py @@ -43,7 +43,7 @@ def testParseFromFileSuccess(): print("CLEAR CONTEXT") ctx = None # Ensure that module captures the context. gc.collect() - module.dump() # Just outputs to stderr. Verifies that it functions. + module.operation.verify() print(str(module))