From 29209330d68e877cb92eca9f48b782244a97ecf9 Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Wed, 23 Oct 2024 16:46:02 +0000 Subject: [PATCH 1/3] Add `omitUsesFromAbove` to getBackwardsSlice `getBackwardsSlice` should track values captured by each op's region that it traverses, and follow those defs. However, there might be logic that depends on not traversing captured values so this change preserves the default behavior by hiding this logic behind the `omitUsesFromAbove` flag. --- mlir/include/mlir/Analysis/SliceAnalysis.h | 5 +++++ mlir/lib/Analysis/SliceAnalysis.cpp | 14 ++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/mlir/include/mlir/Analysis/SliceAnalysis.h b/mlir/include/mlir/Analysis/SliceAnalysis.h index 99279fdfe427c..a4f5d937cd51d 100644 --- a/mlir/include/mlir/Analysis/SliceAnalysis.h +++ b/mlir/include/mlir/Analysis/SliceAnalysis.h @@ -47,6 +47,11 @@ struct BackwardSliceOptions : public SliceOptions { /// backward slice computation traverses block arguments and asserts that the /// parent op has a single region with a single block. bool omitBlockArguments = false; + + /// When omitUsesFromAbove is true, the backward slice computation omits + /// traversing values that are captured from above. + /// TODO: this should default to `false` after users have been updated. + bool omitUsesFromAbove = true; }; using ForwardSliceOptions = SliceOptions; diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp index 2b1cf411ceeee..d07ae7b3ffa2c 100644 --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -16,6 +16,7 @@ #include "mlir/IR/Operation.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Support/LLVM.h" +#include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" @@ -115,6 +116,19 @@ static void getBackwardSliceImpl(Operation *op, } } + // Visit values that are defined above. + if (!options.omitUsesFromAbove) { + visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand *operand) { + if (Operation *definingOp = operand->get().getDefiningOp()) { + getBackwardSliceImpl(definingOp, backwardSlice, options); + return; + } + Operation *bbAargOwner = + cast(operand->get()).getOwner()->getParentOp(); + getBackwardSliceImpl(bbAargOwner, backwardSlice, options); + }); + } + backwardSlice->insert(op); } From 3e8b47a6e9d7153baa70838d796f1593a12984d8 Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Wed, 30 Oct 2024 03:51:23 +0000 Subject: [PATCH 2/3] Cleanup impl & add test --- mlir/lib/Analysis/SliceAnalysis.cpp | 21 +++++++-------------- mlir/test/IR/slice.mlir | 28 +++++++++++++++++++++++++++- mlir/test/lib/IR/TestSlicing.cpp | 2 ++ 3 files changed, 36 insertions(+), 15 deletions(-) diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp index d07ae7b3ffa2c..cd0dc25adf1ca 100644 --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -92,7 +92,13 @@ static void getBackwardSliceImpl(Operation *op, if (options.filter && !options.filter(op)) return; - for (const auto &en : llvm::enumerate(op->getOperands())) { + auto operands = op->getOperands(); + SetVector valuesToFollow(operands.begin(), operands.end()); + if (!options.omitUsesFromAbove) { + getUsedValuesDefinedAbove(op->getRegions(), valuesToFollow); + } + + for (const auto &en : llvm::enumerate(valuesToFollow)) { auto operand = en.value(); if (auto *definingOp = operand.getDefiningOp()) { if (backwardSlice->count(definingOp) == 0) @@ -116,19 +122,6 @@ static void getBackwardSliceImpl(Operation *op, } } - // Visit values that are defined above. - if (!options.omitUsesFromAbove) { - visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand *operand) { - if (Operation *definingOp = operand->get().getDefiningOp()) { - getBackwardSliceImpl(definingOp, backwardSlice, options); - return; - } - Operation *bbAargOwner = - cast(operand->get()).getOwner()->getParentOp(); - getBackwardSliceImpl(bbAargOwner, backwardSlice, options); - }); - } - backwardSlice->insert(op); } diff --git a/mlir/test/IR/slice.mlir b/mlir/test/IR/slice.mlir index 0a32a0f231baf..87d446c8f415a 100644 --- a/mlir/test/IR/slice.mlir +++ b/mlir/test/IR/slice.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -slice-analysis-test %s | FileCheck %s +// RUN: mlir-opt -slice-analysis-test -split-input-file %s | FileCheck %s func.func @slicing_linalg_op(%arg0 : index, %arg1 : index, %arg2 : index) { %a = memref.alloc(%arg0, %arg2) : memref @@ -33,3 +33,29 @@ func.func @slicing_linalg_op(%arg0 : index, %arg1 : index, %arg2 : index) { // CHECK-DAG: %[[B:.+]] = memref.alloc(%[[ARG2]], %[[ARG1]]) : memref // CHECK-DAG: %[[C:.+]] = memref.alloc(%[[ARG0]], %[[ARG1]]) : memref // CHECK: return + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +func.func @slice_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) { + %0 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) { + ^bb0(%in: f32, %out: f32): + %2 = arith.addf %in, %in : f32 + linalg.yield %2 : f32 + } -> tensor<5x5xf32> + %collapsed = tensor.collapse_shape %0 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32> + %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) { + ^bb0(%in: f32, %out: f32): + %c2 = arith.constant 2 : index + %extracted = tensor.extract %collapsed[%c2] : tensor<25xf32> + %2 = arith.addf %extracted, %extracted : f32 + linalg.yield %2 : f32 + } -> tensor<5x5xf32> + return +} + +// CHECK-LABEL: func @slice_use_from_above__backward_slice__0 +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK: %[[A:.+]] = linalg.generic {{.*}} ins(%[[ARG0]] +// CHECK: %[[B:.+]] = tensor.collapse_shape %[[A]] +// CHECK: return diff --git a/mlir/test/lib/IR/TestSlicing.cpp b/mlir/test/lib/IR/TestSlicing.cpp index c3d0d151c6d24..e99d5976d6d9d 100644 --- a/mlir/test/lib/IR/TestSlicing.cpp +++ b/mlir/test/lib/IR/TestSlicing.cpp @@ -39,6 +39,8 @@ static LogicalResult createBackwardSliceFunction(Operation *op, SetVector slice; BackwardSliceOptions options; options.omitBlockArguments = omitBlockArguments; + // TODO: Make this default. + options.omitUsesFromAbove = false; getBackwardSlice(op, &slice, options); for (Operation *slicedOp : slice) builder.clone(*slicedOp, mapper); From 9f7ea3c6021a94af42101834a95238fb8fb2492f Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Wed, 30 Oct 2024 20:26:10 +0000 Subject: [PATCH 3/3] Use lambda to reuse logic --- mlir/lib/Analysis/SliceAnalysis.cpp | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp index cd0dc25adf1ca..7ec999fa0370f 100644 --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -17,6 +17,7 @@ #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" @@ -92,20 +93,13 @@ static void getBackwardSliceImpl(Operation *op, if (options.filter && !options.filter(op)) return; - auto operands = op->getOperands(); - SetVector valuesToFollow(operands.begin(), operands.end()); - if (!options.omitUsesFromAbove) { - getUsedValuesDefinedAbove(op->getRegions(), valuesToFollow); - } - - for (const auto &en : llvm::enumerate(valuesToFollow)) { - auto operand = en.value(); - if (auto *definingOp = operand.getDefiningOp()) { + auto processValue = [&](Value value) { + if (auto *definingOp = value.getDefiningOp()) { if (backwardSlice->count(definingOp) == 0) getBackwardSliceImpl(definingOp, backwardSlice, options); - } else if (auto blockArg = dyn_cast(operand)) { + } else if (auto blockArg = dyn_cast(value)) { if (options.omitBlockArguments) - continue; + return; Block *block = blockArg.getOwner(); Operation *parentOp = block->getParentOp(); @@ -120,7 +114,14 @@ static void getBackwardSliceImpl(Operation *op, } else { llvm_unreachable("No definingOp and not a block argument."); } + }; + + if (!options.omitUsesFromAbove) { + visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand *operand) { + processValue(operand->get()); + }); } + llvm::for_each(op->getOperands(), processValue); backwardSlice->insert(op); }