Skip to content

Commit a48269c

Browse files
committed
[mlir][bufferization]-Add enforce immutable func args pass
Adding a pass which allocates a new a buffer for each input argument of the function it operates on and is being written to, also copying it into the allocated buffer by a `memref.copy`.
1 parent ba1255d commit a48269c

File tree

5 files changed

+368
-0
lines changed

5 files changed

+368
-0
lines changed

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,10 @@ createPromoteBuffersToStackPass(std::function<bool(Value)> isSmallAlloc);
229229
/// insert_slice ops.
230230
std::unique_ptr<Pass> createEmptyTensorEliminationPass();
231231

232+
// Create a pass that enforces read only buffers of the
233+
// relevant function arguments.
234+
std::unique_ptr<Pass> createEnforceImmutableFuncArgsPass();
235+
232236
//===----------------------------------------------------------------------===//
233237
// Registration
234238
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,4 +595,18 @@ def EmptyTensorElimination : Pass<"eliminate-empty-tensors"> {
595595
let constructor = "mlir::bufferization::createEmptyTensorEliminationPass()";
596596
}
597597

598+
def EnforceImmutableFuncArgs : Pass<"enforce-immutable-func-args", "func::FuncOp"> {
599+
let summary = "Enforcing function's arguments immutabilty by inserting allocOps and copy";
600+
let description = [{
601+
This pass allocates a new a buffer for each input argument of the function
602+
which is being written to and marked to be enforced, also copying it into the
603+
allocated buffer.
604+
This will avoid in place memory updates for the function's arguments and
605+
make it immutable/read-only buffer.
606+
}];
607+
let constructor = "mlir::bufferization::createEnforceImmutableFuncArgsPass()";
608+
let dependentDialects = ["memref::MemRefDialect"];
609+
}
610+
611+
598612
#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES

mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRBufferizationTransforms
1616
OwnershipBasedBufferDeallocation.cpp
1717
TensorCopyInsertion.cpp
1818
OptimizeAllocationLiveness.cpp
19+
EnforceImmutableFuncArgs.cpp
1920

