@@ -1048,6 +1048,103 @@ struct SplatOpInterface
10481048 }
10491049};
10501050
1051+ // / Bufferization of tensor.concat. Bufferizes to a new allocation that is
1052+ // / filled with copy ops. Similar to tensor.from_elements, but using memref.copy
1053+ // / on subviews instead of memref.store.
1054+ struct ConcatOpInterface
1055+ : public BufferizableOpInterface::ExternalModel<ConcatOpInterface,
1056+ tensor::ConcatOp> {
1057+
1058+ bool bufferizesToAllocation (Operation *op, Value value) const { return true ; }
1059+
1060+ bool bufferizesToMemoryWrite (Operation *op, OpOperand &opOperand,
1061+ const AnalysisState &state) const {
1062+ return true ;
1063+ }
1064+
1065+ bool bufferizesToMemoryRead (Operation *op, OpOperand &opOperand,
1066+ const AnalysisState &state) const {
1067+ return true ;
1068+ }
1069+
1070+ AliasingValueList getAliasingValues (Operation *op, OpOperand &opOperand,
1071+ const AnalysisState &state) const {
1072+ return {{op->getResult (0 ), BufferRelation::Equivalent}};
1073+ }
1074+
1075+ LogicalResult bufferize (Operation *op, RewriterBase &rewriter,
1076+ const BufferizationOptions &options) const {
1077+ OpBuilder::InsertionGuard g (rewriter);
1078+ auto concatOp = cast<tensor::ConcatOp>(op);
1079+
1080+ // Allocate memory.
1081+ Location loc = op->getLoc ();
1082+ FailureOr<Value> tensorAlloc = allocateTensorForShapedValue (
1083+ rewriter, loc, concatOp.getResult (), options,
1084+ /* copy=*/ false );
1085+ if (failed (tensorAlloc))
1086+ return failure ();
1087+ auto tensorType = cast<RankedTensorType>(tensorAlloc->getType ());
1088+
1089+ // TODO: Implement memory space for this op.
1090+ if (options.defaultMemorySpaceFn (tensorType) != Attribute ())
1091+ return op->emitError (" memory space not implemented yet" );
1092+
1093+ MemRefLayoutAttrInterface layout;
1094+ MemRefType memrefType =
1095+ MemRefType::get (concatOp.getResultType ().getShape (),
1096+ concatOp.getResultType ().getElementType (), layout);
1097+ Value dstBuffer = rewriter.create <bufferization::ToMemrefOp>(
1098+ op->getLoc (), memrefType, *tensorAlloc);
1099+
1100+ // Extract the dimension for the concat op
1101+ uint64_t concatDim = concatOp.getDim ();
1102+
1103+ SmallVector<OpFoldResult> offsets (tensorType.getRank (),
1104+ rewriter.getIndexAttr (0 ));
1105+ SmallVector<OpFoldResult> strides (tensorType.getRank (),
1106+ rewriter.getIndexAttr (1 ));
1107+ SmallVector<OpFoldResult> sizes;
1108+ for (auto dimSize : tensorType.getShape ()) {
1109+ sizes.push_back (rewriter.getIndexAttr (dimSize));
1110+ }
1111+
1112+ int concatDimOffset = 0 ;
1113+ for (auto operand : concatOp.getInputs ()) {
1114+ // Get the buffer for the operand.
1115+ FailureOr<Value> srcBuffer = getBuffer (rewriter, operand, options);
1116+ if (failed (srcBuffer))
1117+ return failure ();
1118+
1119+ // Each operand may have a different size along the concat dimension,
1120+ // so the offset on that axis must accumulate through the loop, and the
1121+ // size must change to the size of the current operand.
1122+ auto operandTensorType = cast<RankedTensorType>(operand.getType ());
1123+ int operandConcatDimSize = operandTensorType.getDimSize (concatDim);
1124+ sizes[concatDim] = rewriter.getIndexAttr (operandConcatDimSize);
1125+ offsets[concatDim] = rewriter.getIndexAttr (concatDimOffset);
1126+
1127+ // Create a subview of the destination buffer.
1128+ auto dstMemrefType = cast<MemRefType>(memrefType);
1129+ MemRefType subviewMemRefType =
1130+ memref::SubViewOp::inferRankReducedResultType (
1131+ operandTensorType.getShape (), dstMemrefType, offsets, sizes,
1132+ strides);
1133+ Value subview = rewriter.create <memref::SubViewOp>(
1134+ loc, subviewMemRefType, dstBuffer, offsets, sizes, strides);
1135+
1136+ // Copy the source buffer into the destination subview.
1137+ if (failed (options.createMemCpy (rewriter, loc, *srcBuffer, subview)))
1138+ return failure ();
1139+
1140+ concatDimOffset += operandConcatDimSize;
1141+ }
1142+
1143+ replaceOpWithBufferizedValues (rewriter, op, dstBuffer);
1144+ return success ();
1145+ }
1146+ };
1147+
10511148} // namespace
10521149} // namespace tensor
10531150} // namespace mlir
@@ -1057,6 +1154,7 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
10571154 registry.addExtension (+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
10581155 CastOp::attachInterface<CastOpInterface>(*ctx);
10591156 CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
1157+ ConcatOp::attachInterface<ConcatOpInterface>(*ctx);
10601158 DimOp::attachInterface<DimOpInterface>(*ctx);
10611159 EmptyOp::attachInterface<EmptyOpInterface>(*ctx);
10621160 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
0 commit comments