Skip to content

Commit 005c833

Browse files
authored
[mlir][tensor] Fix ReifyResultShapes implementation for tensor.concat (#74157)
Without folding the result of the initial tensor.dim, the ReifyResultShapes implementation would be incorrect because it would return a dynamic shape for a static result shape.
1 parent 357b8b4 commit 005c833

File tree

2 files changed

+27
-23
lines changed

2 files changed

+27
-23
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ ConcatOp::reifyResultShapes(OpBuilder &builder,
605605
// Take the sum of the input sizes along the concatenated dim.
606606
AffineExpr sum = builder.getAffineDimExpr(0);
607607
SmallVector<OpFoldResult> sizes = {
608-
builder.create<tensor::DimOp>(init.getLoc(), init, 0).getResult()};
608+
builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)};
609609
for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
610610
sum = sum + builder.getAffineDimExpr(idx + 1);
611611
sizes.push_back(
Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,10 @@
11
// RUN: mlir-opt -split-input-file -transform-interpreter -cse %s | FileCheck %s
22

3-
module attributes {transform.with_named_sequence} {
4-
transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
5-
transform.apply_patterns to %func_op {
6-
transform.apply_patterns.tensor.decompose_concat
7-
} : !transform.op<"func.func">
8-
transform.yield
9-
}
10-
}
11-
123
func.func @decompose_dynamic_concat(%arg0 : tensor<8x4xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
134
%0 = tensor.concat dim(1) %arg0, %arg1 : (tensor<8x4xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
145
return %0 : tensor<?x?xf32>
156
}
16-
// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
7+
// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 4)>
178
// CHECK-LABEL: func @decompose_dynamic_concat(
189
// CHECK-SAME: %[[ARG0:.+]]: tensor<8x4xf32>
1910
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
@@ -22,24 +13,13 @@ func.func @decompose_dynamic_concat(%arg0 : tensor<8x4xf32>, %arg1 : tensor<?x?x
2213
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
2314
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
2415
// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
25-
// CHECK: %[[CONCAT_SIZE:.+]] = affine.apply #[[$MAP]]()[%[[C8]], %[[DIM]]]
16+
// CHECK: %[[CONCAT_SIZE:.+]] = affine.apply #[[$MAP]]()[%[[DIM]]]
2617
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[C8]], %[[CONCAT_SIZE]]) : tensor<?x?xf32>
2718
// CHECK: %[[SLICE0:.+]] = tensor.insert_slice %[[ARG0]] into %[[EMPTY]][0, 0] [8, 4] [1, 1] : tensor<8x4xf32> into tensor<?x?xf32>
2819
// CHECK: %[[OFFSET:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
2920
// CHECK: %[[CONCAT:.+]] = tensor.insert_slice %[[ARG1]] into %[[SLICE0]][0, 4] [%[[OFFSET]], %[[DIM]]] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
3021
// CHECK: return %[[CONCAT]] : tensor<?x?xf32>
3122

32-
// -----
33-
34-
module attributes {transform.with_named_sequence} {
35-
transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
36-
transform.apply_patterns to %func_op {
37-
transform.apply_patterns.tensor.decompose_concat
38-
} : !transform.op<"func.func">
39-
transform.yield
40-
}
41-
}
42-
4323
func.func @decompose_1d_concat(%arg0 : tensor<1xf32>,
4424
%arg1 : tensor<2xf32>,
4525
%arg2 : tensor<3xf32>,
@@ -55,3 +35,27 @@ func.func @decompose_1d_concat(%arg0 : tensor<1xf32>,
5535
// CHECK: tensor.insert_slice %{{.*}}[3] [3] [1] : tensor<3xf32> into tensor<10xf32>
5636
// CHECK: %[[CONCAT:.+]] = tensor.insert_slice %{{.*}}[6] [4] [1] : tensor<4xf32> into tensor<10xf32>
5737
// CHECK: return %[[CONCAT]] : tensor<10xf32>
38+
39+
func.func @decompose_static_concat_dim(%arg0 : tensor<1x?x64xf32>,
40+
%arg1: tensor<1x?x64xf32>) -> tensor<1x?x128xf32> {
41+
%0 = tensor.concat dim(2) %arg0, %arg1
42+
: (tensor<1x?x64xf32>, tensor<1x?x64xf32>) -> tensor<1x?x128xf32>
43+
return %0 : tensor<1x?x128xf32>
44+
}
45+
// CHECK-LABEL: func @decompose_static_concat_dim
46+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
47+
// CHECK: %[[DIM:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?x64xf32>
48+
// CHECK: tensor.empty(%[[DIM]]) : tensor<1x?x128xf32>
49+
// CHECK: tensor.insert_slice %{{.*}}[0, 0, 0] [1, %[[DIM]], 64] [1, 1, 1] : tensor<1x?x64xf32> into tensor<1x?x128xf32>
50+
// CHECK: %[[DIM1:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?x64xf32>
51+
// CHECK: %[[CONCAT:.+]] = tensor.insert_slice %{{.*}}[0, 0, 64] [1, %[[DIM1]], 64] [1, 1, 1] : tensor<1x?x64xf32> into tensor<1x?x128xf32>
52+
// CHECK: return %[[CONCAT]] : tensor<1x?x128xf32>
53+
54+
module attributes {transform.with_named_sequence} {
55+
transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
56+
transform.apply_patterns to %func_op {
57+
transform.apply_patterns.tensor.decompose_concat
58+
} : !transform.op<"func.func">
59+
transform.yield
60+
}
61+
}

0 commit comments

Comments
 (0)