Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ static const char kOperationPrintDocstring[] =
binary: Whether to write bytes (True) or str (False). Defaults to False.
large_elements_limit: Whether to elide elements attributes above this
number of elements. Defaults to None (no limit).
large_resource_limit: Whether to elide resource attributes above this
number of characters. Defaults to None (no limit). If large_elements_limit
is set and this is None, the behavior will be to use large_elements_limit
as large_resource_limit.
enable_debug_info: Whether to print debug/location information. Defaults
to False.
pretty_debug_info: Whether to format debug information for easier reading
Expand Down Expand Up @@ -1303,6 +1307,7 @@ void PyOperation::checkValid() const {
}

void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
std::optional<int64_t> largeResourceLimit,
bool enableDebugInfo, bool prettyDebugInfo,
bool printGenericOpForm, bool useLocalScope,
bool useNameLocAsPrefix, bool assumeVerified,
Expand All @@ -1314,10 +1319,10 @@ void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
fileObject = nb::module_::import_("sys").attr("stdout");

MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
if (largeElementsLimit) {
if (largeElementsLimit)
mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
mlirOpPrintingFlagsElideLargeResourceString(flags, *largeElementsLimit);
}
if (largeResourceLimit)
mlirOpPrintingFlagsElideLargeResourceString(flags, *largeResourceLimit);
if (enableDebugInfo)
mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
/*prettyForm=*/prettyDebugInfo);
Expand Down Expand Up @@ -1405,6 +1410,7 @@ void PyOperationBase::walk(

nb::object PyOperationBase::getAsm(bool binary,
std::optional<int64_t> largeElementsLimit,
std::optional<int64_t> largeResourceLimit,
bool enableDebugInfo, bool prettyDebugInfo,
bool printGenericOpForm, bool useLocalScope,
bool useNameLocAsPrefix, bool assumeVerified,
Expand All @@ -1416,6 +1422,7 @@ nb::object PyOperationBase::getAsm(bool binary,
fileObject = nb::module_::import_("io").attr("StringIO")();
}
print(/*largeElementsLimit=*/largeElementsLimit,
/*largeResourceLimit=*/largeResourceLimit,
/*enableDebugInfo=*/enableDebugInfo,
/*prettyDebugInfo=*/prettyDebugInfo,
/*printGenericOpForm=*/printGenericOpForm,
Expand Down Expand Up @@ -3348,6 +3355,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
[](PyOperationBase &self) {
return self.getAsm(/*binary=*/false,
/*largeElementsLimit=*/std::nullopt,
/*largeResourceLimit=*/std::nullopt,
/*enableDebugInfo=*/false,
/*prettyDebugInfo=*/false,
/*printGenericOpForm=*/false,
Expand All @@ -3363,11 +3371,12 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("state"), nb::arg("file").none() = nb::none(),
nb::arg("binary") = false, kOperationPrintStateDocstring)
.def("print",
nb::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
bool, bool, nb::object, bool, bool>(
&PyOperationBase::print),
nb::overload_cast<std::optional<int64_t>, std::optional<int64_t>,
bool, bool, bool, bool, bool, bool, nb::object,
bool, bool>(&PyOperationBase::print),
// Careful: Lots of arguments must match up with print method.
nb::arg("large_elements_limit").none() = nb::none(),
nb::arg("large_resource_limit").none() = nb::none(),
nb::arg("enable_debug_info") = false,
nb::arg("pretty_debug_info") = false,
nb::arg("print_generic_op_form") = false,
Expand All @@ -3383,6 +3392,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
// Careful: Lots of arguments must match up with get_asm method.
nb::arg("binary") = false,
nb::arg("large_elements_limit").none() = nb::none(),
nb::arg("large_resource_limit").none() = nb::none(),
nb::arg("enable_debug_info") = false,
nb::arg("pretty_debug_info") = false,
nb::arg("print_generic_op_form") = false,
Expand Down
14 changes: 7 additions & 7 deletions mlir/lib/Bindings/Python/IRModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -599,18 +599,18 @@ class PyOperationBase {
public:
virtual ~PyOperationBase() = default;
/// Implements the bound 'print' method and helps with others.
void print(std::optional<int64_t> largeElementsLimit, bool enableDebugInfo,
void print(std::optional<int64_t> largeElementsLimit,
std::optional<int64_t> largeResourceLimit, bool enableDebugInfo,
bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope,
bool useNameLocAsPrefix, bool assumeVerified,
nanobind::object fileObject, bool binary, bool skipRegions);
void print(PyAsmState &state, nanobind::object fileObject, bool binary);

nanobind::object getAsm(bool binary,
std::optional<int64_t> largeElementsLimit,
bool enableDebugInfo, bool prettyDebugInfo,
bool printGenericOpForm, bool useLocalScope,
bool useNameLocAsPrefix, bool assumeVerified,
bool skipRegions);
nanobind::object
getAsm(bool binary, std::optional<int64_t> largeElementsLimit,
std::optional<int64_t> largeResourceLimit, bool enableDebugInfo,
bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope,
bool useNameLocAsPrefix, bool assumeVerified, bool skipRegions);

// Implement the bound 'writeBytecode' method.
void writeBytecode(const nanobind::object &fileObject,
Expand Down
12 changes: 10 additions & 2 deletions mlir/lib/Bindings/Python/Pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,19 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
[](PyPassManager &passManager, bool printBeforeAll,
bool printAfterAll, bool printModuleScope, bool printAfterChange,
bool printAfterFailure, std::optional<int64_t> largeElementsLimit,
bool enableDebugInfo, bool printGenericOpForm,
std::optional<int64_t> largeResourceLimit, bool enableDebugInfo,
bool printGenericOpForm,
std::optional<std::string> optionalTreePrintingPath) {
MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
if (largeElementsLimit)
if (largeElementsLimit) {
mlirOpPrintingFlagsElideLargeElementsAttrs(flags,
*largeElementsLimit);
mlirOpPrintingFlagsElideLargeResourceString(flags,
*largeElementsLimit);
}
if (largeResourceLimit)
mlirOpPrintingFlagsElideLargeResourceString(flags,
*largeResourceLimit);
if (enableDebugInfo)
mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
/*prettyForm=*/false);
Expand All @@ -103,6 +110,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
"print_module_scope"_a = false, "print_after_change"_a = false,
"print_after_failure"_a = false,
"large_elements_limit"_a.none() = nb::none(),
"large_resource_limit"_a.none() = nb::none(),
"enable_debug_info"_a = false, "print_generic_op_form"_a = false,
"tree_printing_dir_path"_a.none() = nb::none(),
"Enable IR printing, default as mlir-print-ir-after-all.")
Expand Down
7 changes: 7 additions & 0 deletions mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ class _OperationBase:
def get_asm(
binary: Literal[True],
large_elements_limit: int | None = None,
large_resource_limit: int | None = None,
enable_debug_info: bool = False,
pretty_debug_info: bool = False,
print_generic_op_form: bool = False,
Expand All @@ -212,6 +213,7 @@ class _OperationBase:
self,
binary: bool = False,
large_elements_limit: int | None = None,
large_resource_limit: int | None = None,
enable_debug_info: bool = False,
pretty_debug_info: bool = False,
print_generic_op_form: bool = False,
Expand Down Expand Up @@ -253,6 +255,7 @@ class _OperationBase:
def print(
self,
large_elements_limit: int | None = None,
large_resource_limit: int | None = None,
enable_debug_info: bool = False,
pretty_debug_info: bool = False,
print_generic_op_form: bool = False,
Expand All @@ -270,6 +273,10 @@ class _OperationBase:
binary: Whether to write bytes (True) or str (False). Defaults to False.
large_elements_limit: Whether to elide elements attributes above this
number of elements. Defaults to None (no limit).
large_resource_limit: Whether to elide resource strings above this
number of characters. Defaults to None (no limit). If large_elements_limit
is set and this is None, the behavior will be to use large_elements_limit
as large_resource_limit.
enable_debug_info: Whether to print debug/location information. Defaults
to False.
pretty_debug_info: Whether to format debug information for easier reading
Expand Down
1 change: 1 addition & 0 deletions mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class PassManager:
print_after_change: bool = False,
print_after_failure: bool = False,
large_elements_limit: int | None = None,
large_resource_limit: int | None = None,
enable_debug_info: bool = False,
print_generic_op_form: bool = False,
tree_printing_dir_path: str | None = None,
Expand Down
9 changes: 9 additions & 0 deletions mlir/test/python/ir/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,15 @@ def testOperationPrint():
skip_regions=True,
)

# Test print with large_resource_limit.
# CHECK: func.func @f1(%arg0: i32) -> i32
# CHECK-NOT: resource1: "0x08
module.operation.print(large_resource_limit=2)

# Test large_elements_limit has no effect on resource string
# CHECK: func.func @f1(%arg0: i32) -> i32
# CHECK: resource1: "0x08
module.operation.print(large_elements_limit=2)

# CHECK-LABEL: TEST: testKnownOpView
@run
Expand Down
57 changes: 57 additions & 0 deletions mlir/test/python/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,63 @@ def testPrintIrLargeLimitElements():
pm.run(module)


# CHECK-LABEL: TEST: testPrintIrLargeResourceLimit
@run
def testPrintIrLargeResourceLimit():
with Context() as ctx:
module = ModuleOp.parse(
"""
module {
func.func @main() -> tensor<3xi64> {
%0 = arith.constant dense_resource<blob1> : tensor<3xi64>
return %0 : tensor<3xi64>
}
}
{-#
dialect_resources: {
builtin: {
blob1: "0x010000000000000002000000000000000300000000000000"
}
}
#-}
"""
)
pm = PassManager.parse("builtin.module(canonicalize)")
ctx.enable_multithreading(False)
pm.enable_ir_printing(large_resource_limit=4)
# CHECK-NOT: blob1: "0x01
pm.run(module)


# CHECK-LABEL: TEST: testPrintIrLargeResourceLimitVsElementsLimit
@run
def testPrintIrLargeResourceLimitVsElementsLimit():
"""Test that large_elements_limit does not affect the printing of resources."""
with Context() as ctx:
module = ModuleOp.parse(
"""
module {
func.func @main() -> tensor<3xi64> {
%0 = arith.constant dense_resource<blob1> : tensor<3xi64>
return %0 : tensor<3xi64>
}
}
{-#
dialect_resources: {
builtin: {
blob1: "0x010000000000000002000000000000000300000000000000"
}
}
#-}
"""
)
pm = PassManager.parse("builtin.module(canonicalize)")
ctx.enable_multithreading(False)
pm.enable_ir_printing(large_elements_limit=1)
# CHECK-NOT: blob1: "0x01
pm.run(module)


# CHECK-LABEL: TEST: testPrintIrTree
@run
def testPrintIrTree():
Expand Down