Skip to content

Commit 0c44188

Browse files
committed
add bufferization for concat op
1 parent 5198205 commit 0c44188

File tree

3 files changed

+142
-2
lines changed

3 files changed

+142
-2
lines changed

mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ void TensorDialect::initialize() {
4949
>();
5050
addInterfaces<TensorInlinerInterface>();
5151
declarePromisedInterfaces<
52-
bufferization::BufferizableOpInterface, CastOp, CollapseShapeOp, DimOp,
53-
EmptyOp, ExpandShapeOp, ExtractSliceOp, ExtractOp, FromElementsOp,
52+
bufferization::BufferizableOpInterface, CastOp, CollapseShapeOp, ConcatOp,
53+
DimOp, EmptyOp, ExpandShapeOp, ExtractSliceOp, ExtractOp, FromElementsOp,
5454
GenerateOp, InsertOp, InsertSliceOp, PadOp, ParallelInsertSliceOp, RankOp,
5555
ReshapeOp, SplatOp>();
5656
declarePromisedInterfaces<transform::FindPayloadReplacementOpInterface,

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,6 +1048,103 @@ struct SplatOpInterface
10481048
}
10491049
};
10501050

1051+
/// Bufferization of tensor.concat. Bufferizes to a new allocation that is
1052+
/// filled with copy ops. Similar to tensor.from_elements, but using memref.copy
1053+
/// on subviews instead of memref.store.
1054+
struct ConcatOpInterface
1055+
: public BufferizableOpInterface::ExternalModel<ConcatOpInterface,
1056+
tensor::ConcatOp> {
1057+
1058+
bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
1059+
1060+
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1061+
const AnalysisState &state) const {
1062+
return true;
1063+
}
1064+
1065+
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1066+
const AnalysisState &state) const {
1067+
return true;
1068+
}
1069+
1070+
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1071+
const AnalysisState &state) const {
1072+
return {{op->getResult(0), BufferRelation::Equivalent}};
1073+
}
1074+
1075+
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1076+
const BufferizationOptions &options) const {
1077+
OpBuilder::InsertionGuard g(rewriter);
1078+
auto concatOp = cast<tensor::ConcatOp>(op);
1079+
1080+
// Allocate memory.
1081+
Location loc = op->getLoc();
1082+
FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
1083+
rewriter, loc, concatOp.getResult(), options,
1084+
/*copy=*/false);
1085+
if (failed(tensorAlloc))
1086+
return failure();
1087+
auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
1088+
1089+
// TODO: Implement memory space for this op.
1090+
if (options.defaultMemorySpaceFn(tensorType) != Attribute())
1091+
return op->emitError("memory space not implemented yet");
1092+
1093+
MemRefLayoutAttrInterface layout;
1094+
MemRefType memrefType =
1095+
MemRefType::get(concatOp.getResultType().getShape(),
1096+
concatOp.getResultType().getElementType(), layout);
1097+
Value dstBuffer = rewriter.create<bufferization::ToMemrefOp>(
1098+
op->getLoc(), memrefType, *tensorAlloc);
1099+
1100+
// Extract the dimension for the concat op
1101+
uint64_t concatDim = concatOp.getDim();
1102+
1103+
SmallVector<OpFoldResult> offsets(tensorType.getRank(),
1104+
rewriter.getIndexAttr(0));
1105+
SmallVector<OpFoldResult> strides(tensorType.getRank(),
1106+
rewriter.getIndexAttr(1));
1107+
SmallVector<OpFoldResult> sizes;
1108+
for (auto dimSize : tensorType.getShape()) {
1109+
sizes.push_back(rewriter.getIndexAttr(dimSize));
1110+
}
1111+
1112+
int concatDimOffset = 0;
1113+
for (auto operand : concatOp.getInputs()) {
1114+
// Get the buffer for the operand.
1115+
FailureOr<Value> srcBuffer = getBuffer(rewriter, operand, options);
1116+
if (failed(srcBuffer))
1117+
return failure();
1118+
1119+
// Each operand may have a different size along the concat dimension,
1120+
// so the offset on that axis must accumulate through the loop, and the
1121+
// size must change to the size of the current operand.
1122+
auto operandTensorType = cast<RankedTensorType>(operand.getType());
1123+
int operandConcatDimSize = operandTensorType.getDimSize(concatDim);
1124+
sizes[concatDim] = rewriter.getIndexAttr(operandConcatDimSize);
1125+
offsets[concatDim] = rewriter.getIndexAttr(concatDimOffset);
1126+
1127+
// Create a subview of the destination buffer.
1128+
auto dstMemrefType = cast<MemRefType>(memrefType);
1129+
MemRefType subviewMemRefType =
1130+
memref::SubViewOp::inferRankReducedResultType(
1131+
operandTensorType.getShape(), dstMemrefType, offsets, sizes,
1132+
strides);
1133+
Value subview = rewriter.create<memref::SubViewOp>(
1134+
loc, subviewMemRefType, dstBuffer, offsets, sizes, strides);
1135+
1136+
// Copy the source buffer into the destination subview.
1137+
if (failed(options.createMemCpy(rewriter, loc, *srcBuffer, subview)))
1138+
return failure();
1139+
1140+
concatDimOffset += operandConcatDimSize;
1141+
}
1142+
1143+
replaceOpWithBufferizedValues(rewriter, op, dstBuffer);
1144+
return success();
1145+
}
1146+
};
1147+
10511148
} // namespace
10521149
} // namespace tensor
10531150
} // namespace mlir
@@ -1057,6 +1154,7 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
10571154
registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
10581155
CastOp::attachInterface<CastOpInterface>(*ctx);
10591156
CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
1157+
ConcatOp::attachInterface<ConcatOpInterface>(*ctx);
10601158
DimOp::attachInterface<DimOpInterface>(*ctx);
10611159
EmptyOp::attachInterface<EmptyOpInterface>(*ctx);
10621160
ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);