2021
ADDITIONAL_HEADER_DIRS
2122
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Bufferization
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
//===- OptimizeAllocationLiveness.cpp - impl. optimize allocation liveness pass
2+
//-===//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
//
10+
// This file implements a pass for optimizing allocation liveness.
11+
// The pass moves the deallocation operation after the last user of the
12+
// allocated buffer.
13+
//===----------------------------------------------------------------------===//
14+
15+
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
16+
#include "mlir/Dialect/Func/IR/FuncOps.h"
17+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
18+
#include "mlir/IR/Operation.h"
19+
#include "llvm/Support/Debug.h"
20+
21+
#define DEBUG_TYPE "enforce-immutable-func-args"
22+
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
23+
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
24+
25+
using namespace mlir;
26+
27+
namespace mlir {
28+
namespace bufferization {
29+
#define GEN_PASS_DEF_ENFORCEIMMUTABLEFUNCARGS
30+
#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
31+
} // namespace bufferization
32+
} // namespace mlir
33+
34+
// Checks if there is any operation which tries to write
35+
// into `buffer`.
36+
// This method assumes buffer has `MemRefType`.
37+
static bool isWrittenTo(Value buffer);
38+
39+
namespace {
40+
/// This pass allocates a new a buffer for each input argument of the function
41+
/// which is being written to, also copying it into the allocated buffer.
42+
/// This will avoid in place memory updates for the kernel's arguments and
43+
/// make them immutable/read-only buffers.
44+
struct EnforceImmutableFuncArgsPass
45+
: public bufferization::impl::EnforceImmutableFuncArgsBase<
46+
EnforceImmutableFuncArgsPass> {
47+
void runOnOperation() final;
48+
};
49+
} // end anonymous namespace.
50+
51+
static bool isWrittenTo(Value buffer) {
52+
assert(isa<MemRefType>(buffer.getType()));
53+
54+
for (auto user : buffer.getUsers()) {
55+
if (hasEffect<MemoryEffects::Write>(user, buffer))
56+
return true;
57+
if (auto viewLikeOp = dyn_cast<ViewLikeOpInterface>(user)) {
58+
assert(viewLikeOp->getNumResults() == 1);
59+
if (isWrittenTo(viewLikeOp->getResult(0)))
60+
return true;
61+
}
62+
}
63+
return false;
64+
}
65+
66+
void EnforceImmutableFuncArgsPass::runOnOperation() {
67+
68+
func::FuncOp funcOp = getOperation();
69+
70+
LDBG("enforcing immutable function arguments in func " << funcOp.getName());
71+
72+
IRRewriter rewriter(funcOp->getContext());
73+
rewriter.setInsertionPointToStart(&funcOp.getBody().front());
74+
for (auto argument : funcOp.getArguments()) {
75+
76+
auto argType = dyn_cast<MemRefType>(argument.getType());
77+
if (!argType) {
78+
emitError(argument.getLoc(),
79+
"function has argument with non memref type");
80+
return signalPassFailure();
81+
}
82+
83+
if (!isWrittenTo(argument))
84+
continue;
85+
86+
LDBG("Found a function argument is being written to " << argument);
87+
Value allocatedMemref =
88+
rewriter.create<memref::AllocOp>(funcOp.getLoc(), argType);
89+
rewriter.replaceAllUsesWith(argument, allocatedMemref);
90+
rewriter.create<memref::CopyOp>(funcOp.getLoc(), argument, allocatedMemref);
91+
}
92+
}
93+
94+
//===----------------------------------------------------------------------===//
95+
// EnforceImmutableFuncArgs construction
96+
//===----------------------------------------------------------------------===//
97+
98+
std::unique_ptr<Pass>
99+
mlir::bufferization::createEnforceImmutableFuncArgsPass() {
100+
return std::make_unique<EnforceImmutableFuncArgsPass>();
101+
}
Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
// RUN: mlir-opt --split-input-file --enforce-immutable-func-args %s -o - | FileCheck %s
2+
3+
4+
// CHECK-LABEL: func.func @func_no_input() {
5+
// CHECK: return
6+
// CHECK: }
7+
8+
func.func @func_no_input() {
9+
return
10+
}
11+
12+
// -----
13+
14+
// CHECK-LABEL: func.func private @func_with_returned_argument(
15+
// CHECK-SAME: %[[VAL_0:.*]]: memref<1x13x21x3xf32>) -> memref<1x13x21x3xf32> {
16+
// CHECK: return %[[VAL_0]] : memref<1x13x21x3xf32>
17+
// CHECK: }
18+
19+
func.func private @func_with_returned_argument(%arg0: memref<1x13x21x3xf32>) -> (memref<1x13x21x3xf32>) {
20+
return %arg0 : memref<1x13x21x3xf32>
21+
}
22+
23+
// -----
24+
25+
// CHECK-LABEL: func.func private @func_with_modified_argument_directly(
26+
// CHECK-SAME: %[[VAL_0:.*]]: memref<1x13x21x3xf32>, %[[VAL_1:.*]]: memref<1x13x21x3xf32>) -> memref<1x13x21x3xf32> {
27+
// CHECK: %[[VAL_2:.*]] = memref.alloc() : memref<1x13x21x3xf32>
28+
// CHECK: memref.copy %[[VAL_0]], %[[VAL_2]] : memref<1x13x21x3xf32> to memref<1x13x21x3xf32>
29+
// CHECK: %[[VAL_3:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x13x21x3xf32>
30+
// CHECK: linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[VAL_2]], %[[VAL_1]] : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>) outs(%[[VAL_2]] : memref<1x13x21x3xf32>) {
31+
// CHECK: ^bb0(%[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: f32):
32+
// CHECK: %[[VAL_7:.*]] = arith.addf %[[VAL_4]], %[[VAL_5]] : f32
33+
// CHECK: linalg.yield %[[VAL_7]] : f32
34+
// CHECK: }
35+
// CHECK: return %[[VAL_3]] : memref<1x13x21x3xf32>
36+
// CHECK: }
37+
38+
func.func private @func_with_modified_argument_directly(%arg0: memref<1x13x21x3xf32>, %arg1: memref<1x13x21x3xf32>) -> (memref<1x13x21x3xf32>){
39+
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x13x21x3xf32>
40+
linalg.generic {
41+
indexing_maps = [
42+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
43+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
44+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
45+
],
46+
iterator_types = ["parallel", "parallel", "parallel", "parallel"]
47+
}
48+
ins(%arg0, %arg1 : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>)
49+
outs(%arg0 : memref<1x13x21x3xf32>) {
50+
^bb0(%in: f32, %in_0: f32, %out: f32):
51+
%0 = arith.addf %in, %in_0 : f32
52+
linalg.yield %0 : f32
53+
}
54+
return %alloc : memref<1x13x21x3xf32>
55+
}
56+
57+
// -----
58+
59+
// CHECK-LABEL: func.func private @func_with_modified_argument_directly_and_returned(
60+
// CHECK-SAME: %[[VAL_0:.*]]: memref<1x13x21x3xf32>, %[[VAL_1:.*]]: memref<1x13x21x3xf32>) -> memref<1x13x21x3xf32> {
61+
// CHECK: %[[VAL_2:.*]] = memref.alloc() : memref<1x13x21x3xf32>
62+
// CHECK: memref.copy %[[VAL_0]], %[[VAL_2]] : memref<1x13x21x3xf32> to memref<1x13x21x3xf32>
63+
// CHECK: linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[VAL_2]], %[[VAL_1]] : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>) outs(%[[VAL_2]] : memref<1x13x21x3xf32>) {
64+
// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32):
65+
// CHECK: %[[VAL_6:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32
66+
// CHECK: linalg.yield %[[VAL_6]] : f32
67+
// CHECK: }
68+
// CHECK: return %[[VAL_2]] : memref<1x13x21x3xf32>
69+
// CHECK: }
70+
71+
func.func private @func_with_modified_argument_directly_and_returned(%arg0: memref<1x13x21x3xf32>, %arg1: memref<1x13x21x3xf32>) -> (memref<1x13x21x3xf32>){
72+
linalg.generic {
73+
indexing_maps = [
74+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
75+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
76+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
77+
],
78+
iterator_types = ["parallel", "parallel", "parallel", "parallel"]
79+
}
80+
ins(%arg0, %arg1 : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>)
81+
outs(%arg0 : memref<1x13x21x3xf32>) {
82+
^bb0(%in: f32, %in_0: f32, %out: f32):
83+
%0 = arith.addf %in, %in_0 : f32
84+
linalg.yield %0 : f32
85+
}
86+
return %arg0 : memref<1x13x21x3xf32>
87+
}
88+
89+
// -----
90+
91+
// CHECK-LABEL: func.func private @func_with_modified_argument_directly_twice(
92+
// CHECK-SAME: %[[VAL_0:.*]]: memref<1x13x21x3xf32>, %[[VAL_1:.*]]: memref<1x13x21x3xf32>) -> memref<1x13x21x3xf32> {
93+
// CHECK: %[[VAL_2:.*]] = memref.alloc() : memref<1x13x21x3xf32>
94+
// CHECK: memref.copy %[[VAL_0]], %[[VAL_2]] : memref<1x13x21x3xf32> to memref<1x13x21x3xf32>
95+
// CHECK: %[[VAL_3:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x13x21x3xf32>
96+
// CHECK: linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[VAL_2]], %[[VAL_1]] : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>) outs(%[[VAL_2]] : memref<1x13x21x3xf32>) {
97+
// CHECK: ^bb0(%[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: f32):
98+
// CHECK: %[[VAL_7:.*]] = arith.addf %[[VAL_4]], %[[VAL_5]] : f32
99+
// CHECK: linalg.yield %[[VAL_7]] : f32
100+
// CHECK: }
101+
// CHECK: linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[VAL_2]], %[[VAL_1]] : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>) outs(%[[VAL_2]] : memref<1x13x21x3xf32>) {
102+
// CHECK: ^bb0(%[[VAL_8:.*]]: f32, %[[VAL_9:.*]]: f32, %[[VAL_10:.*]]: f32):
103+
// CHECK: %[[VAL_11:.*]] = arith.addf %[[VAL_8]], %[[VAL_9]] : f32
104+
// CHECK: linalg.yield %[[VAL_11]] : f32
105+
// CHECK: }
106+
// CHECK: return %[[VAL_3]] : memref<1x13x21x3xf32>
107+
// CHECK: }
108+
109+
func.func private @func_with_modified_argument_directly_twice(%arg0: memref<1x13x21x3xf32>, %arg1: memref<1x13x21x3xf32>) -> (memref<1x13x21x3xf32>){
110+
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x13x21x3xf32>
111+
linalg.generic {
112+
indexing_maps = [
113+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
114+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
115+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
116+
],
117+
iterator_types = ["parallel", "parallel", "parallel", "parallel"]
118+
}
119+
ins(%arg0, %arg1 : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>)
120+
outs(%arg0 : memref<1x13x21x3xf32>) {
121+
^bb0(%in: f32, %in_0: f32, %out: f32):
122+
%0 = arith.addf %in, %in_0 : f32
123+
linalg.yield %0 : f32
124+
}
125+
linalg.generic {
126+
indexing_maps = [
127+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
128+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
129+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
130+
],
131+
iterator_types = ["parallel", "parallel", "parallel", "parallel"]
132+
}
133+
ins(%arg0, %arg1 : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>)
134+
outs(%arg0 : memref<1x13x21x3xf32>) {
135+
^bb0(%in: f32, %in_0: f32, %out: f32):
136+
%0 = arith.addf %in, %in_0 : f32
137+
linalg.yield %0 : f32
138+
}
139+
return %alloc : memref<1x13x21x3xf32>
140+
}
141+
142+
// -----
143+
144+
// CHECK-LABEL: func.func private @func_with_modified_argument_directly(
145+
// CHECK-SAME: %[[VAL_0:.*]]: memref<5xi32, 1>, %[[VAL_1:.*]]: memref<5xi32, 1>, %[[VAL_2:.*]]: memref<5xi32, 1>) -> memref<5xi32, 1> {
146+
// CHECK: %[[VAL_3:.*]] = memref.alloc() : memref<5xi32, 1>
147+
// CHECK: memref.copy %[[VAL_2]], %[[VAL_3]] : memref<5xi32, 1> to memref<5xi32, 1>
148+
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
149+
// CHECK: %[[VAL_5:.*]] = arith.constant 5 : index
150+
// CHECK: %[[VAL_6:.*]] = arith.constant 0 : index
151+
// CHECK: scf.for %[[VAL_7:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_4]] {
152+
// CHECK: %[[VAL_8:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_7]]] : memref<5xi32, 1>
153+
// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_8]] : i32 to index
154+
// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_7]]] : memref<5xi32, 1>
155+
// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_9]]] : memref<5xi32, 1>
156+
// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_10]], %[[VAL_11]] : i32
157+
// CHECK: memref.store %[[VAL_12]], %[[VAL_3]]{{\[}}%[[VAL_9]]] : memref<5xi32, 1>
158+
// CHECK: }
159+
// CHECK: %[[VAL_13:.*]] = memref.alloc() : memref<5xi32, 1>
160+
// CHECK: memref.copy %[[VAL_3]], %[[VAL_13]] : memref<5xi32, 1> to memref<5xi32, 1>
161+
// CHECK: return %[[VAL_13]] : memref<5xi32, 1>
162+
// CHECK: }
163+
164+
func.func private @func_with_modified_argument_directly(%arg0: memref<5xi32, 1>, %arg1: memref<5xi32, 1>, %arg2: memref<5xi32, 1>) -> (memref<5xi32, 1>){
165+
%c1 = arith.constant 1 : index
166+
%c5 = arith.constant 5 : index
167+
%c0 = arith.constant 0 : index
168+
scf.for %arg3 = %c0 to %c5 step %c1 {
169+
%0 = memref.load %arg0[%arg3] : memref<5xi32, 1>
170+
%1 = arith.index_cast %0 : i32 to index
171+
%2 = memref.load %arg1[%arg3] : memref<5xi32, 1>
172+
%3 = memref.load %arg2[%1] : memref<5xi32, 1>
173+
%4 = arith.addi %2, %3 : i32
174+
memref.store %4, %arg2[%1] : memref<5xi32, 1>
175+
}
176+
%alloc = memref.alloc() : memref<5xi32, 1>
177+
memref.copy %arg2, %alloc : memref<5xi32, 1> to memref<5xi32, 1>
178+
return %alloc : memref<5xi32, 1>
179+
}
180+
181+
// -----
182+
183+
// CHECK-LABEL: func.func private @func_with_modified_argument_indirectly(
184+
// CHECK-SAME: %[[VAL_0:.*]]: memref<3x3x4xf32, 1>) -> memref<3x3x4xf32, 1> {
185+
// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<3x3x4xf32, 1>
186+
// CHECK: memref.copy %[[VAL_0]], %[[VAL_1]] : memref<3x3x4xf32, 1> to memref<3x3x4xf32, 1>
187+
// CHECK: %[[VAL_2:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0, 1], [2]] : memref<3x3x4xf32, 1> into memref<9x4xf32, 1>
188+
// CHECK: %[[VAL_3:.*]] = memref.expand_shape %[[VAL_2]] {{\[\[}}0, 1], [2]] output_shape [3, 3, 4] : memref<9x4xf32, 1> into memref<3x3x4xf32, 1>
189+
// CHECK: linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel"]} outs(%[[VAL_3]] : memref<3x3x4xf32, 1>) {
190+
// CHECK: ^bb0(%[[VAL_4:.*]]: f32):
191+
// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_4]], %[[VAL_4]] : f32
192+
// CHECK: linalg.yield %[[VAL_5]] : f32
193+
// CHECK: }
194+
// CHECK: return %[[VAL_3]] : memref<3x3x4xf32, 1>
195+
// CHECK: }
196+
197+
func.func private @func_with_modified_argument_indirectly(%arg0: memref<3x3x4xf32, 1>) -> (memref<3x3x4xf32, 1>) {
198+
%collapse_arg = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<3x3x4xf32, 1> into memref<9x4xf32, 1>
199+
%expand_arg = memref.expand_shape %collapse_arg [[0, 1], [2]] output_shape [3, 3, 4] : memref<9x4xf32, 1> into memref<3x3x4xf32, 1>
200+
linalg.generic {
201+
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
202+
iterator_types = ["parallel", "parallel", "parallel"]
203+
}
204+
outs(%expand_arg : memref<3x3x4xf32, 1>) {
205+
^bb0(%out: f32):
206+
%0 = arith.addf %out, %out : f32
207+
linalg.yield %0 : f32
208+
}
209+
return %expand_arg: memref<3x3x4xf32, 1>
210+
}
211+
212+
// -----
213+
214+
// CHECK-LABEL: func.func private @func_with_modified_argument_subview(
215+
// CHECK-SAME: %[[VAL_0:.*]]: memref<2x4x4xi32, 1>) -> memref<4x4xi32, 1> {
216+
// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<2x4x4xi32, 1>
217+
// CHECK: memref.copy %[[VAL_0]], %[[VAL_1]] : memref<2x4x4xi32, 1> to memref<2x4x4xi32, 1>
218+
// CHECK: %[[VAL_2:.*]] = memref.subview %[[VAL_1]][0, 0, 0] [1, 4, 4] [1, 1, 1] : memref<2x4x4xi32, 1> to memref<1x4x4xi32, strided<[16, 4, 1]>, 1>
219+
// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0, 1], [2]] : memref<1x4x4xi32, strided<[16, 4, 1]>, 1> into memref<4x4xi32, strided<[4, 1]>, 1>
220+
// CHECK: %[[VAL_4:.*]] = memref.cast %[[VAL_3]] : memref<4x4xi32, strided<[4, 1]>, 1> to memref<4x4xi32, 1>
221+
// CHECK: linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[VAL_4]] : memref<4x4xi32, 1>) {
222+
// CHECK: ^bb0(%[[VAL_5:.*]]: i32):
223+
// CHECK: %[[VAL_6:.*]] = arith.addi %[[VAL_5]], %[[VAL_5]] : i32
224+
// CHECK: linalg.yield %[[VAL_6]] : i32
225+
// CHECK: }
226+
// CHECK: %[[VAL_7:.*]] = memref.alloc() : memref<4x4xi32, 1>
227+
// CHECK: memref.copy %[[VAL_4]], %[[VAL_7]] : memref<4x4xi32, 1> to memref<4x4xi32, 1>
228+
// CHECK: return %[[VAL_7]] : memref<4x4xi32, 1>
229+
// CHECK: }
230+
231+
func.func private @func_with_modified_argument_subview(%arg0: memref<2x4x4xi32, 1>) -> ( memref<4x4xi32, 1>){
232+
%subview = memref.subview %arg0[0, 0, 0] [1, 4, 4] [1, 1, 1] : memref<2x4x4xi32, 1> to memref<1x4x4xi32, strided<[16, 4, 1]>, 1>
233+
%collapse_shape = memref.collapse_shape %subview [[0, 1], [2]] : memref<1x4x4xi32, strided<[16, 4, 1]>, 1> into memref<4x4xi32, strided<[4, 1]>, 1>
234+
%cast = memref.cast %collapse_shape : memref<4x4xi32, strided<[4, 1]>, 1> to memref<4x4xi32, 1>
235+
linalg.generic {
236+
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
237+
iterator_types = ["parallel", "parallel"]
238+
}
239+
outs(%cast : memref<4x4xi32, 1>) {
240+
^bb0(%out: i32):
241+
%0 = arith.addi %out, %out : i32
242+
linalg.yield %0 : i32
243+
}
244+
%alloc = memref.alloc() : memref<4x4xi32, 1>
245+
memref.copy %cast, %alloc : memref<4x4xi32, 1> to memref<4x4xi32, 1>
246+
return %alloc : memref<4x4xi32, 1>
247+
}
248+

0 commit comments

Comments
 (0)