Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions mlir/include/mlir/Transforms/RegionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "llvm/ADT/SetVector.h"

namespace mlir {
class DominanceInfo;
class RewriterBase;

/// Check if all values in the provided range are defined above the `limit`
Expand Down Expand Up @@ -69,6 +70,15 @@ SmallVector<Value> makeRegionIsolatedFromAbove(
llvm::function_ref<bool(Operation *)> cloneOperationIntoRegion =
[](Operation *) { return false; });

/// Move SSA values used within an operation before an insertion point,
/// so that the operation itself (or its replacement) can be moved to
/// the insertion point.
LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op,
Operation *insertionPoint,
DominanceInfo &dominance);
LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op,
Operation *insertionPoint);

/// Run a set of structural simplifications over the given regions. This
/// includes transformations like unreachable block elimination, dead argument
/// elimination, as well as some other DCE. This function returns success if any
Expand Down
65 changes: 65 additions & 0 deletions mlir/lib/Transforms/Utils/RegionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@
//===----------------------------------------------------------------------===//

#include "mlir/Transforms/RegionUtils.h"

#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
Expand Down Expand Up @@ -1054,3 +1057,65 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
return success(eliminatedBlocks || eliminatedOpsOrArgs ||
mergedIdenticalBlocks || droppedRedundantArguments);
}

//===---------------------------------------------------------------------===//
// Move operation dependencies
//===---------------------------------------------------------------------===//

LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter,
Operation *op,
Operation *insertionPoint,
DominanceInfo &dominance) {
// Currently unsupported case where the op and insertion point are
// in different basic blocks.
if (op->getBlock() != insertionPoint->getBlock()) {
return rewriter.notifyMatchFailure(
op, "unsupported caes where operation and insertion point are not in "
"the sme basic block");
}

// Find the backward slice of operation for each `Value` the operation
// depends on. Prune the slice to only include operations not already
// dominated by the `insertionPoint`
BackwardSliceOptions options;
options.inclusive = true;
options.filter = [&](Operation *sliceBoundaryOp) {
return !dominance.properlyDominates(sliceBoundaryOp, insertionPoint);
};
llvm::SetVector<Operation *> slice;

// Get the defined slice for operands.
for (Value operand : op->getOperands()) {
getBackwardSlice(operand, &slice, options);
}
auto regions = op->getRegions();
if (!regions.empty()) {
// If op has region, get the defined slice for all captured values.
llvm::SetVector<Value> capturedVals;
mlir::getUsedValuesDefinedAbove(regions, capturedVals);
for (auto value : capturedVals) {
getBackwardSlice(value, &slice, options);
}
}

// If the slice contains `insertionPoint` cannot move the dependencies.
if (slice.contains(insertionPoint)) {
return rewriter.notifyMatchFailure(
op,
"cannot move dependencies before operation in backward slice of op");
}

// Sort the slice topologically, ad move in topological order.
mlir::topologicalSort(slice);
for (auto op : slice) {
rewriter.moveOpBefore(op, insertionPoint);
}
return success();
}

LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter,
Operation *op,
Operation *insertionPoint) {
DominanceInfo dominance(op);
return moveOperationDependencies(rewriter, op, insertionPoint, dominance);
}
113 changes: 113 additions & 0 deletions mlir/test/Transforms/move-operation-deps.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// RUN: mlir-opt --allow-unregistered-dialect --transform-interpreter --split-input-file %s | FileCheck %s

// Check simple move of dependent operation before insertion.
func.func @simple_move() -> f32 {
%0 = "before"() : () -> (f32)
%1 = "moved_op"() : () -> (f32)
%2 = "foo"(%1) : (f32) -> (f32)
return %2 : f32
}
// CHECK-LABEL: func @simple_move()
// CHECK: %[[MOVED:.+]] = "moved_op"
// CHECK: %[[BEFORE:.+]] = "before"
// CHECK: %[[FOO:.+]] = "foo"(%[[MOVED]])
// CHECK: return %[[FOO]]

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["foo"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
transform.test.move_operand_deps %op1 before %op2
: !transform.any_op, !transform.any_op
transform.yield
}
}

