@@ -1099,17 +1099,35 @@ struct ConcatOpInterface
10991099
11001100 // Extract the dimension for the concat op
11011101 uint64_t concatDim = concatOp.getDim ();
1102+ bool dynamicConcatDim = false ;
11021103
11031104 SmallVector<OpFoldResult> offsets (tensorType.getRank (),
11041105 rewriter.getIndexAttr (0 ));
11051106 SmallVector<OpFoldResult> strides (tensorType.getRank (),
11061107 rewriter.getIndexAttr (1 ));
11071108 SmallVector<OpFoldResult> sizes;
1108- for (auto dimSize : tensorType.getShape ()) {
1109- sizes.push_back (rewriter.getIndexAttr (dimSize));
1109+
1110+ for (const auto &[dimIdx, dimSize] :
1111+ llvm::enumerate (tensorType.getShape ())) {
1112+ if (dimSize == ShapedType::kDynamic ) {
1113+ auto dimOp = rewriter.create <memref::DimOp>(loc, dstBuffer, dimSize);
1114+ sizes.push_back (dimOp.getResult ());
1115+ if (dimIdx == concatDim)
1116+ dynamicConcatDim = true ;
1117+ } else {
1118+ sizes.push_back (rewriter.getIndexAttr (dimSize));
1119+ }
11101120 }
11111121
11121122 int64_t concatDimOffset = 0 ;
1123+ std::optional<Value> dynamicOffset;
1124+ std::optional<Value> dynamicSize;
1125+ if (dynamicConcatDim) {
1126+ // One or more operands have dynamic size, so we must accumulate the
1127+ // offset with arith ops.
1128+ dynamicOffset = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
1129+ }
1130+
11131131 for (auto operand : concatOp.getInputs ()) {
11141132 // Get the buffer for the operand.
11151133 FailureOr<Value> srcBuffer = getBuffer (rewriter, operand, options);
@@ -1120,9 +1138,17 @@ struct ConcatOpInterface
11201138 // so the offset on that axis must accumulate through the loop, and the
11211139 // size must change to the size of the current operand.
11221140 auto operandTensorType = cast<RankedTensorType>(operand.getType ());
1123- int operandConcatDimSize = operandTensorType.getDimSize (concatDim);
1124- sizes[concatDim] = rewriter.getIndexAttr (operandConcatDimSize);
1125- offsets[concatDim] = rewriter.getIndexAttr (concatDimOffset);
1141+ int64_t operandConcatDimSize = operandTensorType.getDimSize (concatDim);
1142+
1143+ if (dynamicConcatDim) {
1144+ offsets[concatDim] = dynamicOffset.value ();
1145+ dynamicSize = rewriter.create <memref::DimOp>(loc, *srcBuffer, concatDim)
1146+ .getResult ();
1147+ sizes[concatDim] = dynamicSize.value ();
1148+ } else {
1149+ sizes[concatDim] = rewriter.getIndexAttr (operandConcatDimSize);
1150+ offsets[concatDim] = rewriter.getIndexAttr (concatDimOffset);
1151+ }
11261152
11271153 // Create a subview of the destination buffer.
11281154 auto dstMemrefType = cast<MemRefType>(memrefType);
@@ -1137,7 +1163,12 @@ struct ConcatOpInterface
11371163 if (failed (options.createMemCpy (rewriter, loc, *srcBuffer, subview)))
11381164 return failure ();
11391165
1140- concatDimOffset += operandConcatDimSize;
1166+ if (dynamicConcatDim) {
1167+ dynamicOffset = rewriter.create <arith::AddIOp>(
1168+ loc, dynamicOffset.value (), dynamicSize.value ());
1169+ } else {
1170+ concatDimOffset += operandConcatDimSize;
1171+ }
11411172 }
11421173
11431174 replaceOpWithBufferizedValues (rewriter, op, dstBuffer);
0 commit comments