@@ -1048,6 +1048,134 @@ struct SplatOpInterface
10481048 }
10491049};
10501050
1051+ // / Bufferization of tensor.concat. Bufferizes to a new allocation that is
1052+ // / filled with copy ops. Similar to tensor.from_elements, but using memref.copy
1053+ // / on subviews instead of memref.store.
1054+ struct ConcatOpInterface
1055+ : public BufferizableOpInterface::ExternalModel<ConcatOpInterface,
1056+ tensor::ConcatOp> {
1057+
1058+ bool bufferizesToAllocation (Operation *op, Value value) const { return true ; }
1059+
1060+ bool bufferizesToMemoryWrite (Operation *op, OpOperand &opOperand,
1061+ const AnalysisState &state) const {
1062+ return false ;
1063+ }
1064+
1065+ bool bufferizesToMemoryRead (Operation *op, OpOperand &opOperand,
1066+ const AnalysisState &state) const {
1067+ return true ;
1068+ }
1069+
1070+ AliasingValueList getAliasingValues (Operation *op, OpOperand &opOperand,
1071+ const AnalysisState &state) const {
1072+ return {};
1073+ }
1074+
1075+ LogicalResult bufferize (Operation *op, RewriterBase &rewriter,
1076+ const BufferizationOptions &options) const {
1077+ OpBuilder::InsertionGuard g (rewriter);
1078+ auto concatOp = cast<tensor::ConcatOp>(op);
1079+
1080+ // Allocate memory.
1081+ Location loc = op->getLoc ();
1082+ FailureOr<Value> tensorAlloc = allocateTensorForShapedValue (
1083+ rewriter, loc, concatOp.getResult (), options,
1084+ /* copy=*/ false );
1085+ if (failed (tensorAlloc))
1086+ return failure ();
1087+ auto tensorType = cast<RankedTensorType>(tensorAlloc->getType ());
1088+
1089+ // TODO: Implement memory space for this op.
1090+ if (options.defaultMemorySpaceFn (tensorType) != Attribute ())
1091+ return op->emitError (" memory space not implemented yet" );
1092+
1093+ MemRefLayoutAttrInterface layout;
1094+ MemRefType memrefType =
1095+ MemRefType::get (concatOp.getResultType ().getShape (),
1096+ concatOp.getResultType ().getElementType (), layout);
1097+ Value dstBuffer = rewriter.create <bufferization::ToBufferOp>(
1098+ op->getLoc (), memrefType, *tensorAlloc);
1099+
1100+ // Extract the dimension for the concat op
1101+ uint64_t concatDim = concatOp.getDim ();
1102+ bool dynamicConcatDim = false ;
1103+
1104+ SmallVector<OpFoldResult> offsets (tensorType.getRank (),
1105+ rewriter.getIndexAttr (0 ));
1106+ SmallVector<OpFoldResult> strides (tensorType.getRank (),
1107+ rewriter.getIndexAttr (1 ));
1108+ SmallVector<OpFoldResult> sizes;
1109+
1110+ for (const auto &[dimIdx, dimSize] :
1111+ llvm::enumerate (tensorType.getShape ())) {
1112+ if (dimSize == ShapedType::kDynamic ) {
1113+ auto dimOp = rewriter.create <memref::DimOp>(loc, dstBuffer, dimIdx);
1114+ sizes.push_back (dimOp.getResult ());
1115+ if (dimIdx == concatDim)
1116+ dynamicConcatDim = true ;
1117+ } else {
1118+ sizes.push_back (rewriter.getIndexAttr (dimSize));
1119+ }
1120+ }
1121+
1122+ int64_t concatDimOffset = 0 ;
1123+ std::optional<Value> dynamicOffset;
1124+ std::optional<Value> dynamicSize;
1125+ if (dynamicConcatDim) {
1126+ // One or more operands have dynamic size, so we must accumulate the
1127+ // offset with arith ops.
1128+ dynamicOffset = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
1129+ }
1130+
1131+ for (auto operand : concatOp.getInputs ()) {
1132+ // Get the buffer for the operand.
1133+ FailureOr<Value> srcBuffer = getBuffer (rewriter, operand, options);
1134+ if (failed (srcBuffer))
1135+ return failure ();
1136+
1137+ // Each operand may have a different size along the concat dimension,
1138+ // so the offset on that axis must accumulate through the loop, and the
1139+ // size must change to the size of the current operand.
1140+ auto operandTensorType = cast<RankedTensorType>(operand.getType ());
1141+ int64_t operandConcatDimSize = operandTensorType.getDimSize (concatDim);
1142+
1143+ if (dynamicConcatDim) {
1144+ offsets[concatDim] = dynamicOffset.value ();
1145+ dynamicSize = rewriter.create <memref::DimOp>(loc, *srcBuffer, concatDim)
1146+ .getResult ();
1147+ sizes[concatDim] = dynamicSize.value ();
1148+ } else {
1149+ sizes[concatDim] = rewriter.getIndexAttr (operandConcatDimSize);
1150+ offsets[concatDim] = rewriter.getIndexAttr (concatDimOffset);
1151+ }
1152+
1153+ // Create a subview of the destination buffer.
1154+ auto dstMemrefType = cast<MemRefType>(memrefType);
1155+ MemRefType subviewMemRefType =
1156+ memref::SubViewOp::inferRankReducedResultType (
1157+ operandTensorType.getShape (), dstMemrefType, offsets, sizes,
1158+ strides);
1159+ Value subview = rewriter.create <memref::SubViewOp>(
1160+ loc, subviewMemRefType, dstBuffer, offsets, sizes, strides);
1161+
1162+ // Copy the source buffer into the destination subview.
1163+ if (failed (options.createMemCpy (rewriter, loc, *srcBuffer, subview)))
1164+ return failure ();
1165+
1166+ if (dynamicConcatDim) {
1167+ dynamicOffset = rewriter.create <arith::AddIOp>(
1168+ loc, dynamicOffset.value (), dynamicSize.value ());
1169+ } else {
1170+ concatDimOffset += operandConcatDimSize;
1171+ }
1172+ }
1173+
1174+ replaceOpWithBufferizedValues (rewriter, op, dstBuffer);
1175+ return success ();
1176+ }
1177+ };
1178+
10511179} // namespace
10521180} // namespace tensor
10531181} // namespace mlir
@@ -1057,6 +1185,7 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
10571185 registry.addExtension (+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
10581186 CastOp::attachInterface<CastOpInterface>(*ctx);
10591187 CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
1188+ ConcatOp::attachInterface<ConcatOpInterface>(*ctx);
10601189 DimOp::attachInterface<DimOpInterface>(*ctx);
10611190 EmptyOp::attachInterface<EmptyOpInterface>(*ctx);
10621191 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
0 commit comments