Skip to content

Commit d502b35

Browse files
committed
[MLIR][Transform][Python] expose transform.debug extension in Python
1 parent 050f628 commit d502b35

File tree

4 files changed

+161
-0
lines changed

4 files changed

+161
-0
lines changed

mlir/python/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,15 @@ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
171171
DIALECT_NAME transform
172172
EXTENSION_NAME transform_pdl_extension)
173173

174+
declare_mlir_dialect_extension_python_bindings(
175+
ADD_TO_PARENT MLIRPythonSources.Dialects
176+
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
177+
TD_FILE dialects/TransformDebugExtensionOps.td
178+
SOURCES
179+
dialects/transform/debug.py
180+
DIALECT_NAME transform
181+
EXTENSION_NAME transform_debug_extension)
182+
174183
declare_mlir_dialect_python_bindings(
175184
ADD_TO_PARENT MLIRPythonSources.Dialects
176185
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
//===-- TransformDebugExtensionOps.td - Binding entry point *- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Entry point of the generated Python bindings for the Debug extension of the
10+
// Transform dialect.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS
15+
#define PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS
16+
17+
include "mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.td"
18+
19+
#endif // PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
5+
from typing import Optional
6+
7+
from ...ir import Attribute, Operation, Value, StringAttr
8+
from .._transform_debug_extension_ops_gen import *
9+
from .._transform_pdl_extension_ops_gen import _Dialect
10+
11+
try:
12+
from .._ods_common import _cext as _ods_cext
13+
except ImportError as e:
14+
raise RuntimeError("Error loading imports from extension module") from e
15+
16+
from typing import Union
17+
18+
19+
@_ods_cext.register_operation(_Dialect, replace=True)
20+
class DebugEmitParamAsRemarkOp(DebugEmitParamAsRemarkOp):
21+
def __init__(
22+
self,
23+
param: Attribute,
24+
*,
25+
anchor: Optional[Operation] = None,
26+
message: Optional[Union[StringAttr, str]] = None,
27+
loc=None,
28+
ip=None,
29+
):
30+
if isinstance(message, str):
31+
message = StringAttr.get(message)
32+
33+
super().__init__(
34+
param,
35+
anchor=anchor,
36+
message=message,
37+
loc=loc,
38+
ip=ip,
39+
)
40+
41+
42+
def emit_param_as_remark(
43+
param: Attribute,
44+
*,
45+
anchor: Optional[Operation] = None,
46+
message: Optional[Union[StringAttr, str]] = None,
47+
loc=None,
48+
ip=None,
49+
):
50+
return DebugEmitParamAsRemarkOp(
51+
param, anchor=anchor, message=message, loc=loc, ip=ip
52+
)
53+
54+
del debug_emit_param_as_remark
55+
56+
@_ods_cext.register_operation(_Dialect, replace=True)
57+
class DebugEmitRemarkAtOp(DebugEmitRemarkAtOp):
58+
def __init__(
59+
self,
60+
at: Union[Operation, Value],
61+
message: Optional[Union[StringAttr, str]] = None,
62+
*,
63+
loc=None,
64+
ip=None,
65+
):
66+
if isinstance(message, str):
67+
message = StringAttr.get(message)
68+
69+
super().__init__(
70+
at,
71+
message,
72+
loc=loc,
73+
ip=ip,
74+
)
75+
76+
77+
def emit_remark_at(
78+
at: Union[Operation, Value],
79+
message: Optional[Union[StringAttr, str]] = None,
80+
*,
81+
loc=None,
82+
ip=None,
83+
):
84+
return DebugEmitRemarkAtOp(at, message, loc=loc, ip=ip)
85+
86+
del debug_emit_remark_at
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# RUN: %PYTHON %s | FileCheck %s
2+
3+
from mlir.ir import *
4+
from mlir.dialects import transform
5+
from mlir.dialects.transform import debug
6+
7+
8+
def run(f):
9+
print("\nTEST:", f.__name__)
10+
with Context(), Location.unknown():
11+
module = Module.create()
12+
with InsertionPoint(module.body):
13+
sequence = transform.SequenceOp(
14+
transform.FailurePropagationMode.Propagate,
15+
[],
16+
transform.AnyOpType.get(),
17+
)
18+
with InsertionPoint(sequence.body):
19+
f(sequence.bodyTarget)
20+
transform.YieldOp()
21+
print(module)
22+
return f
23+
24+
25+
@run
26+
def testDebugEmitParamAsRemark(target):
27+
i0 = IntegerAttr.get(IntegerType.get_signless(32), 0)
28+
i0_param = transform.ParamConstantOp(transform.AnyParamType.get(), i0)
29+
debug.emit_param_as_remark(i0_param)
30+
debug.emit_param_as_remark(i0_param, anchor=target, message="some text")
31+
# CHECK-LABEL: TEST: testDebugEmitParamAsRemark
32+
# CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
33+
# CHECK: %[[PARAM:.*]] = transform.param.constant
34+
# CHECK: transform.debug.emit_param_as_remark %[[PARAM]]
35+
# CHECK: transform.debug.emit_param_as_remark %[[PARAM]]
36+
# CHECK-SAME: "some text"
37+
# CHECK-SAME: at %[[ARG0]]
38+
39+
40+
@run
41+
def testDebugEmitRemarkAtOp(target):
42+
i0 = IntegerAttr.get(IntegerType.get_signless(32), 0)
43+
i0_param = transform.ParamConstantOp(transform.AnyParamType.get(), i0)
44+
debug.emit_remark_at(target, "some text")
45+
# CHECK-LABEL: TEST: testDebugEmitRemarkAtOp
46+
# CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
47+
# CHECK: transform.debug.emit_remark_at %[[ARG0]], "some text"

0 commit comments

Comments
 (0)