11// RUN: mlir-opt -split-input-file -transform-interpreter -cse %s | FileCheck %s
22
3- module attributes {transform.with_named_sequence } {
4- transform.named_sequence @__transform_main (%func_op: !transform.op <" func.func" > {transform.readonly }) {
5- transform.apply_patterns to %func_op {
6- transform.apply_patterns.tensor.decompose_concat
7- } : !transform.op <" func.func" >
8- transform.yield
9- }
10- }
11-
123func.func @decompose_dynamic_concat (%arg0 : tensor <8 x4 xf32 >, %arg1 : tensor <?x?xf32 >) -> tensor <?x?xf32 > {
134 %0 = tensor.concat dim (1 ) %arg0 , %arg1 : (tensor <8 x4 xf32 >, tensor <?x?xf32 >) -> tensor <?x?xf32 >
145 return %0 : tensor <?x?xf32 >
156}
16- // CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1 ] -> (s0 + s1 )>
7+ // CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 4 )>
178// CHECK-LABEL: func @decompose_dynamic_concat(
189// CHECK-SAME: %[[ARG0:.+]]: tensor<8x4xf32>
1910// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
@@ -22,24 +13,13 @@ func.func @decompose_dynamic_concat(%arg0 : tensor<8x4xf32>, %arg1 : tensor<?x?x
2213// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
2314// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
2415// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
25- // CHECK: %[[CONCAT_SIZE:.+]] = affine.apply #[[$MAP]]()[%[[C8]], %[[ DIM]]]
16+ // CHECK: %[[CONCAT_SIZE:.+]] = affine.apply #[[$MAP]]()[%[[DIM]]]
2617// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[C8]], %[[CONCAT_SIZE]]) : tensor<?x?xf32>
2718// CHECK: %[[SLICE0:.+]] = tensor.insert_slice %[[ARG0]] into %[[EMPTY]][0, 0] [8, 4] [1, 1] : tensor<8x4xf32> into tensor<?x?xf32>
2819// CHECK: %[[OFFSET:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
2920// CHECK: %[[CONCAT:.+]] = tensor.insert_slice %[[ARG1]] into %[[SLICE0]][0, 4] [%[[OFFSET]], %[[DIM]]] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
3021// CHECK: return %[[CONCAT]] : tensor<?x?xf32>
3122
32- // -----
33-
34- module attributes {transform.with_named_sequence } {
35- transform.named_sequence @__transform_main (%func_op: !transform.op <" func.func" > {transform.readonly }) {
36- transform.apply_patterns to %func_op {
37- transform.apply_patterns.tensor.decompose_concat
38- } : !transform.op <" func.func" >
39- transform.yield
40- }
41- }
42-
4323func.func @decompose_1d_concat (%arg0 : tensor <1 xf32 >,
4424 %arg1 : tensor <2 xf32 >,
4525 %arg2 : tensor <3 xf32 >,
@@ -55,3 +35,27 @@ func.func @decompose_1d_concat(%arg0 : tensor<1xf32>,
5535// CHECK: tensor.insert_slice %{{.*}}[3] [3] [1] : tensor<3xf32> into tensor<10xf32>
5636// CHECK: %[[CONCAT:.+]] = tensor.insert_slice %{{.*}}[6] [4] [1] : tensor<4xf32> into tensor<10xf32>
5737// CHECK: return %[[CONCAT]] : tensor<10xf32>
38+
39+ func.func @decompose_static_concat_dim (%arg0 : tensor <1 x?x64 xf32 >,
40+ %arg1: tensor <1 x?x64 xf32 >) -> tensor <1 x?x128 xf32 > {
41+ %0 = tensor.concat dim (2 ) %arg0 , %arg1
42+ : (tensor <1 x?x64 xf32 >, tensor <1 x?x64 xf32 >) -> tensor <1 x?x128 xf32 >
43+ return %0 : tensor <1 x?x128 xf32 >
44+ }
45+ // CHECK-LABEL: func @decompose_static_concat_dim
46+ // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
47+ // CHECK: %[[DIM:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?x64xf32>
48+ // CHECK: tensor.empty(%[[DIM]]) : tensor<1x?x128xf32>
49+ // CHECK: tensor.insert_slice %{{.*}}[0, 0, 0] [1, %[[DIM]], 64] [1, 1, 1] : tensor<1x?x64xf32> into tensor<1x?x128xf32>
50+ // CHECK: %[[DIM1:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?x64xf32>
51+ // CHECK: %[[CONCAT:.+]] = tensor.insert_slice %{{.*}}[0, 0, 64] [1, %[[DIM1]], 64] [1, 1, 1] : tensor<1x?x64xf32> into tensor<1x?x128xf32>
52+ // CHECK: return %[[CONCAT]] : tensor<1x?x128xf32>
53+
54+ module attributes {transform.with_named_sequence } {
55+ transform.named_sequence @__transform_main (%func_op: !transform.op <" func.func" > {transform.readonly }) {
56+ transform.apply_patterns to %func_op {
57+ transform.apply_patterns.tensor.decompose_concat
58+ } : !transform.op <" func.func" >
59+ transform.yield
60+ }
61+ }
0 commit comments