Skip to content

Commit 4ef7b46

Browse files
committed
fix bug with non-concatted dynamic dim
1 parent 5e441a5 commit 4ef7b46

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1110,7 +1110,7 @@ struct ConcatOpInterface
11101110
for (const auto &[dimIdx, dimSize] :
11111111
llvm::enumerate(tensorType.getShape())) {
11121112
if (dimSize == ShapedType::kDynamic) {
1113-
auto dimOp = rewriter.create<memref::DimOp>(loc, dstBuffer, dimSize);
1113+
auto dimOp = rewriter.create<memref::DimOp>(loc, dstBuffer, dimIdx);
11141114
sizes.push_back(dimOp.getResult());
11151115
if (dimIdx == concatDim)
11161116
dynamicConcatDim = true;

mlir/test/Dialect/Tensor/bufferize.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,33 @@ func.func @tensor.concat_dynamic(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>) -> te
679679

680680
// -----
681681

682+
// CHECK-LABEL: func @tensor.concat_dynamic_nonconcat_dim(
683+
// CHECK-SAME: %[[F:.*]]: tensor<?x?xf32>,
684+
// CHECK-SAME: %[[G:.*]]: tensor<?x?xf32>
685+
// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_memref %[[F]]
686+
// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_memref %[[G]]
687+
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
688+
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
689+
// CHECK-DAG: %[[F_DIM:.*]] = memref.dim %[[F_MEMREF]], %[[c1]]
690+
// CHECK-DAG: %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]]
691+
// CHECK: %[[ALLOC:.*]] = memref.alloc
692+
// CHECK-SAME: memref<?x?xf32>
693+
// CHECK-DAG: %[[NON_CONCAT_DIM:.*]] = memref.dim %[[ALLOC]], %[[c0]]
694+
// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[c0]]] [%[[NON_CONCAT_DIM]], %[[F_DIM]]] [1, 1]
695+
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
696+
// CHECK: %[[OFFSET_2:.*]] = arith.addi %[[c0]], %[[F_DIM]] : index
697+
// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [%[[NON_CONCAT_DIM]], %[[G_DIM]]] [1, 1]
698+
// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
699+
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
700+
// CHECK: return %[[RET]]
701+
// CHECK: }
702+
func.func @tensor.concat_dynamic_nonconcat_dim(%f: tensor<?x?xf32>, %g: tensor<?x?xf32>) -> tensor<?x?xf32> {
703+
%t = tensor.concat dim(1) %f, %g : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
704+
return %t : tensor<?x?xf32>
705+
}
706+
707+
// -----
708+
682709
// CHECK-LABEL: func @tensor.splat_dynamic(
683710
// CHECK-SAME: %[[F:[a-zA-Z0-9_]+]]: f32
684711
// CHECK-SAME: %[[M:[a-zA-Z0-9_]+]]: index

0 commit comments

Comments
 (0)