Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.td"
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
include "mlir/Dialect/Transform/IR/TransformDialect.td"

def DebugEmitRemarkAtOp : TransformDialectOp<"debug.emit_remark_at",
def EmitRemarkAtOp : TransformDialectOp<"debug.emit_remark_at",
[MatchOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>,
MemoryEffectsOpInterface, NavigationTransformOpTrait]> {
Expand All @@ -39,7 +39,7 @@ def DebugEmitRemarkAtOp : TransformDialectOp<"debug.emit_remark_at",
let assemblyFormat = "$at `,` $message attr-dict `:` type($at)";
}

def DebugEmitParamAsRemarkOp
def EmitParamAsRemarkOp
: TransformDialectOp<"debug.emit_param_as_remark",
[MatchOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>,
Expand Down
13 changes: 7 additions & 6 deletions mlir/lib/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ using namespace mlir;
#include "mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp.inc"

DiagnosedSilenceableFailure
transform::DebugEmitRemarkAtOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
transform::EmitRemarkAtOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
if (isa<TransformHandleTypeInterface>(getAt().getType())) {
auto payload = state.getPayloadOps(getAt());
for (Operation *op : payload)
Expand Down Expand Up @@ -52,9 +52,10 @@ transform::DebugEmitRemarkAtOp::apply(transform::TransformRewriter &rewriter,
return DiagnosedSilenceableFailure::success();
}

DiagnosedSilenceableFailure transform::DebugEmitParamAsRemarkOp::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
DiagnosedSilenceableFailure
transform::EmitParamAsRemarkOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
std::string str;
llvm::raw_string_ostream os(str);
if (getMessage())
Expand Down
9 changes: 9 additions & 0 deletions mlir/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,15 @@ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
DIALECT_NAME transform
EXTENSION_NAME transform_pdl_extension)

declare_mlir_dialect_extension_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/TransformDebugExtensionOps.td
SOURCES
dialects/transform/debug.py
DIALECT_NAME transform
EXTENSION_NAME transform_debug_extension)

declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
Expand Down
19 changes: 19 additions & 0 deletions mlir/python/mlir/dialects/TransformDebugExtensionOps.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
//===-- TransformDebugExtensionOps.td - Binding entry point *- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Entry point of the generated Python bindings for the Debug extension of the
// Transform dialect.
//
//===----------------------------------------------------------------------===//

#ifndef PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS
#define PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS

include "mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.td"

#endif // PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS
81 changes: 81 additions & 0 deletions mlir/python/mlir/dialects/transform/debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Optional

from ...ir import Attribute, Operation, Value, StringAttr
from .._transform_debug_extension_ops_gen import *
from .._transform_pdl_extension_ops_gen import _Dialect

try:
from .._ods_common import _cext as _ods_cext
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e

from typing import Union


@_ods_cext.register_operation(_Dialect, replace=True)
class EmitParamAsRemarkOp(EmitParamAsRemarkOp):
def __init__(
self,
param: Attribute,
*,
anchor: Optional[Operation] = None,
message: Optional[Union[StringAttr, str]] = None,
loc=None,
ip=None,
):
if isinstance(message, str):
message = StringAttr.get(message)

super().__init__(
param,
anchor=anchor,
message=message,
loc=loc,
ip=ip,
)


def emit_param_as_remark(
param: Attribute,
*,
anchor: Optional[Operation] = None,
message: Optional[Union[StringAttr, str]] = None,
loc=None,
ip=None,
):
return EmitParamAsRemarkOp(param, anchor=anchor, message=message, loc=loc, ip=ip)


@_ods_cext.register_operation(_Dialect, replace=True)
class EmitRemarkAtOp(EmitRemarkAtOp):
def __init__(
self,
at: Union[Operation, Value],
message: Optional[Union[StringAttr, str]] = None,
*,
loc=None,
ip=None,
):
if isinstance(message, str):
message = StringAttr.get(message)

super().__init__(
at,
message,
loc=loc,
ip=ip,
)


def emit_remark_at(
at: Union[Operation, Value],
message: Optional[Union[StringAttr, str]] = None,
*,
loc=None,
ip=None,
):
return EmitRemarkAtOp(at, message, loc=loc, ip=ip)
45 changes: 45 additions & 0 deletions mlir/test/python/dialects/transform_debug_ext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# RUN: %PYTHON %s | FileCheck %s

from mlir.ir import *
from mlir.dialects import transform
from mlir.dialects.transform import debug


def run(f):
print("\nTEST:", f.__name__)
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.AnyOpType.get(),
)
with InsertionPoint(sequence.body):
f(sequence.bodyTarget)
transform.YieldOp()
print(module)
return f


@run
def testDebugEmitParamAsRemark(target):
i0 = IntegerAttr.get(IntegerType.get_signless(32), 0)
i0_param = transform.ParamConstantOp(transform.AnyParamType.get(), i0)
debug.emit_param_as_remark(i0_param)
debug.emit_param_as_remark(i0_param, anchor=target, message="some text")
# CHECK-LABEL: TEST: testDebugEmitParamAsRemark
# CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
# CHECK: %[[PARAM:.*]] = transform.param.constant
# CHECK: transform.debug.emit_param_as_remark %[[PARAM]]
# CHECK: transform.debug.emit_param_as_remark %[[PARAM]]
# CHECK-SAME: "some text"
# CHECK-SAME: at %[[ARG0]]


@run
def testDebugEmitRemarkAtOp(target):
debug.emit_remark_at(target, "some text")
# CHECK-LABEL: TEST: testDebugEmitRemarkAtOp
# CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
# CHECK: transform.debug.emit_remark_at %[[ARG0]], "some text"
Loading