@@ -358,40 +358,6 @@ util.func public @sink_through_expand_shape(%arg0: tensor<?x?x?xf32>) -> tensor<
358358
359359// -----
360360
361- util.func public @sink_non_involution_through_expand_shape (%arg0 : tensor <2 x3 x4 xf32 >) -> tensor <1 x3 x4 x2 xf32 > {
362- %empty = tensor.empty (): tensor <3 x4 x2 xf32 >
363- %transposed = linalg.transpose ins (%arg0 : tensor <2 x3 x4 xf32 >)
364- outs (%empty : tensor <3 x4 x2 xf32 >) permutation = [1 , 2 , 0 ]
365- %expanded = tensor.expand_shape %transposed [[0 , 1 ], [2 ], [3 ]] output_shape [1 , 3 , 4 , 2 ] : tensor <3 x4 x2 xf32 > into tensor <1 x3 x4 x2 xf32 >
366- util.return %expanded : tensor <1 x3 x4 x2 xf32 >
367- }
368- // SINK-LABEL: util.func public @sink_non_involution_through_expand_shape
369- // SINK: %[[EXP:.+]] = tensor.expand_shape {{.*}} {{\[\[}}0], [1, 2], [3]]
370- // SINK-SAME: tensor<2x3x4xf32> into tensor<2x1x3x4xf32>
371- // SINK: %[[RES:.+]] = linalg.transpose ins(%[[EXP]] : tensor<2x1x3x4xf32>
372- // SINK-SAME: outs({{.*}} : tensor<1x3x4x2xf32>)
373- // SINK-SAME: permutation = [1, 2, 3, 0]
374- // SINK: util.return %[[RES]] : tensor<1x3x4x2xf32>
375-
376- // -----
377-
378- util.func public @bubble_non_involution_through_collapse_shape (%arg0 : tensor <1 x2 x3 x5 x7 x11 xf32 >) -> tensor <35 x11 x6 xf32 > {
379- %collapsed = tensor.collapse_shape %arg0 [[0 , 1 , 2 ], [3 , 4 ], [5 ]] : tensor <1 x2 x3 x5 x7 x11 xf32 > into tensor <6 x35 x11 xf32 >
380- %empty = tensor.empty (): tensor <35 x11 x6 xf32 >
381- %transposed = linalg.transpose ins (%collapsed : tensor <6 x35 x11 xf32 >)
382- outs (%empty : tensor <35 x11 x6 xf32 >) permutation = [1 , 2 , 0 ]
383- util.return %transposed : tensor <35 x11 x6 xf32 >
384- }
385- // BUBBLE-LABEL: util.func public @bubble_non_involution_through_collapse_shape
386- // BUBBLE: %[[T:.+]] = linalg.transpose ins(%{{.*}} : tensor<1x2x3x5x7x11xf32>
387- // BUBBLE-SAME: outs({{.*}} : tensor<5x7x11x1x2x3xf32>)
388- // BUBBLE-SAME: permutation = [3, 4, 5, 0, 1, 2]
389- // BUBBLE: %[[COL:.+]] = tensor.collapse_shape %[[T]] {{\[\[}}0, 1], [2], [3, 4, 5]]
390- // BUBBLE-SAME: tensor<5x7x11x1x2x3xf32> into tensor<35x11x6xf32>
391- // BUBBLE: util.return %[[COL]] : tensor<35x11x6xf32>
392-
393- // -----
394-
395361util.func public @propagate_transpose_through_unary_elementwise (%arg0 : tensor <2 x3 x4 xf32 >) -> tensor <3 x4 x2 xf32 > {
396362 %empty = tensor.empty (): tensor <3 x4 x2 xf32 >
397363 %transposed = linalg.transpose ins (%arg0 : tensor <2 x3 x4 xf32 >)
@@ -799,3 +765,31 @@ util.func public @dont_reshape_reduction(%arg0: tensor<16x4x4xf32>, %arg1: tenso
799765// APROP: %[[V1:.+]] = tensor.collapse_shape %[[V0]]
800766// APROP: %[[V2:.+]] = linalg.matmul ins(%[[V1]]
801767// APROP: util.return %[[V2]]
768+
769+ // -----
770+
771+ util.func @dont_propagate_edge_reshapes (%arg0: tensor <10 x10 x10 xi32 >) -> tensor <10 x100 xi32 > {
772+ %collapsed = tensor.collapse_shape %arg0 [[0 , 1 ], [2 ]] : tensor <10 x10 x10 xi32 > into tensor <100 x10 xi32 >
773+ %empty = tensor.empty () : tensor <10 x100 xi32 >
774+ %transpose = linalg.transpose ins (%collapsed : tensor <100 x10 xi32 >) outs (%empty : tensor <10 x100 xi32 >) permutation = [1 , 0 ]
775+ util.return %transpose : tensor <10 x100 xi32 >
776+ }
777+ // CHECK-LABEL: util.func public @dont_propagate_edge_reshapes
778+ // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]
779+ // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]]
780+ // CHECK: %[[VAL:.+]] = linalg.transpose ins(%[[COLLAPSED]]
781+ // CHECK: util.return %[[VAL]]
782+
783+ // -----
784+
785+ util.func public @dont_sink_through_edge_expand_shape (%arg0 : tensor <2 x3 x4 xf32 >) -> tensor <1 x3 x4 x2 xf32 > {
786+ %empty = tensor.empty (): tensor <3 x4 x2 xf32 >
787+ %transposed = linalg.transpose ins (%arg0 : tensor <2 x3 x4 xf32 >)
788+ outs (%empty : tensor <3 x4 x2 xf32 >) permutation = [1 , 2 , 0 ]
789+ %expanded = tensor.expand_shape %transposed [[0 , 1 ], [2 ], [3 ]] output_shape [1 , 3 , 4 , 2 ] : tensor <3 x4 x2 xf32 > into tensor <1 x3 x4 x2 xf32 >
790+ util.return %expanded : tensor <1 x3 x4 x2 xf32 >
791+ }
792+ // SINK-LABEL: util.func public @dont_sink_through_edge_expand_shape
793+ // SINK: %[[TRANSPOSE:.+]] = linalg.transpose
794+ // SINK: %[[RES:.+]] = tensor.expand_shape %[[TRANSPOSE]]
795+ // SINK: util.return %[[RES]] : tensor<1x3x4x2xf32>
0 commit comments