diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index d961482885300..7b790e90e0d87 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -97,6 +97,10 @@ static const char kOperationPrintDocstring[] = binary: Whether to write bytes (True) or str (False). Defaults to False. large_elements_limit: Whether to elide elements attributes above this number of elements. Defaults to None (no limit). + large_resource_limit: Whether to elide resource attributes above this + number of characters. Defaults to None (no limit). If large_elements_limit + is set and this is None, the behavior will be to use large_elements_limit + as large_resource_limit. enable_debug_info: Whether to print debug/location information. Defaults to False. pretty_debug_info: Whether to format debug information for easier reading @@ -1303,6 +1307,7 @@ void PyOperation::checkValid() const { } void PyOperationBase::print(std::optional largeElementsLimit, + std::optional largeResourceLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool useNameLocAsPrefix, bool assumeVerified, @@ -1314,10 +1319,10 @@ void PyOperationBase::print(std::optional largeElementsLimit, fileObject = nb::module_::import_("sys").attr("stdout"); MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); - if (largeElementsLimit) { + if (largeElementsLimit) mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); - mlirOpPrintingFlagsElideLargeResourceString(flags, *largeElementsLimit); - } + if (largeResourceLimit) + mlirOpPrintingFlagsElideLargeResourceString(flags, *largeResourceLimit); if (enableDebugInfo) mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true, /*prettyForm=*/prettyDebugInfo); @@ -1405,6 +1410,7 @@ void PyOperationBase::walk( nb::object PyOperationBase::getAsm(bool binary, std::optional largeElementsLimit, + std::optional largeResourceLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool useNameLocAsPrefix, bool assumeVerified, @@ -1416,6 +1422,7 @@ nb::object PyOperationBase::getAsm(bool binary, fileObject = nb::module_::import_("io").attr("StringIO")(); } print(/*largeElementsLimit=*/largeElementsLimit, + /*largeResourceLimit=*/largeResourceLimit, /*enableDebugInfo=*/enableDebugInfo, /*prettyDebugInfo=*/prettyDebugInfo, /*printGenericOpForm=*/printGenericOpForm, @@ -3348,6 +3355,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { [](PyOperationBase &self) { return self.getAsm(/*binary=*/false, /*largeElementsLimit=*/std::nullopt, + /*largeResourceLimit=*/std::nullopt, /*enableDebugInfo=*/false, /*prettyDebugInfo=*/false, /*printGenericOpForm=*/false, @@ -3363,11 +3371,12 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::arg("state"), nb::arg("file").none() = nb::none(), nb::arg("binary") = false, kOperationPrintStateDocstring) .def("print", - nb::overload_cast, bool, bool, bool, bool, - bool, bool, nb::object, bool, bool>( - &PyOperationBase::print), + nb::overload_cast, std::optional, + bool, bool, bool, bool, bool, bool, nb::object, + bool, bool>(&PyOperationBase::print), // Careful: Lots of arguments must match up with print method. nb::arg("large_elements_limit").none() = nb::none(), + nb::arg("large_resource_limit").none() = nb::none(), nb::arg("enable_debug_info") = false, nb::arg("pretty_debug_info") = false, nb::arg("print_generic_op_form") = false, @@ -3383,6 +3392,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { // Careful: Lots of arguments must match up with get_asm method. nb::arg("binary") = false, nb::arg("large_elements_limit").none() = nb::none(), + nb::arg("large_resource_limit").none() = nb::none(), nb::arg("enable_debug_info") = false, nb::arg("pretty_debug_info") = false, nb::arg("print_generic_op_form") = false, diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 9befcce725bb7..0fdd2d1a7eff6 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -599,18 +599,18 @@ class PyOperationBase { public: virtual ~PyOperationBase() = default; /// Implements the bound 'print' method and helps with others. - void print(std::optional largeElementsLimit, bool enableDebugInfo, + void print(std::optional largeElementsLimit, + std::optional largeResourceLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool useNameLocAsPrefix, bool assumeVerified, nanobind::object fileObject, bool binary, bool skipRegions); void print(PyAsmState &state, nanobind::object fileObject, bool binary); - nanobind::object getAsm(bool binary, - std::optional largeElementsLimit, - bool enableDebugInfo, bool prettyDebugInfo, - bool printGenericOpForm, bool useLocalScope, - bool useNameLocAsPrefix, bool assumeVerified, - bool skipRegions); + nanobind::object + getAsm(bool binary, std::optional largeElementsLimit, + std::optional largeResourceLimit, bool enableDebugInfo, + bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, + bool useNameLocAsPrefix, bool assumeVerified, bool skipRegions); // Implement the bound 'writeBytecode' method. void writeBytecode(const nanobind::object &fileObject, diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index 8d84864b9db4d..20017e25b69bb 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -78,12 +78,19 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { [](PyPassManager &passManager, bool printBeforeAll, bool printAfterAll, bool printModuleScope, bool printAfterChange, bool printAfterFailure, std::optional largeElementsLimit, - bool enableDebugInfo, bool printGenericOpForm, + std::optional largeResourceLimit, bool enableDebugInfo, + bool printGenericOpForm, std::optional optionalTreePrintingPath) { MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); - if (largeElementsLimit) + if (largeElementsLimit) { mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); + mlirOpPrintingFlagsElideLargeResourceString(flags, + *largeElementsLimit); + } + if (largeResourceLimit) + mlirOpPrintingFlagsElideLargeResourceString(flags, + *largeResourceLimit); if (enableDebugInfo) mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true, /*prettyForm=*/false); @@ -103,6 +110,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { "print_module_scope"_a = false, "print_after_change"_a = false, "print_after_failure"_a = false, "large_elements_limit"_a.none() = nb::none(), + "large_resource_limit"_a.none() = nb::none(), "enable_debug_info"_a = false, "print_generic_op_form"_a = false, "tree_printing_dir_path"_a.none() = nb::none(), "Enable IR printing, default as mlir-print-ir-after-all.") diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index ed476da28d6be..be71737e4b5b4 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -200,6 +200,7 @@ class _OperationBase: def get_asm( binary: Literal[True], large_elements_limit: int | None = None, + large_resource_limit: int | None = None, enable_debug_info: bool = False, pretty_debug_info: bool = False, print_generic_op_form: bool = False, @@ -212,6 +213,7 @@ class _OperationBase: self, binary: bool = False, large_elements_limit: int | None = None, + large_resource_limit: int | None = None, enable_debug_info: bool = False, pretty_debug_info: bool = False, print_generic_op_form: bool = False, @@ -253,6 +255,7 @@ class _OperationBase: def print( self, large_elements_limit: int | None = None, + large_resource_limit: int | None = None, enable_debug_info: bool = False, pretty_debug_info: bool = False, print_generic_op_form: bool = False, @@ -270,6 +273,10 @@ class _OperationBase: binary: Whether to write bytes (True) or str (False). Defaults to False. large_elements_limit: Whether to elide elements attributes above this number of elements. Defaults to None (no limit). + large_resource_limit: Whether to elide resource strings above this + number of characters. Defaults to None (no limit). If large_elements_limit + is set and this is None, the behavior will be to use large_elements_limit + as large_resource_limit. enable_debug_info: Whether to print debug/location information. Defaults to False. pretty_debug_info: Whether to format debug information for easier reading diff --git a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi index 0d2eaffe16d3e..1010daddae2aa 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi @@ -23,6 +23,7 @@ class PassManager: print_after_change: bool = False, print_after_failure: bool = False, large_elements_limit: int | None = None, + large_resource_limit: int | None = None, enable_debug_info: bool = False, print_generic_op_form: bool = False, tree_printing_dir_path: str | None = None, diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py index b08fe98397fbc..ede1571f940f6 100644 --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -686,6 +686,15 @@ def testOperationPrint(): skip_regions=True, ) + # Test print with large_resource_limit. + # CHECK: func.func @f1(%arg0: i32) -> i32 + # CHECK-NOT: resource1: "0x08 + module.operation.print(large_resource_limit=2) + + # Test large_elements_limit has no effect on resource string + # CHECK: func.func @f1(%arg0: i32) -> i32 + # CHECK: resource1: "0x08 + module.operation.print(large_elements_limit=2) # CHECK-LABEL: TEST: testKnownOpView @run diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py index 85d2eb304882e..e26d42bb32913 100644 --- a/mlir/test/python/pass_manager.py +++ b/mlir/test/python/pass_manager.py @@ -363,6 +363,63 @@ def testPrintIrLargeLimitElements(): pm.run(module) +# CHECK-LABEL: TEST: testPrintIrLargeResourceLimit +@run +def testPrintIrLargeResourceLimit(): + with Context() as ctx: + module = ModuleOp.parse( + """ + module { + func.func @main() -> tensor<3xi64> { + %0 = arith.constant dense_resource : tensor<3xi64> + return %0 : tensor<3xi64> + } + } + {-# + dialect_resources: { + builtin: { + blob1: "0x010000000000000002000000000000000300000000000000" + } + } + #-} + """ + ) + pm = PassManager.parse("builtin.module(canonicalize)") + ctx.enable_multithreading(False) + pm.enable_ir_printing(large_resource_limit=4) + # CHECK-NOT: blob1: "0x01 + pm.run(module) + + +# CHECK-LABEL: TEST: testPrintIrLargeResourceLimitVsElementsLimit +@run +def testPrintIrLargeResourceLimitVsElementsLimit(): + """Test that large_elements_limit does not affect the printing of resources.""" + with Context() as ctx: + module = ModuleOp.parse( + """ + module { + func.func @main() -> tensor<3xi64> { + %0 = arith.constant dense_resource : tensor<3xi64> + return %0 : tensor<3xi64> + } + } + {-# + dialect_resources: { + builtin: { + blob1: "0x010000000000000002000000000000000300000000000000" + } + } + #-} + """ + ) + pm = PassManager.parse("builtin.module(canonicalize)") + ctx.enable_multithreading(False) + pm.enable_ir_printing(large_elements_limit=1) + # CHECK-NOT: blob1: "0x01 + pm.run(module) + + # CHECK-LABEL: TEST: testPrintIrTree @run def testPrintIrTree():