Skip to content

Commit 039f883

Browse files
authored
[mlir][tensor] Fix bug in ConcatOpInterface (#168676)
This PR fixes an issue in `ConcatOpInterface` where `tensor.concat` fails when the concat dimension is dynamic while the result type is static. The fix unifies the computation by using `OpFoldResult`, avoiding the need to separately handle dynamic and static dimension values. Fixes #162776.
1 parent 0a03b7e commit 039f883

File tree

2 files changed

+48
-48
lines changed

2 files changed

+48
-48
lines changed

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 15 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,35 +1132,22 @@ struct ConcatOpInterface
11321132

11331133
// Extract the dimension for the concat op
11341134
uint64_t concatDim = concatOp.getDim();
1135-
bool dynamicConcatDim = false;
11361135

11371136
SmallVector<OpFoldResult> offsets(tensorType.getRank(),
11381137
rewriter.getIndexAttr(0));
11391138
SmallVector<OpFoldResult> strides(tensorType.getRank(),
11401139
rewriter.getIndexAttr(1));
1141-
SmallVector<OpFoldResult> sizes;
1142-
1143-
for (const auto &[dimIdx, dimSize] :
1144-
llvm::enumerate(tensorType.getShape())) {
1145-
if (dimSize == ShapedType::kDynamic) {
1146-
auto dimOp = memref::DimOp::create(rewriter, loc, dstBuffer, dimIdx);
1147-
sizes.push_back(dimOp.getResult());
1148-
if (dimIdx == concatDim)
1149-
dynamicConcatDim = true;
1150-
} else {
1151-
sizes.push_back(rewriter.getIndexAttr(dimSize));
1152-
}
1153-
}
1154-
1155-
int64_t concatDimOffset = 0;
1156-
std::optional<Value> dynamicOffset;
1157-
std::optional<Value> dynamicSize;
1158-
if (dynamicConcatDim) {
1159-
// One or more operands have dynamic size, so we must accumulate the
1160-
// offset with arith ops.
1161-
dynamicOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
1162-
}
1140+
SmallVector<OpFoldResult> sizes =
1141+
memref::getMixedSizes(rewriter, loc, dstBuffer);
1142+
1143+
AffineExpr s0, s1;
1144+
bindSymbols(rewriter.getContext(), s0, s1);
1145+
auto sum = [&](OpFoldResult v1, OpFoldResult v2) {
1146+
return affine::makeComposedFoldedAffineApply(rewriter, loc, s0 + s1,
1147+
{v1, v2});
1148+
};
11631149

1150+
OpFoldResult concatDimOffset = rewriter.getIndexAttr(0);
11641151
for (auto operand : concatOp.getInputs()) {
11651152
// Get the buffer for the operand.
11661153
FailureOr<Value> srcBuffer = getBuffer(rewriter, operand, options, state);
@@ -1171,18 +1158,10 @@ struct ConcatOpInterface
11711158
// so the offset on that axis must accumulate through the loop, and the
11721159
// size must change to the size of the current operand.
11731160
auto operandTensorType = cast<RankedTensorType>(operand.getType());
1174-
int64_t operandConcatDimSize = operandTensorType.getDimSize(concatDim);
1175-
1176-
if (dynamicConcatDim) {
1177-
offsets[concatDim] = dynamicOffset.value();
1178-
dynamicSize =
1179-
memref::DimOp::create(rewriter, loc, *srcBuffer, concatDim)
1180-
.getResult();
1181-
sizes[concatDim] = dynamicSize.value();
1182-
} else {
1183-
sizes[concatDim] = rewriter.getIndexAttr(operandConcatDimSize);
1184-
offsets[concatDim] = rewriter.getIndexAttr(concatDimOffset);
1185-
}
1161+
offsets[concatDim] = concatDimOffset;
1162+
OpFoldResult concatDimSize =
1163+
memref::getMixedSize(rewriter, loc, *srcBuffer, concatDim);
1164+
sizes[concatDim] = concatDimSize;
11861165

11871166
// Create a subview of the destination buffer.
11881167
auto dstMemrefType = cast<MemRefType>(memrefType);
@@ -1197,12 +1176,7 @@ struct ConcatOpInterface
11971176
if (failed(options.createMemCpy(rewriter, loc, *srcBuffer, subview)))
11981177
return failure();
11991178

1200-
if (dynamicConcatDim) {
1201-
dynamicOffset = arith::AddIOp::create(
1202-
rewriter, loc, dynamicOffset.value(), dynamicSize.value());
1203-
} else {
1204-
concatDimOffset += operandConcatDimSize;
1205-
}
1179+
concatDimOffset = sum(concatDimOffset, concatDimSize);
12061180
}
12071181

12081182
replaceOpWithBufferizedValues(rewriter, op, dstBuffer);

mlir/test/Dialect/Tensor/bufferize.mlir

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -678,11 +678,9 @@ func.func @tensor.concat_different_shapes(%f: tensor<8x4xf32>, %g: tensor<8x5xf3
678678
// CHECK-DAG: %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]]
679679
// CHECK: %[[ALLOC:.*]] = memref.alloc
680680
// CHECK-SAME: memref<8x?xf32>
681-
// CHECK-DAG: %[[OFFSET:.*]] = arith.constant 0 : index
682-
// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET]]] [8, %[[F_DIM]]] [1, 1]
681+
// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, %[[F_DIM]]] [1, 1]
683682
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
684-
// CHECK: %[[OFFSET_2:.*]] = arith.addi %[[OFFSET]], %[[F_DIM]] : index
685-
// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [8, %[[G_DIM]]] [1, 1]
683+
// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [8, %[[G_DIM]]] [1, 1]
686684
// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
687685
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
688686
// CHECK: return %[[RET]]
@@ -706,10 +704,9 @@ func.func @tensor.concat_dynamic(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>) -> te
706704
// CHECK: %[[ALLOC:.*]] = memref.alloc
707705
// CHECK-SAME: memref<?x?xf32>
708706
// CHECK-DAG: %[[NON_CONCAT_DIM:.*]] = memref.dim %[[ALLOC]], %[[c0]]
709-
// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[c0]]] [%[[NON_CONCAT_DIM]], %[[F_DIM]]] [1, 1]
707+
// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [%[[NON_CONCAT_DIM]], %[[F_DIM]]] [1, 1]
710708
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
711-
// CHECK: %[[OFFSET_2:.*]] = arith.addi %[[c0]], %[[F_DIM]] : index
712-
// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [%[[NON_CONCAT_DIM]], %[[G_DIM]]] [1, 1]
709+
// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [%[[NON_CONCAT_DIM]], %[[G_DIM]]] [1, 1]
713710
// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
714711
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
715712
// CHECK: return %[[RET]]
@@ -721,6 +718,35 @@ func.func @tensor.concat_dynamic_nonconcat_dim(%f: tensor<?x?xf32>, %g: tensor<?
721718

722719
// -----
723720

721+
// CHECK: #[[$sum_map:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
722+
723+
// CHECK-LABEL: func @tensor.concat_mixed_dynamic_static(
724+
// CHECK-SAME: %[[F:.*]]: tensor<8x?xf32>, %[[G:.*]]: tensor<8x?xf32>,
725+
// CHECK-SAME: %[[H:.*]]: tensor<8x2xf32>)
726+
// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]]
727+
// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_buffer %[[G]]
728+
// CHECK-DAG: %[[H_MEMREF:.*]] = bufferization.to_buffer %[[H]]
729+
// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x10xf32>
730+
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
731+
// CHECK: %[[F_DIM:.*]] = memref.dim %[[F_MEMREF]], %[[c1]]
732+
// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, %[[F_DIM]]] [1, 1]
733+
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
734+
// CHECK: %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]]
735+
// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [8, %[[G_DIM]]] [1, 1]
736+
// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
737+
// CHECK: %[[OFFSET:.*]] = affine.apply #[[$sum_map]]()[%[[F_DIM]], %[[G_DIM]]]
738+
// CHECK: %[[SUBVIEW3:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET]]] [8, 2] [1, 1]
739+
// CHECK: memref.copy %[[H_MEMREF]], %[[SUBVIEW3]]
740+
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
741+
// CHECK: return %[[RET]]
742+
// CHECK: }
743+
func.func @tensor.concat_mixed_dynamic_static(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>, %h: tensor<8x2xf32>) -> tensor<8x10xf32> {
744+
%0 = tensor.concat dim(1) %f, %g, %h : (tensor<8x?xf32>, tensor<8x?xf32>, tensor<8x2xf32>) -> tensor<8x10xf32>
745+
return %0 : tensor<8x10xf32>
746+
}
747+
748+
// -----
749+
724750
// CHECK-LABEL: func @tensor.splat_dynamic(
725751
// CHECK-SAME: %[[F:[a-zA-Z0-9_]+]]: f32
726752
// CHECK-SAME: %[[M:[a-zA-Z0-9_]+]]: index

0 commit comments

Comments
 (0)