Skip to content

Commit 5e441a5

Browse files
committed
support dynamic tensors
1 parent 6ec0040 commit 5e441a5

File tree

2 files changed

+63
-6
lines changed

2 files changed

+63
-6
lines changed

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,17 +1099,35 @@ struct ConcatOpInterface
10991099

11001100
// Extract the dimension for the concat op
11011101
uint64_t concatDim = concatOp.getDim();
1102+
bool dynamicConcatDim = false;
11021103

11031104
SmallVector<OpFoldResult> offsets(tensorType.getRank(),
11041105
rewriter.getIndexAttr(0));
11051106
SmallVector<OpFoldResult> strides(tensorType.getRank(),
11061107
rewriter.getIndexAttr(1));
11071108
SmallVector<OpFoldResult> sizes;
1108-
for (auto dimSize : tensorType.getShape()) {
1109-
sizes.push_back(rewriter.getIndexAttr(dimSize));
1109+
1110+
for (const auto &[dimIdx, dimSize] :
1111+
llvm::enumerate(tensorType.getShape())) {
1112+
if (dimSize == ShapedType::kDynamic) {
1113+
auto dimOp = rewriter.create<memref::DimOp>(loc, dstBuffer, dimSize);
1114+
sizes.push_back(dimOp.getResult());
1115+
if (dimIdx == concatDim)
1116+
dynamicConcatDim = true;
1117+
} else {
1118+
sizes.push_back(rewriter.getIndexAttr(dimSize));
1119+
}
11101120
}
11111121

11121122
int64_t concatDimOffset = 0;
1123+
std::optional<Value> dynamicOffset;
1124+
std::optional<Value> dynamicSize;
1125+
if (dynamicConcatDim) {
1126+
// One or more operands have dynamic size, so we must accumulate the
1127+
// offset with arith ops.
1128+
dynamicOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1129+
}
1130+
11131131
for (auto operand : concatOp.getInputs()) {
11141132
// Get the buffer for the operand.
11151133
FailureOr<Value> srcBuffer = getBuffer(rewriter, operand, options);
@@ -1120,9 +1138,17 @@ struct ConcatOpInterface
11201138
// so the offset on that axis must accumulate through the loop, and the
11211139
// size must change to the size of the current operand.
11221140
auto operandTensorType = cast<RankedTensorType>(operand.getType());
1123-
int operandConcatDimSize = operandTensorType.getDimSize(concatDim);
1124-
sizes[concatDim] = rewriter.getIndexAttr(operandConcatDimSize);
1125-
offsets[concatDim] = rewriter.getIndexAttr(concatDimOffset);
1141+
int64_t operandConcatDimSize = operandTensorType.getDimSize(concatDim);
1142+
1143+
if (dynamicConcatDim) {
1144+
offsets[concatDim] = dynamicOffset.value();
1145+
dynamicSize = rewriter.create<memref::DimOp>(loc, *srcBuffer, concatDim)
1146+
.getResult();
1147+
sizes[concatDim] = dynamicSize.value();
1148+
} else {
1149+
sizes[concatDim] = rewriter.getIndexAttr(operandConcatDimSize);
1150+
offsets[concatDim] = rewriter.getIndexAttr(concatDimOffset);
1151+
}
11261152

11271153
// Create a subview of the destination buffer.
11281154
auto dstMemrefType = cast<MemRefType>(memrefType);
@@ -1137,7 +1163,12 @@ struct ConcatOpInterface
11371163
if (failed(options.createMemCpy(rewriter, loc, *srcBuffer, subview)))
11381164
return failure();
11391165

1140-
concatDimOffset += operandConcatDimSize;
1166+
if (dynamicConcatDim) {
1167+
dynamicOffset = rewriter.create<arith::AddIOp>(
1168+
loc, dynamicOffset.value(), dynamicSize.value());
1169+
} else {
1170+
concatDimOffset += operandConcatDimSize;
1171+
}
11411172
}
11421173

11431174
replaceOpWithBufferizedValues(rewriter, op, dstBuffer);

mlir/test/Dialect/Tensor/bufferize.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,32 @@ func.func @tensor.concat_different_shapes(%f: tensor<8x4xf32>, %g: tensor<8x5xf3
653653

654654
// -----
655655

656+
// CHECK-LABEL: func @tensor.concat_dynamic(
657+
// CHECK-SAME: %[[F:.*]]: tensor<8x?xf32>,
658+
// CHECK-SAME: %[[G:.*]]: tensor<8x?xf32>
659+
// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_memref %[[F]]
660+
// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_memref %[[G]]
661+
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
662+
// CHECK-DAG: %[[F_DIM:.*]] = memref.dim %[[F_MEMREF]], %[[c1]]
663+
// CHECK-DAG: %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]]
664+
// CHECK: %[[ALLOC:.*]] = memref.alloc
665+
// CHECK-SAME: memref<8x?xf32>
666+
// CHECK-DAG: %[[OFFSET:.*]] = arith.constant 0 : index
667+
// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET]]] [8, %[[F_DIM]]] [1, 1]
668+
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
669+
// CHECK: %[[OFFSET_2:.*]] = arith.addi %[[OFFSET]], %[[F_DIM]] : index
670+
// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [8, %[[G_DIM]]] [1, 1]
671+
// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
672+
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
673+
// CHECK: return %[[RET]]
674+
// CHECK: }
675+
func.func @tensor.concat_dynamic(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>) -> tensor<8x?xf32> {
676+
%t = tensor.concat dim(1) %f, %g : (tensor<8x?xf32>, tensor<8x?xf32>) -> tensor<8x?xf32>
677+
return %t : tensor<8x?xf32>
678+
}
679+
680+
// -----
681+
656682
// CHECK-LABEL: func @tensor.splat_dynamic(
657683
// CHECK-SAME: %[[F:[a-zA-Z0-9_]+]]: f32
658684
// CHECK-SAME: %[[M:[a-zA-Z0-9_]+]]: index

0 commit comments

Comments
 (0)