@@ -1048,134 +1048,6 @@ struct SplatOpInterface
10481048 }
10491049};
10501050
1051- // / Bufferization of tensor.concat. Bufferizes to a new allocation that is
1052- // / filled with copy ops. Similar to tensor.from_elements, but using memref.copy
1053- // / on subviews instead of memref.store.
1054- struct ConcatOpInterface
1055- : public BufferizableOpInterface::ExternalModel<ConcatOpInterface,
1056- tensor::ConcatOp> {
1057-
1058- bool bufferizesToAllocation (Operation *op, Value value) const { return true ; }
1059-
1060- bool bufferizesToMemoryWrite (Operation *op, OpOperand &opOperand,
1061- const AnalysisState &state) const {
1062- return false ;
1063- }
1064-
1065- bool bufferizesToMemoryRead (Operation *op, OpOperand &opOperand,
1066- const AnalysisState &state) const {
1067- return true ;
1068- }
1069-
1070- AliasingValueList getAliasingValues (Operation *op, OpOperand &opOperand,
1071- const AnalysisState &state) const {
1072- return {};
1073- }
1074-
1075- LogicalResult bufferize (Operation *op, RewriterBase &rewriter,
1076- const BufferizationOptions &options) const {
1077- OpBuilder::InsertionGuard g (rewriter);
1078- auto concatOp = cast<tensor::ConcatOp>(op);
1079-
1080- // Allocate memory.
1081- Location loc = op->getLoc ();
1082- FailureOr<Value> tensorAlloc = allocateTensorForShapedValue (
1083- rewriter, loc, concatOp.getResult (), options,
1084- /* copy=*/ false );
1085- if (failed (tensorAlloc))
1086- return failure ();
1087- auto tensorType = cast<RankedTensorType>(tensorAlloc->getType ());
1088-
1089- // TODO: Implement memory space for this op.
1090- if (options.defaultMemorySpaceFn (tensorType) != Attribute ())
1091- return op->emitError (" memory space not implemented yet" );
1092-
1093- MemRefLayoutAttrInterface layout;
1094- MemRefType memrefType =
1095- MemRefType::get (concatOp.getResultType ().getShape (),
1096- concatOp.getResultType ().getElementType (), layout);
1097- Value dstBuffer = rewriter.create <bufferization::ToMemrefOp>(
1098- op->getLoc (), memrefType, *tensorAlloc);
1099-
1100- // Extract the dimension for the concat op
1101- uint64_t concatDim = concatOp.getDim ();
1102- bool dynamicConcatDim = false ;
1103-
1104- SmallVector<OpFoldResult> offsets (tensorType.getRank (),
1105- rewriter.getIndexAttr (0 ));
1106- SmallVector<OpFoldResult> strides (tensorType.getRank (),
1107- rewriter.getIndexAttr (1 ));
1108- SmallVector<OpFoldResult> sizes;
1109-
1110- for (const auto &[dimIdx, dimSize] :
1111- llvm::enumerate (tensorType.getShape ())) {
1112- if (dimSize == ShapedType::kDynamic ) {
1113- auto dimOp = rewriter.create <memref::DimOp>(loc, dstBuffer, dimIdx);
1114- sizes.push_back (dimOp.getResult ());
1115- if (dimIdx == concatDim)
1116- dynamicConcatDim = true ;
1117- } else {
1118- sizes.push_back (rewriter.getIndexAttr (dimSize));
1119- }
1120- }
1121-
1122- int64_t concatDimOffset = 0 ;
1123- std::optional<Value> dynamicOffset;
1124- std::optional<Value> dynamicSize;
1125- if (dynamicConcatDim) {
1126- // One or more operands have dynamic size, so we must accumulate the
1127- // offset with arith ops.
1128- dynamicOffset = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
1129- }
1130-
1131- for (auto operand : concatOp.getInputs ()) {
1132- // Get the buffer for the operand.
1133- FailureOr<Value> srcBuffer = getBuffer (rewriter, operand, options);
1134- if (failed (srcBuffer))
1135- return failure ();
1136-
1137- // Each operand may have a different size along the concat dimension,
1138- // so the offset on that axis must accumulate through the loop, and the
1139- // size must change to the size of the current operand.
1140- auto operandTensorType = cast<RankedTensorType>(operand.getType ());
1141- int64_t operandConcatDimSize = operandTensorType.getDimSize (concatDim);
1142-
1143- if (dynamicConcatDim) {
1144- offsets[concatDim] = dynamicOffset.value ();
1145- dynamicSize = rewriter.create <memref::DimOp>(loc, *srcBuffer, concatDim)
1146- .getResult ();
1147- sizes[concatDim] = dynamicSize.value ();
1148- } else {
1149- sizes[concatDim] = rewriter.getIndexAttr (operandConcatDimSize);
1150- offsets[concatDim] = rewriter.getIndexAttr (concatDimOffset);
1151- }
1152-
1153- // Create a subview of the destination buffer.
1154- auto dstMemrefType = cast<MemRefType>(memrefType);
1155- MemRefType subviewMemRefType =
1156- memref::SubViewOp::inferRankReducedResultType (
1157- operandTensorType.getShape (), dstMemrefType, offsets, sizes,
1158- strides);
1159- Value subview = rewriter.create <memref::SubViewOp>(
1160- loc, subviewMemRefType, dstBuffer, offsets, sizes, strides);
1161-
1162- // Copy the source buffer into the destination subview.
1163- if (failed (options.createMemCpy (rewriter, loc, *srcBuffer, subview)))
1164- return failure ();
1165-
1166- if (dynamicConcatDim) {
1167- dynamicOffset = rewriter.create <arith::AddIOp>(
1168- loc, dynamicOffset.value (), dynamicSize.value ());
1169- } else {
1170- concatDimOffset += operandConcatDimSize;
1171- }
1172- }
1173-
1174- replaceOpWithBufferizedValues (rewriter, op, dstBuffer);
1175- return success ();
1176- }
1177- };
1178-
11791051} // namespace
11801052} // namespace tensor
11811053} // namespace mlir
@@ -1185,7 +1057,6 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
11851057 registry.addExtension (+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
11861058 CastOp::attachInterface<CastOpInterface>(*ctx);
11871059 CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
1188- ConcatOp::attachInterface<ConcatOpInterface>(*ctx);
11891060 DimOp::attachInterface<DimOpInterface>(*ctx);
11901061 EmptyOp::attachInterface<EmptyOpInterface>(*ctx);
11911062 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
0 commit comments