// -----

// Move operands that are implicitly captured by the op
func.func @move_region_dependencies() -> f32 {
%0 = "before"() : () -> (f32)
%1 = "moved_op"() : () -> (f32)
%2 = "foo"() ({
%3 = "inner_op"(%1) : (f32) -> (f32)
"yield"(%3) : (f32) -> ()
}) : () -> (f32)
return %2 : f32
}
// CHECK-LABEL: func @move_region_dependencies()
// CHECK: %[[MOVED:.+]] = "moved_op"
// CHECK: %[[BEFORE:.+]] = "before"
// CHECK: %[[FOO:.+]] = "foo"
// CHECK: return %[[FOO]]

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["foo"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
transform.test.move_operand_deps %op1 before %op2
: !transform.any_op, !transform.any_op
transform.yield
}
}

// -----

// Move operations in toplogical sort order
func.func @move_in_topological_sort_order() -> f32 {
%0 = "before"() : () -> (f32)
%1 = "moved_op_1"() : () -> (f32)
%2 = "moved_op_2"() : () -> (f32)
%3 = "moved_op_3"(%1) : (f32) -> (f32)
%4 = "moved_op_4"(%1, %3) : (f32, f32) -> (f32)
%5 = "moved_op_5"(%2) : (f32) -> (f32)
%6 = "foo"(%4, %5) : (f32, f32) -> (f32)
return %6 : f32
}
// CHECK-LABEL: func @move_in_topological_sort_order()
// CHECK: %[[MOVED_1:.+]] = "moved_op_1"
// CHECK-DAG: %[[MOVED_2:.+]] = "moved_op_3"(%[[MOVED_1]])
// CHECK-DAG: %[[MOVED_3:.+]] = "moved_op_4"(%[[MOVED_1]], %[[MOVED_2]])
// CHECK-DAG: %[[MOVED_4:.+]] = "moved_op_2"
// CHECK-DAG: %[[MOVED_5:.+]] = "moved_op_5"(%[[MOVED_4]])
// CHECK: %[[BEFORE:.+]] = "before"
// CHECK: %[[FOO:.+]] = "foo"(%[[MOVED_3]], %[[MOVED_5]])
// CHECK: return %[[FOO]]

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["foo"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
transform.test.move_operand_deps %op1 before %op2
: !transform.any_op, !transform.any_op
transform.yield
}
}

// -----

// Fail when the "before" operation is part of the operation slice.
func.func @do_not_move_slice() -> f32 {
%0 = "before"() : () -> (f32)
%1 = "moved_op"(%0) : (f32) -> (f32)
%2 = "foo"(%1) : (f32) -> (f32)
return %2 : f32
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["foo"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
transform.test.move_operand_deps %op1 before %op2
: !transform.any_op, !transform.any_op
transform.yield
}
}
8 changes: 8 additions & 0 deletions mlir/test/lib/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
set(LLVM_TARGET_DEFINITIONS TestTransformsOps.td)
mlir_tablegen(TestTransformsOps.h.inc -gen-op-decls)
mlir_tablegen(TestTransformsOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRTestTransformsOpsIncGen)

