1616#include " mlir/Dialect/Tensor/IR/Tensor.h"
1717#include " mlir/Dialect/Tensor/Utils/Utils.h"
1818#include " mlir/Dialect/Utils/IndexingUtils.h"
19+ #include " mlir/Interfaces/InferTypeOpInterface.h"
1920#include " mlir/Interfaces/TilingInterface.h"
2021#include " mlir/Interfaces/ValueBoundsOpInterface.h"
2122
@@ -621,6 +622,12 @@ struct UnPackOpTiling
621622 SmallVectorImpl<OpFoldResult> &resultOffsets,
622623 SmallVectorImpl<OpFoldResult> &resultSizes) const {
623624 auto unPackOp = cast<UnPackOp>(op);
625+ // If the operand tile is the dest, then no adjustment is needed.
626+ if (operandNumber == unPackOp.getDestMutable ().getOperandNumber ()) {
627+ resultOffsets = llvm::to_vector (offsets);
628+ resultSizes = llvm::to_vector (sizes);
629+ return success ();
630+ }
624631 Location loc = unPackOp.getLoc ();
625632
626633 int64_t numTiles = unPackOp.getInnerDimsPos ().size ();
@@ -629,6 +636,11 @@ struct UnPackOpTiling
629636 // The tiling is applied on interchanged dimensions. We have to undo the
630637 // interchange to map sizes and offsets to the original input.
631638 int64_t outputRank = unPackOp.getDestRank ();
639+ ReifiedRankedShapedTypeDims reifiedReturnShapes;
640+ if (failed (reifyResultShapes (b, unPackOp, reifiedReturnShapes))) {
641+ return failure ();
642+ }
643+ SmallVector<OpFoldResult> outputMixedSizes = reifiedReturnShapes.front ();
632644 SmallVector<OpFoldResult> origOffsets (destOffsets);
633645 SmallVector<OpFoldResult> origSizes (destSizes);
634646 applyPermToRange (origOffsets, origSizes,
@@ -640,18 +652,21 @@ struct UnPackOpTiling
640652 for (auto dim : llvm::seq<int64_t >(0 , outputRank)) {
641653 using AV = affine::AffineValueExpr;
642654 affine::AffineBuilder ab (b, loc);
643- AffineExpr dim0, dim1, sym ;
655+ AffineExpr dim0, dim1, sym0 ;
644656 bindDims (b.getContext (), dim0, dim1);
645- bindSymbols (b.getContext (), sym );
657+ bindSymbols (b.getContext (), sym0 );
646658 if (dimAndTileMapping.count (dim)) {
647659 // If the data dimension is tiled, the i-th index is the product of
648660 // offset_i and tile_i, and the i-th size is the product of sizes_i and
649- // tile_i.
661+ // tile_i. The sizes must be clamped to the sizes of the unpack result.
650662 auto avOffset = AV (dim0).bind (origOffsets[dim]);
651663 auto avSize = AV (dim0).bind (origSizes[dim]);
652- auto avTileSize = AV (sym).bind (dimAndTileMapping[dim]);
664+ auto avTileSize = AV (sym0).bind (dimAndTileMapping[dim]);
665+ auto avResultSize = AV (dim0).bind (outputMixedSizes[dim]);
653666 resultOffsets.push_back (ab.mul (avOffset, avTileSize));
654- resultSizes.push_back (ab.mul (avSize, avTileSize));
667+ auto avResultOffset = AV (dim1).bind (resultOffsets.back ());
668+ resultSizes.push_back (ab.min ({ab.mul (avSize, avTileSize),
669+ ab.sub (avResultSize, avResultOffset)}));
655670 } else {
656671 resultOffsets.push_back (origOffsets[dim]);
657672 resultSizes.push_back (origSizes[dim]);
0 commit comments