@@ -37,10 +37,37 @@ namespace {
3737// / %extracted = tensor.extract %0[%in, %17] : tensor<7x128xf16>
3838// / linalg.yield %extracted : f16
3939// / } -> tensor<1x7x128xf16>
40-
4140struct DecomposeGatherOp : public OpRewritePattern <tensor::GatherOp> {
4241 using OpRewritePattern<tensor::GatherOp>::OpRewritePattern;
4342
43+ SmallVector<OpFoldResult> getDstMixedSizes (PatternRewriter &rewriter,
44+ Location loc,
45+ tensor::GatherOp gatherOp) const {
46+ SmallVector<OpFoldResult> dstSize =
47+ tensor::getMixedSizes (rewriter, loc, gatherOp.getResult ());
48+ SmallVector<OpFoldResult> indexSize =
49+ tensor::getMixedSizes (rewriter, loc, gatherOp.getIndices ());
50+ SmallVector<OpFoldResult> srcSize =
51+ tensor::getMixedSizes (rewriter, loc, gatherOp.getSource ());
52+ SmallVector<int64_t > gatherDims (gatherOp.getGatherDims ());
53+ bool isShrinkDst = (indexSize.size () - 1 ) + srcSize.size () ==
54+ dstSize.size () + gatherDims.size ();
55+ for (size_t i = 0 ; i < indexSize.size () - 1 ; i++) {
56+ dstSize[i] = indexSize[i];
57+ }
58+ auto cnt = 0 ;
59+ for (size_t i = indexSize.size () - 1 ; i < dstSize.size (); i++) {
60+ while (isShrinkDst && llvm::find (gatherDims, cnt) != gatherDims.end ()) {
61+ cnt++;
62+ }
63+ dstSize[i] = llvm::find (gatherDims, cnt) == gatherDims.end ()
64+ ? srcSize[cnt]
65+ : getAsIndexOpFoldResult (rewriter.getContext (), 1 );
66+ cnt++;
67+ }
68+ return dstSize;
69+ }
70+
4471 LogicalResult matchAndRewrite (tensor::GatherOp gatherOp,
4572 PatternRewriter &rewriter) const override {
4673 OpBuilder::InsertionGuard g (rewriter);
@@ -51,7 +78,7 @@ struct DecomposeGatherOp : public OpRewritePattern<tensor::GatherOp> {
5178 // create destination tensor for linalg out
5279 RankedTensorType dstType = gatherOp.getResultType ();
5380 Value dstTensor = rewriter.create <tensor::EmptyOp>(
54- loc, tensor::getMixedSizes (rewriter, loc, gatherOp. getResult () ),
81+ loc, getDstMixedSizes (rewriter, loc, gatherOp),
5582 dstType.getElementType ());
5683
5784 // split index tensor to create the linalg input
@@ -113,8 +140,8 @@ struct DecomposeGatherOp : public OpRewritePattern<tensor::GatherOp> {
113140 dstRank + gatherDims.size ();
114141 int cnt = 0 ;
115142 for (auto i = indexTensorSize.size () - 1 ; i < dstRank; i++) {
116- while (llvm::find (gatherDims, cnt) != gatherDims. end () &&
117- isShrinkDst ) {
143+ while (isShrinkDst &&
144+ llvm::find (gatherDims, cnt) != gatherDims. end () ) {
118145 cnt++;
119146 }
120147 indexValues[cnt] = b.create <linalg::IndexOp>(loc, i);
0 commit comments