set(LLVM_OPTIONAL_SOURCES
TestDialectConversion.cpp)
set(MLIRTestTransformsPDLDep)
Expand Down Expand Up @@ -25,6 +30,7 @@ add_mlir_library(MLIRTestTransforms
TestControlFlowSink.cpp
TestInlining.cpp
TestMakeIsolatedFromAbove.cpp
TestTransformsOps.cpp
${MLIRTestTransformsPDLSrc}

EXCLUDE_FROM_LIBMLIR
Expand All @@ -34,6 +40,7 @@ add_mlir_library(MLIRTestTransforms

DEPENDS
${MLIRTestTransformsPDLDep}
MLIRTestTransformsOpsIncGen

LINK_LIBS PUBLIC
MLIRTestDialect
Expand All @@ -43,6 +50,7 @@ mlir_target_link_libraries(MLIRTestTransforms PUBLIC
MLIRFuncDialect
MLIRInferIntRangeInterface
MLIRTransforms
MLIRTransformDialect
)

target_include_directories(MLIRTestTransforms
Expand Down
66 changes: 66 additions & 0 deletions mlir/test/lib/Transforms/TestTransformsOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
//===- TestTransformsOps.cpp - Test Transforms ----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines transform dialect operations for testing MLIR
// transformations
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Transforms/RegionUtils.h"

#define GET_OP_CLASSES
#include "TestTransformsOps.h.inc"

using namespace mlir;
using namespace mlir::transform;

#define GET_OP_CLASSES
#include "TestTransformsOps.cpp.inc"

DiagnosedSilenceableFailure
transform::TestMoveOperandDeps::apply(TransformRewriter &rewriter,
TransformResults &TransformResults,
TransformState &state) {
Operation *op = *state.getPayloadOps(getOp()).begin();
Operation *moveBefore = *state.getPayloadOps(getInsertionPoint()).begin();
if (failed(moveOperationDependencies(rewriter, op, moveBefore))) {
auto listener = cast<ErrorCheckingTrackingListener>(rewriter.getListener());
std::string errorMsg = listener->checkAndResetError().getMessage();
return emitSilenceableFailure(op, errorMsg);
}
return DiagnosedSilenceableFailure::success();
}

namespace {

class TestTransformsDialectExtension
: public transform::TransformDialectExtension<
TestTransformsDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformsDialectExtension)

using Base::Base;

void init() {
registerTransformOps<
#define GET_OP_LIST
#include "TestTransformsOps.cpp.inc"
>();
}
};
} // namespace

namespace test {
void registerTestTransformsTransformDialectExtension(
DialectRegistry &registry) {
registry.addExtensions<TestTransformsDialectExtension>();
}
} // namespace test
41 changes: 41 additions & 0 deletions mlir/test/lib/Transforms/TestTransformsOps.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
//===- TestTransformOps.td ---------------------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef TEST_TRANSFORM_OPS
#define TEST_TRANSFORM_OPS

include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
include "mlir/Dialect/Transform/IR/TransformTypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"

/// Transform dialect perations for testing transformations in MLIR

def TestMoveOperandDeps :
Op<Transform_Dialect, "test.move_operand_deps",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Moves all dependencies of on operation before another operation.
}];

let arguments =
(ins TransformHandleTypeInterface:$op,
TransformHandleTypeInterface:$insertion_point);

let results = (outs);

let assemblyFormat = [{
$op `before` $insertion_point attr-dict
`:` type($op) `,` type($insertion_point)
}];
}

#endif // TEST_TRANSFORM_OPS
2 changes: 2 additions & 0 deletions mlir/tools/mlir-opt/mlir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ void registerTestDialect(DialectRegistry &);
void registerTestDynDialect(DialectRegistry &);
void registerTestTilingInterfaceTransformDialectExtension(DialectRegistry &);
void registerTestTransformDialectExtension(DialectRegistry &);
void registerTestTransformsTransformDialectExtension(DialectRegistry &);
} // namespace test

#ifdef MLIR_INCLUDE_TESTS
Expand Down Expand Up @@ -323,6 +324,7 @@ int main(int argc, char **argv) {
#ifdef MLIR_INCLUDE_TESTS
::test::registerTestDialect(registry);
::test::registerTestTransformDialectExtension(registry);
::test::registerTestTransformsTransformDialectExtension(registry);
::test::registerTestTilingInterfaceTransformDialectExtension(registry);
::test::registerTestDynDialect(registry);
#endif
Expand Down
Loading