mlir/test/Dialect/Tensor/bufferize.mlir

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,48 @@ func.func @tensor.splat(%f: f32) -> tensor<10x2x4xf32> {
615615

616616
// -----
617617

618+
// CHECK-LABEL: func @tensor.concat(
619+
// CHECK-SAME: %[[F:.*]]: tensor<8xf32>)
620+
// CHECK: %[[F_MEMREF:.*]] = bufferization.to_memref %[[F]]
621+
// CHECK: memref.copy %[[F_MEMREF]], %[[F_ALLOC:.*]] :
622+
// CHECK: memref.copy %[[F_MEMREF]], %[[F_ALLOC_2:.*]] :
623+
// CHECK: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<16xf32>
624+
// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0] [8] [1]
625+
// CHECK: memref.copy %[[F_ALLOC]], %[[SUBVIEW1]]
626+
// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][8] [8] [1]
627+
// CHECK: memref.copy %[[F_ALLOC_2]], %[[SUBVIEW2]]
628+
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
629+
// CHECK: return %[[RET]]
630+
// CHECK: }
631+
func.func @tensor.concat(%f: tensor<8xf32>) -> tensor<16xf32> {
632+
%t = tensor.concat dim(0) %f, %f : (tensor<8xf32>, tensor<8xf32>) -> tensor<16xf32>
633+
return %t : tensor<16xf32>
634+
}
635+
636+
// -----
637+
638+
// CHECK-LABEL: func @tensor.concat_different_shapes(
639+
// CHECK-SAME: %[[F:.*]]: tensor<8x4xf32>
640+
// CHECK-SAME: %[[G:.*]]: tensor<8x5xf32>
641+
// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_memref %[[F]]
642+
// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_memref %[[G]]
643+
// CHECK: memref.copy %[[F_MEMREF]], %[[F_ALLOC:.*]] :
644+
// CHECK: memref.copy %[[G_MEMREF]], %[[F_ALLOC_2:.*]] :
645+
// CHECK: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<8x9xf32>
646+
// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, 4] [1, 1]
647+
// CHECK: memref.copy %[[F_ALLOC]], %[[SUBVIEW1]]
648+
// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, 4] [8, 5] [1, 1]
649+
// CHECK: memref.copy %[[F_ALLOC_2]], %[[SUBVIEW2]]
650+
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
651+
// CHECK: return %[[RET]]
652+
// CHECK: }
653+
func.func @tensor.concat_different_shapes(%f: tensor<8x4xf32>, %g: tensor<8x5xf32>) -> tensor<8x9xf32> {
654+
%t = tensor.concat dim(1) %f, %g : (tensor<8x4xf32>, tensor<8x5xf32>) -> tensor<8x9xf32>
655+
return %t : tensor<8x9xf32>
656+
}
657+
658+
// -----
659+
618660
// CHECK-LABEL: func @tensor.splat_dynamic(
619661
// CHECK-SAME: %[[F:[a-zA-Z0-9_]+]]: f32
620662
// CHECK-SAME: %[[M:[a-zA-Z0-9_]+]]: index

0 commit comments

Comments
 (0)