Skip to content

Commit fa019c9

Browse files
authored
[Codegen] Add patterns to fold reshapes into load_from/store_to_memref (#20881)
Adds reshape folding patterns for iree_codegen.load_from_memref and iree_codegen.store_to_memref. The patterns move the reshape ops from the tensor result/operand to the buffer operand. The patterns are also added to CleanupBufferAllocViewPass. --------- Signed-off-by: Max Dawkins <[email protected]>
1 parent b56c2c4 commit fa019c9

File tree

8 files changed

+389
-4
lines changed

8 files changed

+389
-4
lines changed

compiler/src/iree/compiler/Codegen/Common/CleanupBufferAllocViewPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ struct CleanupBufferAllocViewPass final
3535
void runOnOperation() override {
3636
RewritePatternSet patterns(&getContext());
3737
populateReshapeToInterfaceTensorPatterns(patterns);
38+
populateFoldTensorReshapeIntoBufferPatterns(patterns);
3839
populateRemoveDeadMemAllocPatterns(patterns);
3940
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
4041
return signalPassFailure();

compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ iree_lit_test_suite(
2525
"bubble_up_ordinal_ops.mlir",
2626
"bufferize_copy_only_dispatches.mlir",
2727
"bufferize_dispatch_tensor_load_store.mlir",
28+
"canonicalize_early_bufferization_ops.mlir",
2829
"canonicalize_interface_load_store.mlir",
2930
"check_for_config.mlir",
3031
"combine_layout_transformation.mlir",

compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ iree_lit_test_suite(
2121
"bubble_up_ordinal_ops.mlir"
2222
"bufferize_copy_only_dispatches.mlir"
2323
"bufferize_dispatch_tensor_load_store.mlir"
24+
"canonicalize_early_bufferization_ops.mlir"
2425
"canonicalize_interface_load_store.mlir"
2526
"check_for_config.mlir"
2627
"combine_layout_transformation.mlir"
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-codegen-cleanup-buffer-alloc-view))" %s | FileCheck %s
2+
3+
#pipeline_layout = #hal.pipeline.layout<bindings = [
4+
#hal.pipeline.binding<storage_buffer>,
5+
#hal.pipeline.binding<storage_buffer>
6+
]>
7+
func.func @fold_reshape_load() {
8+
%cst = arith.constant 0.000000e+00 : f32
9+
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : memref<3x3x1x96xf32, #hal.descriptor_type<storage_buffer>>
10+
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) : memref<3x3x96xf32, #hal.descriptor_type<storage_buffer>>
11+
%2 = iree_codegen.load_from_memref %0 : memref<3x3x1x96xf32, #hal.descriptor_type<storage_buffer>> -> tensor<3x3x1x96xf32>
12+
%collapsed = tensor.collapse_shape %2 [[0, 1, 2, 3]] : tensor<3x3x1x96xf32> into tensor<864xf32>
13+
%expanded = tensor.expand_shape %collapsed [[0, 1, 2]] output_shape [3, 3, 96] : tensor<864xf32> into tensor<3x3x96xf32>
14+
%barrier = util.optimization_barrier %expanded : tensor<3x3x96xf32>
15+
iree_codegen.store_to_memref %barrier, %1 : tensor<3x3x96xf32> into memref<3x3x96xf32, #hal.descriptor_type<storage_buffer>>
16+
return
17+
}
18+
// CHECK-LABEL: @fold_reshape_load
19+
// CHECK-DAG: %[[SRC_SUBSPAN:.+]] = hal.interface.binding.subspan{{.*}} binding(0){{.*}} memref<3x3x1x96xf32
20+
// CHECK-DAG: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC_SUBSPAN]]{{.*}} into memref<864xf32
21+
// CHECK-DAG: %[[EXPAND:.+]] = memref.expand_shape %[[COLLAPSE]]{{.*}} into memref<3x3x96xf32
22+
// CHECK-DAG: %[[DEST_SUBSPAN:.+]] = hal.interface.binding.subspan{{.*}} binding(1)
23+
// CHECK: %[[LOAD:.+]] = iree_codegen.load_from_memref %[[EXPAND]]
24+
// CHECK-SAME: memref<3x3x96xf32, #hal.descriptor_type<storage_buffer>> -> tensor<3x3x96xf32>
25+
// CHECK: %[[BARRIER:.+]] = util.optimization_barrier %[[LOAD]]
26+
// CHECK: iree_codegen.store_to_memref %[[BARRIER]], %[[DEST_SUBSPAN]]
27+
28+
// -----
29+
30+
#pipeline_layout = #hal.pipeline.layout<bindings = [
31+
#hal.pipeline.binding<storage_buffer>,
32+
#hal.pipeline.binding<storage_buffer>
33+
]>
34+
func.func @fold_reshape_store() {
35+
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : memref<3x3x1x96xf32, #hal.descriptor_type<storage_buffer>>
36+
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) : memref<3x3x96xf32, #hal.descriptor_type<storage_buffer>>
37+
%2 = iree_codegen.load_from_memref %0 : memref<3x3x1x96xf32, #hal.descriptor_type<storage_buffer>> -> tensor<3x3x1x96xf32>
38+
%barrier = util.optimization_barrier %2 : tensor<3x3x1x96xf32>
39+
%collapsed = tensor.collapse_shape %barrier [[0, 1, 2, 3]] : tensor<3x3x1x96xf32> into tensor<864xf32>
40+
%expanded = tensor.expand_shape %collapsed [[0, 1, 2]] output_shape [3, 3, 96] : tensor<864xf32> into tensor<3x3x96xf32>
41+
iree_codegen.store_to_memref %expanded, %1 : tensor<3x3x96xf32> into memref<3x3x96xf32, #hal.descriptor_type<storage_buffer>>
42+
return
43+
}
44+
// CHECK-LABEL: @fold_reshape_store
45+
// CHECK-DAG: %[[SRC_SUBSPAN:.+]] = hal.interface.binding.subspan{{.*}} binding(0)
46+
// CHECK-DAG: %[[DEST_SUBSPAN:.+]] = hal.interface.binding.subspan{{.*}} binding(1){{.*}} memref<3x3x96xf32
47+
// CHECK-DAG: %[[COLLAPSE:.+]] = memref.collapse_shape %[[DEST_SUBSPAN]]{{.*}} into memref<864xf32
48+
// CHECK-DAG: %[[EXPAND:.+]] = memref.expand_shape %[[COLLAPSE]]{{.*}} into memref<3x3x1x96xf32
49+
// CHECK: %[[LOAD:.+]] = iree_codegen.load_from_memref %[[SRC_SUBSPAN]]
50+
// CHECK: %[[BARRIER:.+]] = util.optimization_barrier %[[LOAD]]
51+
// CHECK: iree_codegen.store_to_memref %[[BARRIER]], %[[EXPAND]]
52+
// CHECK-SAME: tensor<3x3x1x96xf32> into memref<3x3x1x96xf32, #hal.descriptor_type<storage_buffer>>
53+
54+
// -----
55+
56+
#pipeline_layout = #hal.pipeline.layout<bindings = [
57+
#hal.pipeline.binding<storage_buffer>,
58+
#hal.pipeline.binding<storage_buffer>
59+
]>
60+
func.func @fold_reshape_with_slice_load() {
61+
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : memref<6x3x1x96xf32, #hal.descriptor_type<storage_buffer>>
62+
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) : memref<3x3x96xf32, #hal.descriptor_type<storage_buffer>>
63+
%subview = memref.subview %0[3, 0, 0, 0] [3, 3, 1, 96] [1, 1, 1, 1] : memref<6x3x1x96xf32, #hal.descriptor_type<storage_buffer>> to memref<3x3x1x96xf32, strided<[288, 96, 96, 1], offset: 864>, #hal.descriptor_type<storage_buffer>>
64+
%2 = iree_codegen.load_from_memref %subview : memref<3x3x1x96xf32, strided<[288, 96, 96, 1], offset: 864>, #hal.descriptor_type<storage_buffer>> -> tensor<3x3x1x96xf32>
65+
%collapsed = tensor.collapse_shape %2 [[0, 1, 2, 3]] : tensor<3x3x1x96xf32> into tensor<864xf32>
66+
%expanded = tensor.expand_shape %collapsed [[0, 1, 2]] output_shape [3, 3, 96] : tensor<864xf32> into tensor<3x3x96xf32>
67+
%barrier = util.optimization_barrier %expanded : tensor<3x3x96xf32>
68+
iree_codegen.store_to_memref %barrier, %1 : tensor<3x3x96xf32> into memref<3x3x96xf32, #hal.descriptor_type<storage_buffer>>
69+
return
70+
}
71+
// CHECK-LABEL: @fold_reshape_with_slice_load
72+
// CHECK-DAG: %[[SRC_SUBSPAN:.+]] = hal.interface.binding.subspan{{.*}} binding(0)
73+
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[SRC_SUBSPAN]]
74+
// CHECK-SAME: memref<3x3x1x96xf32, strided<[288, 96, 96, 1], offset: 864>
75+
// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SUBVIEW]]
76+
// CHECK-SAME: into memref<864xf32, strided<[1], offset: 864>
77+
// CHECK: %[[EXPAND:.+]] = memref.expand_shape %[[COLLAPSE]]
78+
// CHECK-SAME: into memref<3x3x96xf32, strided<[288, 96, 1], offset: 864>
79+
// CHECK: iree_codegen.load_from_memref %[[EXPAND]]
80+
81+
// -----
82+
83+
#pipeline_layout = #hal.pipeline.layout<bindings = [
84+
#hal.pipeline.binding<storage_buffer>,
85+
#hal.pipeline.binding<storage_buffer>
86+
]>
87+
func.func @fold_reshape_with_slice_store() {
88+
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : memref<3x3x1x96xf32, #hal.descriptor_type<storage_buffer>>
89+
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) : memref<6x3x96xf32, #hal.descriptor_type<storage_buffer>>
90+
%subview = memref.subview %1[3, 0, 0] [3, 3, 96] [1, 1, 1] : memref<6x3x96xf32, #hal.descriptor_type<storage_buffer>> to memref<3x3x96xf32, strided<[288, 96, 1], offset: 864>, #hal.descriptor_type<storage_buffer>>
91+
%2 = iree_codegen.load_from_memref %0 : memref<3x3x1x96xf32, #hal.descriptor_type<storage_buffer>> -> tensor<3x3x1x96xf32>
92+
%barrier = util.optimization_barrier %2 : tensor<3x3x1x96xf32>
93+
%collapsed = tensor.collapse_shape %barrier [[0, 1, 2, 3]] : tensor<3x3x1x96xf32> into tensor<864xf32>
94+
%expanded = tensor.expand_shape %collapsed [[0, 1, 2]] output_shape [3, 3, 96] : tensor<864xf32> into tensor<3x3x96xf32>
95+
iree_codegen.store_to_memref %expanded, %subview : tensor<3x3x96xf32> into memref<3x3x96xf32, strided<[288, 96, 1], offset: 864>, #hal.descriptor_type<storage_buffer>>
96+
return
97+
}
98+
// CHECK-LABEL: @fold_reshape_with_slice_store
99+
// CHECK-DAG: %[[DEST_SUBSPAN:.+]] = hal.interface.binding.subspan{{.*}} binding(1)
100+
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST_SUBSPAN]]
101+
// CHECK-SAME: memref<3x3x96xf32, strided<[288, 96, 1], offset: 864>
102+
// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SUBVIEW]]
103+
// CHECK-SAME: into memref<864xf32, strided<[1], offset: 864>
104+
// CHECK: %[[EXPAND:.+]] = memref.expand_shape %[[COLLAPSE]]
105+
// CHECK-SAME: into memref<3x3x1x96xf32, strided<[288, 96, 96, 1], offset: 864>
106+
// CHECK: iree_codegen.store_to_memref {{.*}}, %[[EXPAND]]
107+
108+
// -----
109+
110+
#pipeline_layout = #hal.pipeline.layout<constants = 4, bindings = [
111+
#hal.pipeline.binding<storage_buffer>,
112+
#hal.pipeline.binding<storage_buffer>
113+
]>
114+
func.func @fold_dynamic_reshape_load() {
115+
%c0 = arith.constant 0 : index
116+
%0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
117+
%1 = hal.interface.constant.load layout(#pipeline_layout) ordinal(1) : index
118+
%2 = hal.interface.constant.load layout(#pipeline_layout) ordinal(2) : index
119+
%3 = hal.interface.constant.load layout(#pipeline_layout) ordinal(3) : index
120+
%4 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : memref<?x?xf32, #hal.descriptor_type<storage_buffer>>{%0, %1}
121+
memref.assume_alignment %4, 1 : memref<?x?xf32, #hal.descriptor_type<storage_buffer>>
122+
%5 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) : memref<?x?xf32, #hal.descriptor_type<storage_buffer>>{%2, %3}
123+
memref.assume_alignment %5, 1 : memref<?x?xf32, #hal.descriptor_type<storage_buffer>>
124+
%6 = iree_codegen.load_from_memref %4 : memref<?x?xf32, #hal.descriptor_type<storage_buffer>> -> tensor<?x?xf32>
125+
%collapsed = tensor.collapse_shape %6 [[0, 1]] : tensor<?x?xf32> into tensor<?xf32>
126+
%expanded = tensor.expand_shape %collapsed [[0, 1]] output_shape [%2, %3] : tensor<?xf32> into tensor<?x?xf32>
127+
%barrier = util.optimization_barrier %expanded : tensor<?x?xf32>
128+
iree_codegen.store_to_memref %barrier, %5 : tensor<?x?xf32> into memref<?x?xf32, #hal.descriptor_type<storage_buffer>>
129+
return
130+
}
131+
// CHECK-LABEL: @fold_dynamic_reshape_load
132+
// CHECK-DAG: %[[D0:.+]] = hal.interface.constant.load{{.*}} ordinal(2) : index
133+
// CHECK-DAG: %[[D1:.+]] = hal.interface.constant.load{{.*}} ordinal(3) : index
134+
// CHECK-DAG: %[[SRC_SUBSPAN:.+]] = hal.interface.binding.subspan{{.*}} binding(0)
135+
// CHECK-DAG: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC_SUBSPAN]]
136+
// CHECK-DAG: %[[EXPAND:.+]] = memref.expand_shape %[[COLLAPSE]]
137+
// CHECK-SAME: output_shape [%[[D0]], %[[D1]]]
138+
// CHECK-DAG: %[[LOAD:.+]] = iree_codegen.load_from_memref %[[EXPAND]]
139+
// CHECK-DAG: %[[DEST_SUBSPAN:.+]] = hal.interface.binding.subspan{{.*}} binding(1)
140+
// CHECK-DAG: %[[BARRIER:.+]] = util.optimization_barrier %[[LOAD]]
141+
// CHECK: iree_codegen.store_to_memref %[[BARRIER]], %[[DEST_SUBSPAN]]
142+
143+
// -----
144+
145+
#pipeline_layout = #hal.pipeline.layout<constants = 4, bindings = [
146+
#hal.pipeline.binding<storage_buffer>,
147+
#hal.pipeline.binding<storage_buffer>
148+
]>
149+
func.func @fold_dynamic_reshape_store() {
150+
%c0 = arith.constant 0 : index
151+
%0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
152+
%1 = hal.interface.constant.load layout(#pipeline_layout) ordinal(1) : index
153+
%2 = hal.interface.constant.load layout(#pipeline_layout) ordinal(2) : index
154+
%3 = hal.interface.constant.load layout(#pipeline_layout) ordinal(3) : index
155+
%4 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : memref<?x?xf32, #hal.descriptor_type<storage_buffer>>{%0, %1}
156+
memref.assume_alignment %4, 1 : memref<?x?xf32, #hal.descriptor_type<storage_buffer>>
157+
%5 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) : memref<?x?xf32, #hal.descriptor_type<storage_buffer>>{%2, %3}
158+
memref.assume_alignment %5, 1 : memref<?x?xf32, #hal.descriptor_type<storage_buffer>>
159+
%6 = iree_codegen.load_from_memref %4 : memref<?x?xf32, #hal.descriptor_type<storage_buffer>> -> tensor<?x?xf32>
160+
%barrier = util.optimization_barrier %6 : tensor<?x?xf32>
161+
%collapsed = tensor.collapse_shape %barrier [[0, 1]] : tensor<?x?xf32> into tensor<?xf32>
162+
%expanded = tensor.expand_shape %collapsed [[0, 1]] output_shape [%2, %3] : tensor<?xf32> into tensor<?x?xf32>
163+
iree_codegen.store_to_memref %expanded, %5 : tensor<?x?xf32> into memref<?x?xf32, #hal.descriptor_type<storage_buffer>>
164+
return
165+
}
166+
// CHECK-LABEL: @fold_dynamic_reshape_store
167+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0
168+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1
169+
// CHECK-DAG: %[[SRC_SUBSPAN:.+]] = hal.interface.binding.subspan{{.*}} binding(0)
170+
// CHECK-DAG: %[[LOAD:.+]] = iree_codegen.load_from_memref %[[SRC_SUBSPAN]]
171+
// CHECK-DAG: %[[BARRIER:.+]] = util.optimization_barrier %[[LOAD]]
172+
// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[BARRIER]], %[[C0]]
173+
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[BARRIER]], %[[C1]]
174+
// CHECK-DAG: %[[DEST_SUBSPAN:.+]] = hal.interface.binding.subspan{{.*}} binding(1)
175+
// CHECK-DAG: %[[COLLAPSE:.+]] = memref.collapse_shape %[[DEST_SUBSPAN]]
176+
// CHECK-DAG: %[[EXPAND:.+]] = memref.expand_shape %[[COLLAPSE]]
177+
// CHECK-SAME: output_shape [%[[D0]], %[[D1]]]
178+
// CHECK: iree_codegen.store_to_memref %[[BARRIER]], %[[EXPAND]]

0 commit comments

Comments
 (0)