-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[mlir][bufferization]-Add enforce immutable func args pass #113130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][bufferization]-Add enforce immutable func args pass #113130
Conversation
Adding a pass which allocates a new a buffer for each input argument of the function it operates on and is being written to, also copying it into the allocated buffer by a `memref.copy`.
|
@llvm/pr-subscribers-mlir-bufferization @llvm/pr-subscribers-mlir Author: Amir Bishara (amirBish) ChangesAdding a pass which allocates a new a buffer for each input argument of the function it operates on and is being written to, also copying it into the allocated buffer by a Patch is 20.27 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/113130.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index 72abb5b3f1f94e..e17914fbbd5840 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -229,6 +229,10 @@ createPromoteBuffersToStackPass(std::function<bool(Value)> isSmallAlloc);
/// insert_slice ops.
std::unique_ptr<Pass> createEmptyTensorEliminationPass();
+// Create a pass that enforces read only buffers of the
+// relevant function arguments.
+std::unique_ptr<Pass> createEnforceImmutableFuncArgsPass();
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index cc5463ea968fc3..fb2b4d3a305f4a 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -595,4 +595,18 @@ def EmptyTensorElimination : Pass<"eliminate-empty-tensors"> {
let constructor = "mlir::bufferization::createEmptyTensorEliminationPass()";
}
+def EnforceImmutableFuncArgs : Pass<"enforce-immutable-func-args", "func::FuncOp"> {
+ let summary = "Enforcing function's arguments immutabilty by inserting allocOps and copy";
+ let description = [{
+ This pass allocates a new a buffer for each input argument of the function
+ which is being written to and marked to be enforced, also copying it into the
+ allocated buffer.
+ This will avoid in place memory updates for the function's arguments and
+ make it immutable/read-only buffer.
+ }];
+ let constructor = "mlir::bufferization::createEnforceImmutableFuncArgsPass()";
+ let dependentDialects = ["memref::MemRefDialect"];
+}
+
+
#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
index 50104e8f8346b4..25de31c179a31d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
@@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRBufferizationTransforms
OwnershipBasedBufferDeallocation.cpp
TensorCopyInsertion.cpp
OptimizeAllocationLiveness.cpp
+ EnforceImmutableFuncArgs.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Bufferization
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EnforceImmutableFuncArgs.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EnforceImmutableFuncArgs.cpp
new file mode 100644
index 00000000000000..84f201c141a3d1
--- /dev/null
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EnforceImmutableFuncArgs.cpp
@@ -0,0 +1,101 @@
+//===- OptimizeAllocationLiveness.cpp - impl. optimize allocation liveness pass
+//-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass for optimizing allocation liveness.
+// The pass moves the deallocation operation after the last user of the
+// allocated buffer.
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Operation.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "enforce-immutable-func-args"
+#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
+using namespace mlir;
+
+namespace mlir {
+namespace bufferization {
+#define GEN_PASS_DEF_ENFORCEIMMUTABLEFUNCARGS
+#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
+} // namespace bufferization
+} // namespace mlir
+
+// Checks if there is any operation which tries to write
+// into `buffer`.
+// This method assumes buffer has `MemRefType`.
+static bool isWrittenTo(Value buffer);
+
+namespace {
+/// This pass allocates a new a buffer for each input argument of the function
+/// which is being written to, also copying it into the allocated buffer.
+/// This will avoid in place memory updates for the kernel's arguments and
+/// make them immutable/read-only buffers.
+struct EnforceImmutableFuncArgsPass
+ : public bufferization::impl::EnforceImmutableFuncArgsBase<
+ EnforceImmutableFuncArgsPass> {
+ void runOnOperation() final;
+};
+} // end anonymous namespace.
+
+static bool isWrittenTo(Value buffer) {
+ assert(isa<MemRefType>(buffer.getType()));
+
+ for (auto user : buffer.getUsers()) {
+ if (hasEffect<MemoryEffects::Write>(user, buffer))
+ return true;
+ if (auto viewLikeOp = dyn_cast<ViewLikeOpInterface>(user)) {
+ assert(viewLikeOp->getNumResults() == 1);
+ if (isWrittenTo(viewLikeOp->getResult(0)))
+ return true;
+ }
+ }
+ return false;
+}
+
+void EnforceImmutableFuncArgsPass::runOnOperation() {
+
+ func::FuncOp funcOp = getOperation();
+
+ LDBG("enforcing immutable function arguments in func " << funcOp.getName());
+
+ IRRewriter rewriter(funcOp->getContext());
+ rewriter.setInsertionPointToStart(&funcOp.getBody().front());
+ for (auto argument : funcOp.getArguments()) {
+
+ auto argType = dyn_cast<MemRefType>(argument.getType());
+ if (!argType) {
+ emitError(argument.getLoc(),
+ "function has argument with non memref type");
+ return signalPassFailure();
+ }
+
+ if (!isWrittenTo(argument))
+ continue;
+
+ LDBG("Found a function argument is being written to " << argument);
+ Value allocatedMemref =
+ rewriter.create<memref::AllocOp>(funcOp.getLoc(), argType);
+ rewriter.replaceAllUsesWith(argument, allocatedMemref);
+ rewriter.create<memref::CopyOp>(funcOp.getLoc(), argument, allocatedMemref);
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// EnforceImmutableFuncArgs construction
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<Pass>
+mlir::bufferization::createEnforceImmutableFuncArgsPass() {
+ return std::make_unique<EnforceImmutableFuncArgsPass>();
+}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/enforce-immutable-func-args.mlir b/mlir/test/Dialect/Bufferization/Transforms/enforce-immutable-func-args.mlir
new file mode 100644
index 00000000000000..13019d2fbf5af4
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/enforce-immutable-func-args.mlir
@@ -0,0 +1,248 @@
+// RUN: mlir-opt --split-input-file --enforce-immutable-func-args %s -o - | FileCheck %s
+
+
+// CHECK-LABEL: func.func @func_no_input() {
+// CHECK: return
+// CHECK: }
+
+func.func @func_no_input() {
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func private @func_with_returned_argument(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<1x13x21x3xf32>) -> memref<1x13x21x3xf32> {
+// CHECK: return %[[VAL_0]] : memref<1x13x21x3xf32>
+// CHECK: }
+
+func.func private @func_with_returned_argument(%arg0: memref<1x13x21x3xf32>) -> (memref<1x13x21x3xf32>) {
+ return %arg0 : memref<1x13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func private @func_with_modified_argument_directly(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<1x13x21x3xf32>, %[[VAL_1:.*]]: memref<1x13x21x3xf32>) -> memref<1x13x21x3xf32> {
+// CHECK: %[[VAL_2:.*]] = memref.alloc() : memref<1x13x21x3xf32>
+// CHECK: memref.copy %[[VAL_0]], %[[VAL_2]] : memref<1x13x21x3xf32> to memref<1x13x21x3xf32>
+// CHECK: %[[VAL_3:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x13x21x3xf32>
+// CHECK: linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[VAL_2]], %[[VAL_1]] : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>) outs(%[[VAL_2]] : memref<1x13x21x3xf32>) {
+// CHECK: ^bb0(%[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: f32):
+// CHECK: %[[VAL_7:.*]] = arith.addf %[[VAL_4]], %[[VAL_5]] : f32
+// CHECK: linalg.yield %[[VAL_7]] : f32
+// CHECK: }
+// CHECK: return %[[VAL_3]] : memref<1x13x21x3xf32>
+// CHECK: }
+
+func.func private @func_with_modified_argument_directly(%arg0: memref<1x13x21x3xf32>, %arg1: memref<1x13x21x3xf32>) -> (memref<1x13x21x3xf32>){
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x13x21x3xf32>
+ linalg.generic {
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+ ],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+ }
+ ins(%arg0, %arg1 : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>)
+ outs(%arg0 : memref<1x13x21x3xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %0 = arith.addf %in, %in_0 : f32
+ linalg.yield %0 : f32
+ }
+ return %alloc : memref<1x13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func private @func_with_modified_argument_directly_and_returned(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<1x13x21x3xf32>, %[[VAL_1:.*]]: memref<1x13x21x3xf32>) -> memref<1x13x21x3xf32> {
+// CHECK: %[[VAL_2:.*]] = memref.alloc() : memref<1x13x21x3xf32>
+// CHECK: memref.copy %[[VAL_0]], %[[VAL_2]] : memref<1x13x21x3xf32> to memref<1x13x21x3xf32>
+// CHECK: linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[VAL_2]], %[[VAL_1]] : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>) outs(%[[VAL_2]] : memref<1x13x21x3xf32>) {
+// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32):
+// CHECK: %[[VAL_6:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32
+// CHECK: linalg.yield %[[VAL_6]] : f32
+// CHECK: }
+// CHECK: return %[[VAL_2]] : memref<1x13x21x3xf32>
+// CHECK: }
+
+func.func private @func_with_modified_argument_directly_and_returned(%arg0: memref<1x13x21x3xf32>, %arg1: memref<1x13x21x3xf32>) -> (memref<1x13x21x3xf32>){
+ linalg.generic {
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+ ],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+ }
+ ins(%arg0, %arg1 : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>)
+ outs(%arg0 : memref<1x13x21x3xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %0 = arith.addf %in, %in_0 : f32
+ linalg.yield %0 : f32
+ }
+ return %arg0 : memref<1x13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func private @func_with_modified_argument_directly_twice(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<1x13x21x3xf32>, %[[VAL_1:.*]]: memref<1x13x21x3xf32>) -> memref<1x13x21x3xf32> {
+// CHECK: %[[VAL_2:.*]] = memref.alloc() : memref<1x13x21x3xf32>
+// CHECK: memref.copy %[[VAL_0]], %[[VAL_2]] : memref<1x13x21x3xf32> to memref<1x13x21x3xf32>
+// CHECK: %[[VAL_3:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x13x21x3xf32>
+// CHECK: linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[VAL_2]], %[[VAL_1]] : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>) outs(%[[VAL_2]] : memref<1x13x21x3xf32>) {
+// CHECK: ^bb0(%[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: f32):
+// CHECK: %[[VAL_7:.*]] = arith.addf %[[VAL_4]], %[[VAL_5]] : f32
+// CHECK: linalg.yield %[[VAL_7]] : f32
+// CHECK: }
+// CHECK: linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[VAL_2]], %[[VAL_1]] : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>) outs(%[[VAL_2]] : memref<1x13x21x3xf32>) {
+// CHECK: ^bb0(%[[VAL_8:.*]]: f32, %[[VAL_9:.*]]: f32, %[[VAL_10:.*]]: f32):
+// CHECK: %[[VAL_11:.*]] = arith.addf %[[VAL_8]], %[[VAL_9]] : f32
+// CHECK: linalg.yield %[[VAL_11]] : f32
+// CHECK: }
+// CHECK: return %[[VAL_3]] : memref<1x13x21x3xf32>
+// CHECK: }
+
+func.func private @func_with_modified_argument_directly_twice(%arg0: memref<1x13x21x3xf32>, %arg1: memref<1x13x21x3xf32>) -> (memref<1x13x21x3xf32>){
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x13x21x3xf32>
+ linalg.generic {
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+ ],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+ }
+ ins(%arg0, %arg1 : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>)
+ outs(%arg0 : memref<1x13x21x3xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %0 = arith.addf %in, %in_0 : f32
+ linalg.yield %0 : f32
+ }
+ linalg.generic {
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+ ],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+ }
+ ins(%arg0, %arg1 : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>)
+ outs(%arg0 : memref<1x13x21x3xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %0 = arith.addf %in, %in_0 : f32
+ linalg.yield %0 : f32
+ }
+ return %alloc : memref<1x13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func private @func_with_modified_argument_directly(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<5xi32, 1>, %[[VAL_1:.*]]: memref<5xi32, 1>, %[[VAL_2:.*]]: memref<5xi32, 1>) -> memref<5xi32, 1> {
+// CHECK: %[[VAL_3:.*]] = memref.alloc() : memref<5xi32, 1>
+// CHECK: memref.copy %[[VAL_2]], %[[VAL_3]] : memref<5xi32, 1> to memref<5xi32, 1>
+// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_5:.*]] = arith.constant 5 : index
+// CHECK: %[[VAL_6:.*]] = arith.constant 0 : index
+// CHECK: scf.for %[[VAL_7:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_4]] {
+// CHECK: %[[VAL_8:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_7]]] : memref<5xi32, 1>
+// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_8]] : i32 to index
+// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_7]]] : memref<5xi32, 1>
+// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_9]]] : memref<5xi32, 1>
+// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_10]], %[[VAL_11]] : i32
+// CHECK: memref.store %[[VAL_12]], %[[VAL_3]]{{\[}}%[[VAL_9]]] : memref<5xi32, 1>
+// CHECK: }
+// CHECK: %[[VAL_13:.*]] = memref.alloc() : memref<5xi32, 1>
+// CHECK: memref.copy %[[VAL_3]], %[[VAL_13]] : memref<5xi32, 1> to memref<5xi32, 1>
+// CHECK: return %[[VAL_13]] : memref<5xi32, 1>
+// CHECK: }
+
+func.func private @func_with_modified_argument_directly(%arg0: memref<5xi32, 1>, %arg1: memref<5xi32, 1>, %arg2: memref<5xi32, 1>) -> (memref<5xi32, 1>){
+ %c1 = arith.constant 1 : index
+ %c5 = arith.constant 5 : index
+ %c0 = arith.constant 0 : index
+ scf.for %arg3 = %c0 to %c5 step %c1 {
+ %0 = memref.load %arg0[%arg3] : memref<5xi32, 1>
+ %1 = arith.index_cast %0 : i32 to index
+ %2 = memref.load %arg1[%arg3] : memref<5xi32, 1>
+ %3 = memref.load %arg2[%1] : memref<5xi32, 1>
+ %4 = arith.addi %2, %3 : i32
+ memref.store %4, %arg2[%1] : memref<5xi32, 1>
+ }
+ %alloc = memref.alloc() : memref<5xi32, 1>
+ memref.copy %arg2, %alloc : memref<5xi32, 1> to memref<5xi32, 1>
+ return %alloc : memref<5xi32, 1>
+}
+
+// -----
+
+// CHECK-LABEL: func.func private @func_with_modified_argument_indirectly(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<3x3x4xf32, 1>) -> memref<3x3x4xf32, 1> {
+// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<3x3x4xf32, 1>
+// CHECK: memref.copy %[[VAL_0]], %[[VAL_1]] : memref<3x3x4xf32, 1> to memref<3x3x4xf32, 1>
+// CHECK: %[[VAL_2:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0, 1], [2]] : memref<3x3x4xf32, 1> into memref<9x4xf32, 1>
+// CHECK: %[[VAL_3:.*]] = memref.expand_shape %[[VAL_2]] {{\[\[}}0, 1], [2]] output_shape [3, 3, 4] : memref<9x4xf32, 1> into memref<3x3x4xf32, 1>
+// CHECK: linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel"]} outs(%[[VAL_3]] : memref<3x3x4xf32, 1>) {
+// CHECK: ^bb0(%[[VAL_4:.*]]: f32):
+// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_4]], %[[VAL_4]] : f32
+// CHECK: linalg.yield %[[VAL_5]] : f32
+// CHECK: }
+// CHECK: return %[[VAL_3]] : memref<3x3x4xf32, 1>
+// CHECK: }
+
+func.func private @func_with_modified_argument_indirectly(%arg0: memref<3x3x4xf32, 1>) -> (memref<3x3x4xf32, 1>) {
+ %collapse_arg = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<3x3x4xf32, 1> into memref<9x4xf32, 1>
+ %expand_arg = memref.expand_shape %collapse_arg [[0, 1], [2]] output_shape [3, 3, 4] : memref<9x4xf32, 1> into memref<3x3x4xf32, 1>
+ linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+ iterator_types = ["parallel", "parallel", "parallel"]
+ }
+ outs(%expand_arg : memref<3x3x4xf32, 1>) {
+ ^bb0(%out: f32):
+ %0 = arith.addf %out, %out : f32
+ linalg.yield %0 : f32
+ }
+ return %expand_arg: memref<3x3x4xf32, 1>
+}
+
+// -----
+
+// CHECK-LABEL: func.func private @func_with_modified_argument_subview(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<2x4x4xi32, 1>) -> memref<4x4xi32, 1> {
+// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<2x4x4xi32, 1>
+// CHECK: memref.copy %[[VAL_0]], %[[VAL_1]] : memref<2x4x4xi32, 1> to memref<2x4x4xi32, 1>
+// CHECK: %[[VAL_2:.*]] = memref.subview %[[VAL_1]][0, 0, 0] [1, 4, 4] [1, 1, 1] : memref<2x4x4xi32, 1> to memref<1x4x4xi32, strided<[16, 4, 1]>, 1>
+// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0, 1], [2]] : memref<1x4x4xi32, strided<[16, 4, 1]>, 1> into memref<4x4xi32, strided<[4, 1]>, 1>
+// CHECK: %[[VAL_4:.*]] = memref.cast %[[VAL_3]] : memref<4x4xi32, strided<[4, 1]>, 1> to memref<4x4xi32, 1>
+// CHECK: linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[VAL_4]] : memref<4x4xi32, 1>) {
+// CHECK: ^bb0(%[[VAL_5:.*]]: i32):
+// CHECK: %[[VAL_6:.*]] = arith.addi %[[VAL_5]], %[[VAL_5]] : i32
+// CHECK: linalg.yield %[[VAL_6]] : i32
+// CHECK: }
+// CHECK: %[[VAL_7:.*]] = memref.alloc() : memref<4x4xi32, 1>
+// CHECK: memref.copy %[[VAL_4]], %[[VAL_7]] : memref<4x4xi32, 1> to memref<4x4xi32, 1>
+// CHECK: return %[[VAL_7]] : memref<4x4xi32, 1>
+// CHECK: }
+
+func.func private @func_with_modified_argument_subview(%arg0: memref<2x4x4xi32, 1>) -> ( memref<4x4xi32, 1>){
+ %subview = memref.subview %arg0[0, 0, 0] [1, 4, 4] [1, 1, 1] : memref<2x4x4xi32, 1> to memref<1x4x4xi32, strided<[16, 4, 1]>, 1>
+ %collapse_shape = memref.collapse_shape %subview [[0, 1], [2]] : memref<1x4x4xi32, strided<[16, 4, 1]>, 1> into memref<4x4xi32, strided<[4, 1]>, 1>
+ %cast = memref.cast %collapse_shape : memref<4x4xi32, strided<[4, 1]>, 1> to memref<4x4xi32, 1>
+ linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]
+ }
+ outs(%cas...
[truncated]
|
Adding a pass which allocates a new a buffer for each input argument of the function it operates on and is being written to, also copying it into the allocated buffer by a
memref.copy.