Skip to content

Commit 0201daa

Browse files
committed
Add test to check reduction reshape fusion
Signed-off-by: Ian Wood <[email protected]>
1 parent f31cf5a commit 0201daa

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

mlir/test/Dialect/Linalg/reshape_fusion.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,31 @@ func.func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor<?x?xf32>,
482482

483483
// -----
484484

485+
func.func @fuse_collapse_reduction(%arg0: tensor<10x10x20xf32>) -> tensor<100xf32> {
486+
%c0 = arith.constant 0 : index
487+
%c_0 = arith.constant 0.0 : f32
488+
%0 = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<10x10x20xf32> into tensor<100x20xf32>
489+
%2 = tensor.empty() : tensor<100xf32>
490+
%3 = linalg.fill ins(%c_0 : f32) outs(%2 : tensor<100xf32>) -> tensor<100xf32>
491+
%4 = linalg.generic {
492+
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
493+
iterator_types = ["parallel", "reduction"]}
494+
ins(%0 : tensor<100x20xf32>) outs(%3 : tensor<100xf32>) {
495+
^bb0(%arg1 : f32, %arg2: f32):
496+
%4 = arith.addf %arg1, %arg2 : f32
497+
linalg.yield %4 : f32
498+
} -> tensor<100xf32>
499+
return %4 : tensor<100xf32>
500+
}
501+
502+
// CHECK: func @fuse_collapse_reduction
503+
// CHECK-SAME: %[[ARG0:.+]]: tensor<10x10x20xf32>
504+
// CHECK: %[[GENERIC:.+]] = linalg.generic
505+
// CHECK-SAME: ins(%[[ARG0]] : tensor<10x10x20xf32>)
506+
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[GENERIC]]
507+
// CHECK: return %[[COLLAPSE]]
508+
// -----
509+
485510
func.func @no_fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
486511
%c0 = arith.constant 0 : index
487512
%0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<?x?xf32> into tensor<?xf32>

0 commit comments

Comments
 (0)