1- // RUN: mlir-opt -split-input-file -transform-interpreter -cse %s | FileCheck %s
1+ // RUN: mlir-opt -split-input-file -transform-interpreter -cse --mlir-print-local-scope %s | FileCheck %s
22
33func.func @decompose_dynamic_concat (%arg0 : tensor <8 x4 xf32 >, %arg1 : tensor <?x?xf32 >) -> tensor <?x?xf32 > {
44 %0 = tensor.concat dim (1 ) %arg0 , %arg1 : (tensor <8 x4 xf32 >, tensor <?x?xf32 >) -> tensor <?x?xf32 >
55 return %0 : tensor <?x?xf32 >
66}
7- // CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 4)>
87// CHECK-LABEL: func @decompose_dynamic_concat(
98// CHECK-SAME: %[[ARG0:.+]]: tensor<8x4xf32>
109// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
1110
12- // CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
1311// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
1412// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
15- // CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
16- // CHECK: %[[CONCAT_SIZE:.+]] = affine.apply #[[$MAP]]()[%[[DIM]]]
17- // CHECK: %[[EMPTY:.+]] = tensor.empty(%[[C8]], %[[CONCAT_SIZE]]) : tensor<?x?xf32>
18- // CHECK: %[[SLICE0:.+]] = tensor.insert_slice %[[ARG0]] into %[[EMPTY]][0, 0] [8, 4] [1, 1] : tensor<8x4xf32> into tensor<?x?xf32>
19- // CHECK: %[[OFFSET:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
20- // CHECK: %[[CONCAT:.+]] = tensor.insert_slice %[[ARG1]] into %[[SLICE0]][0, 4] [%[[OFFSET]], %[[DIM]]] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
21- // CHECK: return %[[CONCAT]] : tensor<?x?xf32>
13+ // CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
14+ // CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
15+ // CHECK: %[[CONCAT_SIZE:.+]] = affine.apply affine_map<()[s0] -> (s0 + 4)>()[%[[DIM0]]]
16+ // CHECK: %[[EMPTY:.+]] = tensor.empty(%[[CONCAT_SIZE]]) : tensor<8x?xf32>
17+ // CHECK: %[[SLICE0:.+]] = tensor.insert_slice %[[ARG0]] into %[[EMPTY]][0, 0] [8, 4] [1, 1] : tensor<8x4xf32> into tensor<8x?xf32>
18+ // CHECK: %[[CONCAT:.+]] = tensor.insert_slice %[[ARG1]] into %[[SLICE0]][0, 4] [%[[DIM]], %[[DIM0]]] [1, 1] : tensor<?x?xf32> into tensor<8x?xf32>
19+ // CHECK: %[[CAST:.+]] = tensor.cast %[[CONCAT]] : tensor<8x?xf32> to tensor<?x?xf32>
20+ // CHECK: return %[[CAST]] : tensor<?x?xf32>
2221
2322func.func @decompose_1d_concat (%arg0 : tensor <1 xf32 >,
2423 %arg1 : tensor <2 xf32 >,
@@ -42,12 +41,14 @@ func.func @decompose_static_concat_dim(%arg0 : tensor<1x?x64xf32>,
4241 : (tensor <1 x?x64 xf32 >, tensor <1 x?x64 xf32 >) -> tensor <1 x?x128 xf32 >
4342 return %0 : tensor <1 x?x128 xf32 >
4443}
45- // CHECK-LABEL: func @decompose_static_concat_dim
44+ // CHECK-LABEL: func @decompose_static_concat_dim(
45+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x?x64xf32>,
46+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<1x?x64xf32>)
4647// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
47- // CHECK: %[[DIM:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?x64xf32>
48+ // CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<1x?x64xf32>
49+ // CHECK: %[[DIM1:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<1x?x64xf32>
4850// CHECK: tensor.empty(%[[DIM]]) : tensor<1x?x128xf32>
4951// CHECK: tensor.insert_slice %{{.*}}[0, 0, 0] [1, %[[DIM]], 64] [1, 1, 1] : tensor<1x?x64xf32> into tensor<1x?x128xf32>
50- // CHECK: %[[DIM1:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?x64xf32>
5152// CHECK: %[[CONCAT:.+]] = tensor.insert_slice %{{.*}}[0, 0, 64] [1, %[[DIM1]], 64] [1, 1, 1] : tensor<1x?x64xf32> into tensor<1x?x128xf32>
5253// CHECK: return %[[CONCAT]] : tensor<1x?x128xf32>
5354
@@ -58,19 +59,23 @@ func.func @decompose_dynamic_into_static_concat_dim(%arg0 : tensor<1x?x?xf32>,
5859 : (tensor <1 x?x?xf32 >, tensor <1 x?x?xf32 >) -> tensor <1 x?x128 xf32 >
5960 return %0 : tensor <1 x?x128 xf32 >
6061}
61- // CHECK-LABEL: func @decompose_dynamic_into_static_concat_dim
62+ // CHECK-LABEL: func @decompose_dynamic_into_static_concat_dim(
63+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>,
64+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>)
6265// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
6366// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
64- // CHECK: %[[T0_DIM1:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?x?xf32>
65- // CHECK: tensor.empty(%[[T0_DIM1]]) : tensor<1x?x128xf32>
66- // CHECK: %[[T0_DIM2:.+]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x?x?xf32>
67+ // CHECK: %[[T0_DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<1x?x?xf32>
68+ // CHECK: %[[T0_DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<1x?x?xf32>
69+ // CHECK: %[[T1_DIM1:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<1x?x?xf32>
70+ // CHECK: %[[T1_DIM2:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<1x?x?xf32>
71+ // CHECK: %[[CONCAT_DIM:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[T0_DIM2]], %[[T1_DIM2]]]
72+ // CHECK: tensor.empty(%[[T0_DIM1]], %[[CONCAT_DIM]]) : tensor<1x?x?xf32>
6773// CHECK: tensor.insert_slice %{{.*}}[0, 0, 0] [1, %[[T0_DIM1]], %[[T0_DIM2]]] [1, 1, 1]
68- // CHECK-SAME: tensor<1x?x?xf32> into tensor<1x?x128xf32>
69- // CHECK: %[[T1_DIM1:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?x?xf32>
70- // CHECK: %[[T1_DIM2:.+]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x?x?xf32>
74+ // CHECK-SAME: tensor<1x?x?xf32> into tensor<1x?x?xf32>
7175// CHECK: %[[CONCAT:.+]] = tensor.insert_slice %{{.*}}[0, 0, %[[T0_DIM2]]] [1, %[[T1_DIM1]], %[[T1_DIM2]]] [1, 1, 1]
72- // CHECK-SAME: tensor<1x?x?xf32> into tensor<1x?x128xf32>
73- // CHECK: return %[[CONCAT]] : tensor<1x?x128xf32>
76+ // CHECK-SAME: tensor<1x?x?xf32> into tensor<1x?x?xf32>
77+ // CHECK: %[[CAST:.+]] = tensor.cast %[[CONCAT]] : tensor<1x?x?xf32> to tensor<1x?x128xf32>
78+ // CHECK: return %[[CAST]] : tensor<1x?x128xf32>
7479
7580module attributes {transform.with_named_sequence } {
7681 transform.named_sequence @__transform_main (%root: !transform.any_op {transform.readonly }) {
0 commit comments