@@ -482,6 +482,31 @@ func.func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor<?x?xf32>,
482482
483483// -----
484484
485+ func.func @fuse_collapse_reduction (%arg0: tensor <10 x10 x20 xf32 >) -> tensor <100 xf32 > {
486+ %c0 = arith.constant 0 : index
487+ %c_0 = arith.constant 0.0 : f32
488+ %0 = tensor.collapse_shape %arg0 [[0 , 1 ], [2 ]] : tensor <10 x10 x20 xf32 > into tensor <100 x20 xf32 >
489+ %2 = tensor.empty () : tensor <100 xf32 >
490+ %3 = linalg.fill ins (%c_0 : f32 ) outs (%2 : tensor <100 xf32 >) -> tensor <100 xf32 >
491+ %4 = linalg.generic {
492+ indexing_maps = [affine_map <(d0 , d1 ) -> (d0 , d1 )>, affine_map <(d0 , d1 ) -> (d0 )>],
493+ iterator_types = [" parallel" , " reduction" ]}
494+ ins (%0 : tensor <100 x20 xf32 >) outs (%3 : tensor <100 xf32 >) {
495+ ^bb0 (%arg1 : f32 , %arg2: f32 ):
496+ %4 = arith.addf %arg1 , %arg2 : f32
497+ linalg.yield %4 : f32
498+ } -> tensor <100 xf32 >
499+ return %4 : tensor <100 xf32 >
500+ }
501+
502+ // CHECK: func @fuse_collapse_reduction
503+ // CHECK-SAME: %[[ARG0:.+]]: tensor<10x10x20xf32>
504+ // CHECK: %[[GENERIC:.+]] = linalg.generic
505+ // CHECK-SAME: ins(%[[ARG0]] : tensor<10x10x20xf32>)
506+ // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[GENERIC]]
507+ // CHECK: return %[[COLLAPSE]]
508+ // -----
509+
485510func.func @no_fuse_dynamic_dims (%arg0: tensor <?x?xf32 >) -> tensor <?xf32 > {
486511 %c0 = arith.constant 0 : index
487512 %0 = tensor.collapse_shape %arg0 [[0 , 1 ]] : tensor <?x?xf32 > into tensor <?xf32 >
0 commit comments