@@ -30,8 +30,13 @@ func.func @simple_pad_and_pack(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2x
3030// CHECK-ALL-SAME: %[[PAD_VAL:[A-Za-z0-9]+]]:
3131// CHECK-ALL: %[[PAD:.+]] = tensor.pad %[[IN]] low[0, 0] high[3, 1]
3232// CHECK-ALL: tensor.yield %[[PAD_VAL]]
33- // CHECK-ALL: %[[INSERT:.+]] = tensor.insert_slice %[[PAD]] into %[[OUT]][0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
34- // CHECK-ALL: return %[[INSERT]]
33+
34+ // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[PAD]] into %[[OUT]][0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
35+ // CHECK: return %[[INSERT]]
36+
37+ // CHECK-RESHAPE: %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0, 1], [2, 3]] output_shape [1, 8, 1, 2] : tensor<8x2xf32> into tensor<1x8x1x2xf32>
38+ // CHECK-RESHAPE: %[[TRANS:.+]] = linalg.transpose ins(%[[EXPAND]] : tensor<1x8x1x2xf32>) outs(%[[OUT]] : tensor<1x1x8x2xf32>) permutation = [0, 2, 1, 3]
39+ // CHECK-RESHAPE: return %[[TRANS]]
3540
3641// -----
3742
@@ -42,8 +47,13 @@ func.func @simple_NC_to_CNnc(%arg0: tensor<32x8xf32>, %arg1: tensor<1x1x32x8xf32
4247// CHECK-ALL-LABEL: func.func @simple_NC_to_CNnc
4348// CHECK-ALL-SAME: %[[IN:[A-Za-z0-9]+]]:
4449// CHECK-ALL-SAME: %[[OUT:[A-Za-z0-9]+]]:
45- // CHECK-ALL: %[[INSERT:.+]] = tensor.insert_slice %[[IN]] into %[[OUT]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
46- // CHECK-ALL: return %[[INSERT]]
50+
51+ // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[IN]] into %[[OUT]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
52+ // CHECK: return %[[INSERT]]
53+
54+ // CHECK-RESHAPE: %[[EXPAND:.+]] = tensor.expand_shape %[[IN]] {{\[}}[0, 1], [2, 3]] output_shape [1, 32, 1, 8] : tensor<32x8xf32> into tensor<1x32x1x8xf32>
55+ // CHECK-RESHAPE: %[[TRANS:.+]] = linalg.transpose ins(%[[EXPAND]] : tensor<1x32x1x8xf32>) outs(%[[OUT]] : tensor<1x1x32x8xf32>) permutation = [2, 0, 1, 3]
56+ // CHECK-RESHAPE: return %[[TRANS]]
4757
4858// -----
4959
@@ -132,8 +142,11 @@ func.func @simple_unpack_and_extract_slice(%input: tensor<1x1x8x2xf32>, %output:
132142// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[IN]][0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
133143// CHECK: %[[RES:.+]] = tensor.extract_slice %[[TILE]][0, 0] [5, 1] [1, 1]
134144
135- // CHECK-RESHAPE: %[[RES:.+]] = tensor.extract_slice %[[IN]][0, 0, 0, 0] [1, 1, 5, 1] [1, 1, 1, 1]
136-
145+ // CHECK-RESHAPE: %[[EMPTY:.+]] = tensor.empty() : tensor<1x8x1x2xf32>
146+ // CHECK-RESHAPE: %[[TRANS:.+]] = linalg.transpose ins(%[[IN]] : tensor<1x1x8x2xf32>) outs(%[[EMPTY]] : tensor<1x8x1x2xf32>) permutation = [0, 2, 1, 3]
147+ // CHECK-RESHAPE: %[[COLLAPSE:.+]] = tensor.collapse_shape
148+ // CHECK-RESHAPE: %[[SLICE:.+]] = tensor.extract_slice %[[COLLAPSE]]
149+ // CHECK-RESHAPE: %[[RES:.+]] = linalg.copy ins(%[[SLICE]] : tensor<5x1xf32>) outs(%[[OUT]] : tensor<5x1xf32>) -> tensor<5x1xf32>
137150// CHECK-ALL: return %[[RES:.+]]
138151
139152// -----
@@ -145,8 +158,13 @@ func.func @simple_CNnc_to_NC(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<32x8xf32
145158// CHECK-ALL-LABEL: func.func @simple_CNnc_to_NC
146159// CHECK-ALL-SAME: %[[IN:[A-Za-z0-9]+]]:
147160// CHECK-ALL-SAME: %[[OUT:[A-Za-z0-9]+]]:
148- // CHECK-ALL: %[[TILE:.+]] = tensor.extract_slice %[[IN]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
149- // CHECK-ALL: return %[[TILE]]
161+ // CHECK: %[[RESULT:.+]] = tensor.extract_slice %[[IN]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
162+
163+ // CHECK-RESHAPE: %[[EMPTY:.+]] = tensor.empty() : tensor<1x32x1x8xf32>
164+ // CHECK-RESHAPE: %[[TRANS:.+]] = linalg.transpose ins(%[[IN]] : tensor<1x1x32x8xf32>) outs(%[[EMPTY]] : tensor<1x32x1x8xf32>) permutation = [1, 2, 0, 3]
165+ // CHECK-RESHAPE: %[[COLLAPSE:.+]] = tensor.collapse_shape
166+ // CHECK-RESHAPE: %[[RESULT:.+]] = linalg.copy ins(%[[COLLAPSE]] : tensor<32x8xf32>) outs(%[[OUT]] : tensor<32x8xf32>) -> tensor<32x8xf32>
167+ // CHECK-ALL: return %[[RESULT]]
150168
151169// -----
152170
0 commit comments