From f364860b81b52e7dc1539ac71030fe4ebe581b97 Mon Sep 17 00:00:00 2001 From: MaheshRavishankar Date: Mon, 3 Mar 2025 12:50:20 -0600 Subject: [PATCH 1/6] [mlir] Add a utility method to move operation dependencies. The added utility method moves all SSA values that an operation depends upon before an insertion point. This is useful during transformations where such movements might make transformations (like fusion) more powerful. To test the operation add a transform dialect op that calls the move operation. Signed-off-by: MaheshRavishankar --- mlir/include/mlir/Transforms/RegionUtils.h | 10 ++ mlir/lib/Transforms/Utils/RegionUtils.cpp | 65 ++++++++++ mlir/test/Transforms/move-operation-deps.mlir | 113 ++++++++++++++++++ mlir/test/lib/Transforms/CMakeLists.txt | 8 ++ .../test/lib/Transforms/TestTransformsOps.cpp | 66 ++++++++++ mlir/test/lib/Transforms/TestTransformsOps.td | 41 +++++++ mlir/tools/mlir-opt/mlir-opt.cpp | 2 + 7 files changed, 305 insertions(+) create mode 100644 mlir/test/Transforms/move-operation-deps.mlir create mode 100644 mlir/test/lib/Transforms/TestTransformsOps.cpp create mode 100644 mlir/test/lib/Transforms/TestTransformsOps.td diff --git a/mlir/include/mlir/Transforms/RegionUtils.h b/mlir/include/mlir/Transforms/RegionUtils.h index 5c57dd5b7532a..4acc8528efe97 100644 --- a/mlir/include/mlir/Transforms/RegionUtils.h +++ b/mlir/include/mlir/Transforms/RegionUtils.h @@ -15,6 +15,7 @@ #include "llvm/ADT/SetVector.h" namespace mlir { +class DominanceInfo; class RewriterBase; /// Check if all values in the provided range are defined above the `limit` @@ -69,6 +70,15 @@ SmallVector makeRegionIsolatedFromAbove( llvm::function_ref cloneOperationIntoRegion = [](Operation *) { return false; }); +/// Move SSA values used within an operation before an insertion point, +/// so that the operation itself (or its replacement) can be moved to +/// the insertion point. +LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op, + Operation *insertionPoint, + DominanceInfo &dominance); +LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op, + Operation *insertionPoint); + /// Run a set of structural simplifications over the given regions. This /// includes transformations like unreachable block elimination, dead argument /// elimination, as well as some other DCE. This function returns success if any diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp index e55ef6eb66b9c..7040243bed83b 100644 --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -7,9 +7,12 @@ //===----------------------------------------------------------------------===// #include "mlir/Transforms/RegionUtils.h" + +#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/IR/Block.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dominance.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" @@ -1054,3 +1057,65 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter, return success(eliminatedBlocks || eliminatedOpsOrArgs || mergedIdenticalBlocks || droppedRedundantArguments); } + +//===---------------------------------------------------------------------===// +// Move operation dependencies +//===---------------------------------------------------------------------===// + +LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter, + Operation *op, + Operation *insertionPoint, + DominanceInfo &dominance) { + // Currently unsupported case where the op and insertion point are + // in different basic blocks. + if (op->getBlock() != insertionPoint->getBlock()) { + return rewriter.notifyMatchFailure( + op, "unsupported caes where operation and insertion point are not in " + "the sme basic block"); + } + + // Find the backward slice of operation for each `Value` the operation + // depends on. Prune the slice to only include operations not already + // dominated by the `insertionPoint` + BackwardSliceOptions options; + options.inclusive = true; + options.filter = [&](Operation *sliceBoundaryOp) { + return !dominance.properlyDominates(sliceBoundaryOp, insertionPoint); + }; + llvm::SetVector slice; + + // Get the defined slice for operands. + for (Value operand : op->getOperands()) { + getBackwardSlice(operand, &slice, options); + } + auto regions = op->getRegions(); + if (!regions.empty()) { + // If op has region, get the defined slice for all captured values. + llvm::SetVector capturedVals; + mlir::getUsedValuesDefinedAbove(regions, capturedVals); + for (auto value : capturedVals) { + getBackwardSlice(value, &slice, options); + } + } + + // If the slice contains `insertionPoint` cannot move the dependencies. + if (slice.contains(insertionPoint)) { + return rewriter.notifyMatchFailure( + op, + "cannot move dependencies before operation in backward slice of op"); + } + + // Sort the slice topologically, ad move in topological order. + mlir::topologicalSort(slice); + for (auto op : slice) { + rewriter.moveOpBefore(op, insertionPoint); + } + return success(); +} + +LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter, + Operation *op, + Operation *insertionPoint) { + DominanceInfo dominance(op); + return moveOperationDependencies(rewriter, op, insertionPoint, dominance); +} \ No newline at end of file diff --git a/mlir/test/Transforms/move-operation-deps.mlir b/mlir/test/Transforms/move-operation-deps.mlir new file mode 100644 index 0000000000000..90c66a0f14938 --- /dev/null +++ b/mlir/test/Transforms/move-operation-deps.mlir @@ -0,0 +1,113 @@ +// RUN: mlir-opt --allow-unregistered-dialect --transform-interpreter --split-input-file %s | FileCheck %s + +// Check simple move of dependent operation before insertion. +func.func @simple_move() -> f32 { + %0 = "before"() : () -> (f32) + %1 = "moved_op"() : () -> (f32) + %2 = "foo"(%1) : (f32) -> (f32) + return %2 : f32 +} +// CHECK-LABEL: func @simple_move() +// CHECK: %[[MOVED:.+]] = "moved_op" +// CHECK: %[[BEFORE:.+]] = "before" +// CHECK: %[[FOO:.+]] = "foo"(%[[MOVED]]) +// CHECK: return %[[FOO]] + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) { + %op1 = transform.structured.match ops{["foo"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %op2 = transform.structured.match ops{["before"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + transform.test.move_operand_deps %op1 before %op2 + : !transform.any_op, !transform.any_op + transform.yield + } +} + +// ----- + +// Move operands that are implicitly captured by the op +func.func @move_region_dependencies() -> f32 { + %0 = "before"() : () -> (f32) + %1 = "moved_op"() : () -> (f32) + %2 = "foo"() ({ + %3 = "inner_op"(%1) : (f32) -> (f32) + "yield"(%3) : (f32) -> () + }) : () -> (f32) + return %2 : f32 +} +// CHECK-LABEL: func @move_region_dependencies() +// CHECK: %[[MOVED:.+]] = "moved_op" +// CHECK: %[[BEFORE:.+]] = "before" +// CHECK: %[[FOO:.+]] = "foo" +// CHECK: return %[[FOO]] + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) { + %op1 = transform.structured.match ops{["foo"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %op2 = transform.structured.match ops{["before"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + transform.test.move_operand_deps %op1 before %op2 + : !transform.any_op, !transform.any_op + transform.yield + } +} + +// ----- + +// Move operations in toplogical sort order +func.func @move_in_topological_sort_order() -> f32 { + %0 = "before"() : () -> (f32) + %1 = "moved_op_1"() : () -> (f32) + %2 = "moved_op_2"() : () -> (f32) + %3 = "moved_op_3"(%1) : (f32) -> (f32) + %4 = "moved_op_4"(%1, %3) : (f32, f32) -> (f32) + %5 = "moved_op_5"(%2) : (f32) -> (f32) + %6 = "foo"(%4, %5) : (f32, f32) -> (f32) + return %6 : f32 +} +// CHECK-LABEL: func @move_in_topological_sort_order() +// CHECK: %[[MOVED_1:.+]] = "moved_op_1" +// CHECK-DAG: %[[MOVED_2:.+]] = "moved_op_3"(%[[MOVED_1]]) +// CHECK-DAG: %[[MOVED_3:.+]] = "moved_op_4"(%[[MOVED_1]], %[[MOVED_2]]) +// CHECK-DAG: %[[MOVED_4:.+]] = "moved_op_2" +// CHECK-DAG: %[[MOVED_5:.+]] = "moved_op_5"(%[[MOVED_4]]) +// CHECK: %[[BEFORE:.+]] = "before" +// CHECK: %[[FOO:.+]] = "foo"(%[[MOVED_3]], %[[MOVED_5]]) +// CHECK: return %[[FOO]] + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) { + %op1 = transform.structured.match ops{["foo"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %op2 = transform.structured.match ops{["before"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + transform.test.move_operand_deps %op1 before %op2 + : !transform.any_op, !transform.any_op + transform.yield + } +} + +// ----- + +// Fail when the "before" operation is part of the operation slice. +func.func @do_not_move_slice() -> f32 { + %0 = "before"() : () -> (f32) + %1 = "moved_op"(%0) : (f32) -> (f32) + %2 = "foo"(%1) : (f32) -> (f32) + return %2 : f32 +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) { + %op1 = transform.structured.match ops{["foo"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %op2 = transform.structured.match ops{["before"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + transform.test.move_operand_deps %op1 before %op2 + : !transform.any_op, !transform.any_op + transform.yield + } +} diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt index 1b9b9bffa5279..c053fd4b20473 100644 --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -1,3 +1,8 @@ +set(LLVM_TARGET_DEFINITIONS TestTransformsOps.td) +mlir_tablegen(TestTransformsOps.h.inc -gen-op-decls) +mlir_tablegen(TestTransformsOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRTestTransformsOpsIncGen) + set(LLVM_OPTIONAL_SOURCES TestDialectConversion.cpp) set(MLIRTestTransformsPDLDep) @@ -25,6 +30,7 @@ add_mlir_library(MLIRTestTransforms TestControlFlowSink.cpp TestInlining.cpp TestMakeIsolatedFromAbove.cpp + TestTransformsOps.cpp ${MLIRTestTransformsPDLSrc} EXCLUDE_FROM_LIBMLIR @@ -34,6 +40,7 @@ add_mlir_library(MLIRTestTransforms DEPENDS ${MLIRTestTransformsPDLDep} + MLIRTestTransformsOpsIncGen LINK_LIBS PUBLIC MLIRTestDialect @@ -43,6 +50,7 @@ mlir_target_link_libraries(MLIRTestTransforms PUBLIC MLIRFuncDialect MLIRInferIntRangeInterface MLIRTransforms + MLIRTransformDialect ) target_include_directories(MLIRTestTransforms diff --git a/mlir/test/lib/Transforms/TestTransformsOps.cpp b/mlir/test/lib/Transforms/TestTransformsOps.cpp new file mode 100644 index 0000000000000..427930b0c7ed1 --- /dev/null +++ b/mlir/test/lib/Transforms/TestTransformsOps.cpp @@ -0,0 +1,66 @@ +//===- TestTransformsOps.cpp - Test Transforms ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines transform dialect operations for testing MLIR +// transformations +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Transform/IR/TransformAttrs.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/Transforms/RegionUtils.h" + +#define GET_OP_CLASSES +#include "TestTransformsOps.h.inc" + +using namespace mlir; +using namespace mlir::transform; + +#define GET_OP_CLASSES +#include "TestTransformsOps.cpp.inc" + +DiagnosedSilenceableFailure +transform::TestMoveOperandDeps::apply(TransformRewriter &rewriter, + TransformResults &TransformResults, + TransformState &state) { + Operation *op = *state.getPayloadOps(getOp()).begin(); + Operation *moveBefore = *state.getPayloadOps(getInsertionPoint()).begin(); + if (failed(moveOperationDependencies(rewriter, op, moveBefore))) { + auto listener = cast(rewriter.getListener()); + std::string errorMsg = listener->checkAndResetError().getMessage(); + return emitSilenceableFailure(op, errorMsg); + } + return DiagnosedSilenceableFailure::success(); +} + +namespace { + +class TestTransformsDialectExtension + : public transform::TransformDialectExtension< + TestTransformsDialectExtension> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformsDialectExtension) + + using Base::Base; + + void init() { + registerTransformOps< +#define GET_OP_LIST +#include "TestTransformsOps.cpp.inc" + >(); + } +}; +} // namespace + +namespace test { +void registerTestTransformsTransformDialectExtension( + DialectRegistry ®istry) { + registry.addExtensions(); +} +} // namespace test \ No newline at end of file diff --git a/mlir/test/lib/Transforms/TestTransformsOps.td b/mlir/test/lib/Transforms/TestTransformsOps.td new file mode 100644 index 0000000000000..ef19d00f999c3 --- /dev/null +++ b/mlir/test/lib/Transforms/TestTransformsOps.td @@ -0,0 +1,41 @@ +//===- TestTransformOps.td ---------------------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TEST_TRANSFORM_OPS +#define TEST_TRANSFORM_OPS + +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" +include "mlir/Dialect/Transform/IR/TransformTypes.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpBase.td" + +/// Transform dialect perations for testing transformations in MLIR + +def TestMoveOperandDeps : + Op, + ReportTrackingListenerFailuresOpTrait]> { + let description = [{ + Moves all dependencies of on operation before another operation. + }]; + + let arguments = + (ins TransformHandleTypeInterface:$op, + TransformHandleTypeInterface:$insertion_point); + + let results = (outs); + + let assemblyFormat = [{ + $op `before` $insertion_point attr-dict + `:` type($op) `,` type($insertion_point) + }]; +} + +#endif // TEST_TRANSFORM_OPS diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index f18ad45dfb708..d06ff8070e7cf 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -170,6 +170,7 @@ void registerTestDialect(DialectRegistry &); void registerTestDynDialect(DialectRegistry &); void registerTestTilingInterfaceTransformDialectExtension(DialectRegistry &); void registerTestTransformDialectExtension(DialectRegistry &); +void registerTestTransformsTransformDialectExtension(DialectRegistry &); } // namespace test #ifdef MLIR_INCLUDE_TESTS @@ -323,6 +324,7 @@ int main(int argc, char **argv) { #ifdef MLIR_INCLUDE_TESTS ::test::registerTestDialect(registry); ::test::registerTestTransformDialectExtension(registry); + ::test::registerTestTransformsTransformDialectExtension(registry); ::test::registerTestTilingInterfaceTransformDialectExtension(registry); ::test::registerTestDynDialect(registry); #endif From d13fc99c128864420b3bd907ec8bd4bfb1224221 Mon Sep 17 00:00:00 2001 From: MaheshRavishankar Date: Thu, 6 Mar 2025 12:57:21 -0800 Subject: [PATCH 2/6] Address comments (round 1). Signed-off-by: MaheshRavishankar --- mlir/lib/Transforms/Utils/RegionUtils.cpp | 11 ++-- mlir/test/Transforms/move-operation-deps.mlir | 58 +++++++++++++++++++ .../test/lib/Transforms/TestTransformsOps.cpp | 2 +- mlir/test/lib/Transforms/TestTransformsOps.td | 2 +- 4 files changed, 66 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp index 7040243bed83b..3a17c97b3c982 100644 --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -1071,7 +1071,7 @@ LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter, if (op->getBlock() != insertionPoint->getBlock()) { return rewriter.notifyMatchFailure( op, "unsupported caes where operation and insertion point are not in " - "the sme basic block"); + "the same basic block"); } // Find the backward slice of operation for each `Value` the operation @@ -1079,6 +1079,7 @@ LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter, // dominated by the `insertionPoint` BackwardSliceOptions options; options.inclusive = true; + options.omitUsesFromAbove = false; options.filter = [&](Operation *sliceBoundaryOp) { return !dominance.properlyDominates(sliceBoundaryOp, insertionPoint); }; @@ -1093,7 +1094,7 @@ LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter, // If op has region, get the defined slice for all captured values. llvm::SetVector capturedVals; mlir::getUsedValuesDefinedAbove(regions, capturedVals); - for (auto value : capturedVals) { + for (Value value : capturedVals) { getBackwardSlice(value, &slice, options); } } @@ -1105,9 +1106,9 @@ LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter, "cannot move dependencies before operation in backward slice of op"); } - // Sort the slice topologically, ad move in topological order. + // Sort the slice topologically, and move in topological order. mlir::topologicalSort(slice); - for (auto op : slice) { + for (Operation *op : slice) { rewriter.moveOpBefore(op, insertionPoint); } return success(); @@ -1118,4 +1119,4 @@ LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter, Operation *insertionPoint) { DominanceInfo dominance(op); return moveOperationDependencies(rewriter, op, insertionPoint, dominance); -} \ No newline at end of file +} diff --git a/mlir/test/Transforms/move-operation-deps.mlir b/mlir/test/Transforms/move-operation-deps.mlir index 90c66a0f14938..97f9f6a95cc84 100644 --- a/mlir/test/Transforms/move-operation-deps.mlir +++ b/mlir/test/Transforms/move-operation-deps.mlir @@ -92,6 +92,39 @@ module attributes {transform.with_named_sequence} { // ----- +func.func @move_region_dependencies() -> f32 { + %0 = "before"() : () -> (f32) + %1 = "moved_op_1"() : () -> (f32) + %2 = "moved_op"() ({ + "yield"(%1) : (f32) -> () + }) : () -> (f32) + %3 = "foo"() ({ + "yield"(%2) : (f32) -> () + }) : () -> (f32) + return %3 : f32 +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) { + %op1 = transform.structured.match ops{["foo"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %op2 = transform.structured.match ops{["before"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + transform.test.move_operand_deps %op1 before %op2 + : !transform.any_op, !transform.any_op + transform.yield + } +} +// CHECK-LABEL: func @move_region_dependencies() +// CHECK: %[[MOVED_1:.+]] = "moved_op_1" +// CHECK: %[[MOVED_2:.+]] = "moved_op_2" +// CHECK: "yield"(%[[MOVED_1]]) +// CHECK: "before" +// CHECK: %[[FOO:.+]] = "foo" +// CHECK: return %[[FOO]] + +// ----- + // Fail when the "before" operation is part of the operation slice. func.func @do_not_move_slice() -> f32 { %0 = "before"() : () -> (f32) @@ -111,3 +144,28 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +func.func @move_region_dependencies() -> f32 { + %0 = "before"() : () -> (f32) + %1 = "moved_op"() ({ + "yield"(%0) : (f32) -> () + }) : () -> (f32) + %2 = "foo"() ({ + "yield"(%1) : (f32) -> () + }) : () -> (f32) + return %2 : f32 +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) { + %op1 = transform.structured.match ops{["foo"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %op2 = transform.structured.match ops{["before"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + transform.test.move_operand_deps %op1 before %op2 + : !transform.any_op, !transform.any_op + transform.yield + } +} diff --git a/mlir/test/lib/Transforms/TestTransformsOps.cpp b/mlir/test/lib/Transforms/TestTransformsOps.cpp index 427930b0c7ed1..de3e0163f5c45 100644 --- a/mlir/test/lib/Transforms/TestTransformsOps.cpp +++ b/mlir/test/lib/Transforms/TestTransformsOps.cpp @@ -63,4 +63,4 @@ void registerTestTransformsTransformDialectExtension( DialectRegistry ®istry) { registry.addExtensions(); } -} // namespace test \ No newline at end of file +} // namespace test diff --git a/mlir/test/lib/Transforms/TestTransformsOps.td b/mlir/test/lib/Transforms/TestTransformsOps.td index ef19d00f999c3..f514702cef5bc 100644 --- a/mlir/test/lib/Transforms/TestTransformsOps.td +++ b/mlir/test/lib/Transforms/TestTransformsOps.td @@ -15,7 +15,7 @@ include "mlir/Dialect/Transform/IR/TransformTypes.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpBase.td" -/// Transform dialect perations for testing transformations in MLIR +/// Transform dialect operations for testing transformations in MLIR def TestMoveOperandDeps : Op Date: Thu, 6 Mar 2025 16:18:52 -0800 Subject: [PATCH 3/6] Modify `TransformRewriter` listener to get the match failure remark and use it to test failure in the op. Signed-off-by: MaheshRavishankar --- .../Interfaces/TransformInterfaces.h | 11 ++++++ mlir/include/mlir/Transforms/RegionUtils.h | 3 +- .../Interfaces/TransformInterfaces.cpp | 11 ++++++ mlir/lib/Transforms/Utils/RegionUtils.cpp | 8 +++++ mlir/test/Transforms/move-operation-deps.mlir | 35 +++++++++++++++++-- .../test/lib/Transforms/TestTransformsOps.cpp | 4 +-- 6 files changed, 66 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h index e51aac02936b5..28046d1b8f2b0 100644 --- a/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h @@ -1074,10 +1074,18 @@ class ErrorCheckingTrackingListener : public TrackingListener { /// resets the error state to "success". DiagnosedSilenceableFailure checkAndResetError(); + /// Return the latest match notification message. + std::string getLatestMatchFailureMessage(); + /// Return "true" if this tracking listener had a failure. bool failed() const; protected: + + void + notifyMatchFailure(Location loc, + function_ref reasonCallback) override; + void notifyPayloadReplacementNotFound(Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) override; @@ -1089,6 +1097,9 @@ class ErrorCheckingTrackingListener : public TrackingListener { /// The number of errors that have been encountered. int64_t errorCounter = 0; + + /// Latest message from match failure notification. + std::string matchFailureMsg = ""; }; /// This is a special rewriter to be used in transform op implementations, diff --git a/mlir/include/mlir/Transforms/RegionUtils.h b/mlir/include/mlir/Transforms/RegionUtils.h index 4acc8528efe97..e6b928d8ebecc 100644 --- a/mlir/include/mlir/Transforms/RegionUtils.h +++ b/mlir/include/mlir/Transforms/RegionUtils.h @@ -72,7 +72,8 @@ SmallVector makeRegionIsolatedFromAbove( /// Move SSA values used within an operation before an insertion point, /// so that the operation itself (or its replacement) can be moved to -/// the insertion point. +/// the insertion point. Current support is only for movement of +/// dependencies of `op` before `insertionPoint` in the same basic block. LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op, Operation *insertionPoint, DominanceInfo &dominance); diff --git a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp index 1e0ef5add358e..55cac471fb14c 100644 --- a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp @@ -1390,6 +1390,17 @@ void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound( ++errorCounter; } +std::string transform::ErrorCheckingTrackingListener::getLatestMatchFailureMessage() { + return matchFailureMsg; +} + +void transform::ErrorCheckingTrackingListener::notifyMatchFailure( + Location loc, function_ref reasonCallback) { + Diagnostic diag(loc, DiagnosticSeverity::Remark); + reasonCallback(diag); + matchFailureMsg = diag.str(); +} + //===----------------------------------------------------------------------===// // TransformRewriter //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp index 3a17c97b3c982..516a773b03109 100644 --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -1073,6 +1073,11 @@ LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter, op, "unsupported caes where operation and insertion point are not in " "the same basic block"); } + // If `insertionPoint` does not dominate `op`, do nothing + if (!dominance.properlyDominates(insertionPoint, op)) { + return rewriter.notifyMatchFailure(op, + "insertion point does not dominate op"); + } // Find the backward slice of operation for each `Value` the operation // depends on. Prune the slice to only include operations not already @@ -1080,6 +1085,9 @@ LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter, BackwardSliceOptions options; options.inclusive = true; options.omitUsesFromAbove = false; + // Since current support is to only move within a same basic block, + // the slices dont need to look past block arguments. + options.omitBlockArguments = true; options.filter = [&](Operation *sliceBoundaryOp) { return !dominance.properlyDominates(sliceBoundaryOp, insertionPoint); }; diff --git a/mlir/test/Transforms/move-operation-deps.mlir b/mlir/test/Transforms/move-operation-deps.mlir index 97f9f6a95cc84..7b81f843a430d 100644 --- a/mlir/test/Transforms/move-operation-deps.mlir +++ b/mlir/test/Transforms/move-operation-deps.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt --allow-unregistered-dialect --transform-interpreter --split-input-file %s | FileCheck %s +// RUN: mlir-opt --allow-unregistered-dialect --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s // Check simple move of dependent operation before insertion. func.func @simple_move() -> f32 { @@ -95,7 +95,7 @@ module attributes {transform.with_named_sequence} { func.func @move_region_dependencies() -> f32 { %0 = "before"() : () -> (f32) %1 = "moved_op_1"() : () -> (f32) - %2 = "moved_op"() ({ + %2 = "moved_op_2"() ({ "yield"(%1) : (f32) -> () }) : () -> (f32) %3 = "foo"() ({ @@ -139,6 +139,7 @@ module attributes {transform.with_named_sequence} { : (!transform.any_op) -> !transform.any_op %op2 = transform.structured.match ops{["before"]} in %arg0 : (!transform.any_op) -> !transform.any_op + // expected-remark@+1{{cannot move dependencies before operation in backward slice of op}} transform.test.move_operand_deps %op1 before %op2 : !transform.any_op, !transform.any_op transform.yield @@ -147,7 +148,9 @@ module attributes {transform.with_named_sequence} { // ----- -func.func @move_region_dependencies() -> f32 { +// Fail when the "before" operation is part of the operation slice (computed +// when looking through implicitly captured values). +func.func @do_not_move_slice() -> f32 { %0 = "before"() : () -> (f32) %1 = "moved_op"() ({ "yield"(%0) : (f32) -> () @@ -164,6 +167,32 @@ module attributes {transform.with_named_sequence} { : (!transform.any_op) -> !transform.any_op %op2 = transform.structured.match ops{["before"]} in %arg0 : (!transform.any_op) -> !transform.any_op + // expected-remark@+1{{cannot move dependencies before operation in backward slice of op}} + transform.test.move_operand_deps %op1 before %op2 + : !transform.any_op, !transform.any_op + transform.yield + } +} + +// ----- + +// Dont move ops when insertion point does not dominate the op +func.func @do_not_move() -> f32 { + %1 = "moved_op"() : () -> (f32) + %2 = "foo"() ({ + "yield"(%1) : (f32) -> () + }) : () -> (f32) + %3 = "before"() : () -> f32 + return %2 : f32 +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) { + %op1 = transform.structured.match ops{["foo"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %op2 = transform.structured.match ops{["before"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + // expected-remark@+1{{insertion point does not dominate op}} transform.test.move_operand_deps %op1 before %op2 : !transform.any_op, !transform.any_op transform.yield diff --git a/mlir/test/lib/Transforms/TestTransformsOps.cpp b/mlir/test/lib/Transforms/TestTransformsOps.cpp index de3e0163f5c45..aaa566d9938a3 100644 --- a/mlir/test/lib/Transforms/TestTransformsOps.cpp +++ b/mlir/test/lib/Transforms/TestTransformsOps.cpp @@ -33,8 +33,8 @@ transform::TestMoveOperandDeps::apply(TransformRewriter &rewriter, Operation *moveBefore = *state.getPayloadOps(getInsertionPoint()).begin(); if (failed(moveOperationDependencies(rewriter, op, moveBefore))) { auto listener = cast(rewriter.getListener()); - std::string errorMsg = listener->checkAndResetError().getMessage(); - return emitSilenceableFailure(op, errorMsg); + std::string errorMsg = listener->getLatestMatchFailureMessage(); + (void)emitRemark(errorMsg); } return DiagnosedSilenceableFailure::success(); } From b0ddc6c010ca27e0f24ff10e73fcbf2d45a0ff5f Mon Sep 17 00:00:00 2001 From: MaheshRavishankar Date: Thu, 6 Mar 2025 16:28:14 -0800 Subject: [PATCH 4/6] Simplify slice computation. Signed-off-by: MaheshRavishankar --- mlir/lib/Transforms/Utils/RegionUtils.cpp | 21 +++-------- mlir/test/Transforms/move-operation-deps.mlir | 36 +++++++++++++++++++ 2 files changed, 40 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp index 516a773b03109..da0d486f0fdcb 100644 --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -1083,7 +1083,7 @@ LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter, // depends on. Prune the slice to only include operations not already // dominated by the `insertionPoint` BackwardSliceOptions options; - options.inclusive = true; + options.inclusive = false; options.omitUsesFromAbove = false; // Since current support is to only move within a same basic block, // the slices dont need to look past block arguments. @@ -1092,20 +1092,7 @@ LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter, return !dominance.properlyDominates(sliceBoundaryOp, insertionPoint); }; llvm::SetVector slice; - - // Get the defined slice for operands. - for (Value operand : op->getOperands()) { - getBackwardSlice(operand, &slice, options); - } - auto regions = op->getRegions(); - if (!regions.empty()) { - // If op has region, get the defined slice for all captured values. - llvm::SetVector capturedVals; - mlir::getUsedValuesDefinedAbove(regions, capturedVals); - for (Value value : capturedVals) { - getBackwardSlice(value, &slice, options); - } - } + getBackwardSlice(op, &slice, options); // If the slice contains `insertionPoint` cannot move the dependencies. if (slice.contains(insertionPoint)) { @@ -1114,8 +1101,8 @@ LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter, "cannot move dependencies before operation in backward slice of op"); } - // Sort the slice topologically, and move in topological order. - mlir::topologicalSort(slice); + // We should move the slice in topological order, but `getBackwardSlice` + // already does that. So no need to sort again. for (Operation *op : slice) { rewriter.moveOpBefore(op, insertionPoint); } diff --git a/mlir/test/Transforms/move-operation-deps.mlir b/mlir/test/Transforms/move-operation-deps.mlir index 7b81f843a430d..37637152938f6 100644 --- a/mlir/test/Transforms/move-operation-deps.mlir +++ b/mlir/test/Transforms/move-operation-deps.mlir @@ -125,6 +125,42 @@ module attributes {transform.with_named_sequence} { // ----- +// Current implementation omits following basic block argument when +// computing slices. Verify that this gives expected result. +func.func @ignore_basic_block_arguments() -> f32 { +^bb0(): + %8 = "test"() : () -> (f32) + return %8: f32 +^bb1(%bbArg : f32): + %0 = "before"() : () -> (f32) + %1 = "moved_op"() ({ + "yield"(%bbArg) : (f32) -> () + }) : () -> (f32) + %2 = "foo"() ({ + "yield"(%1) : (f32) -> () + }) : () -> (f32) + return %2 : f32 +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) { + %op1 = transform.structured.match ops{["foo"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %op2 = transform.structured.match ops{["before"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + transform.test.move_operand_deps %op1 before %op2 + : !transform.any_op, !transform.any_op + transform.yield + } +} +// CHECK-LABEL: func @ignore_basic_block_arguments() +// CHECK: %[[MOVED_1:.+]] = "moved_op" +// CHECK: "before" +// CHECK: %[[FOO:.+]] = "foo" +// CHECK: return %[[FOO]] + +// ----- + // Fail when the "before" operation is part of the operation slice. func.func @do_not_move_slice() -> f32 { %0 = "before"() : () -> (f32) From 59ce8d0d1f3153beb377c90bcaad65b1ad48b9e2 Mon Sep 17 00:00:00 2001 From: MaheshRavishankar Date: Thu, 6 Mar 2025 16:52:27 -0800 Subject: [PATCH 5/6] Fix Windows failure Signed-off-by: MaheshRavishankar --- mlir/test/lib/Transforms/lit.local.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/test/lib/Transforms/lit.local.cfg b/mlir/test/lib/Transforms/lit.local.cfg index 8ffccee1d6d79..7f4d25f1ba025 100644 --- a/mlir/test/lib/Transforms/lit.local.cfg +++ b/mlir/test/lib/Transforms/lit.local.cfg @@ -1 +1,2 @@ config.suffixes.remove(".pdll") +config.suffixes.remove(".td") From ac9e0d1320caf90feab3ebe82658ef7f45e90220 Mon Sep 17 00:00:00 2001 From: MaheshRavishankar Date: Thu, 6 Mar 2025 16:56:46 -0800 Subject: [PATCH 6/6] Fix code formatter errors. Signed-off-by: MaheshRavishankar --- .../Transform/Interfaces/TransformInterfaces.h | 8 ++++---- .../Transform/Interfaces/TransformInterfaces.cpp | 12 ++++++++---- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h index 28046d1b8f2b0..b9f2af22e9483 100644 --- a/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h @@ -1074,17 +1074,17 @@ class ErrorCheckingTrackingListener : public TrackingListener { /// resets the error state to "success". DiagnosedSilenceableFailure checkAndResetError(); - /// Return the latest match notification message. + /// Return the latest match notification message. Returns an empty string + /// when no error message was captured. std::string getLatestMatchFailureMessage(); /// Return "true" if this tracking listener had a failure. bool failed() const; protected: - void notifyMatchFailure(Location loc, - function_ref reasonCallback) override; + function_ref reasonCallback) override; void notifyPayloadReplacementNotFound(Operation *op, ValueRange values, @@ -1099,7 +1099,7 @@ class ErrorCheckingTrackingListener : public TrackingListener { int64_t errorCounter = 0; /// Latest message from match failure notification. - std::string matchFailureMsg = ""; + std::optional matchFailure; }; /// This is a special rewriter to be used in transform op implementations, diff --git a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp index 55cac471fb14c..e0a5df0c758b3 100644 --- a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp @@ -1390,15 +1390,19 @@ void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound( ++errorCounter; } -std::string transform::ErrorCheckingTrackingListener::getLatestMatchFailureMessage() { - return matchFailureMsg; +std::string +transform::ErrorCheckingTrackingListener::getLatestMatchFailureMessage() { + if (!matchFailure) { + return ""; + } + return matchFailure->str(); } void transform::ErrorCheckingTrackingListener::notifyMatchFailure( - Location loc, function_ref reasonCallback) { + Location loc, function_ref reasonCallback) { Diagnostic diag(loc, DiagnosticSeverity::Remark); reasonCallback(diag); - matchFailureMsg = diag.str(); + matchFailure = std::move(diag); } //===----------------------------------------------------------------------===//