@@ -54,43 +54,26 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {
5454
5555  LogicalResult matchAndRewrite (vector::GatherOp op,
5656                                PatternRewriter &rewriter) const  override  {
57-     VectorType resultTy = op.getType ();
58-     if  (resultTy.getRank () < 2 )
59-       return  rewriter.notifyMatchFailure (op, " already 1-D"  );
60- 
61-     //  Unrolling doesn't take vscale into account. Pattern is disabled for
62-     //  vectors with leading scalable dim(s).
63-     if  (resultTy.getScalableDims ().front ())
64-       return  rewriter.notifyMatchFailure (op, " cannot unroll scalable dim"  );
65- 
66-     Location loc = op.getLoc ();
6757    Value indexVec = op.getIndexVec ();
6858    Value maskVec = op.getMask ();
6959    Value passThruVec = op.getPassThru ();
7060
71-     Value result = arith::ConstantOp::create (rewriter, loc, resultTy,
72-                                              rewriter.getZeroAttr (resultTy));
73- 
74-     VectorType subTy = VectorType::Builder (resultTy).dropDim (0 );
75- 
76-     for  (int64_t  i = 0 , e = resultTy.getShape ().front (); i < e; ++i) {
77-       int64_t  thisIdx[1 ] = {i};
61+     auto  unrollGatherFn = [&](PatternRewriter &rewriter, Location loc,
62+                               VectorType subTy, int64_t  index) {
63+       int64_t  thisIdx[1 ] = {index};
7864
7965      Value indexSubVec =
8066          vector::ExtractOp::create (rewriter, loc, indexVec, thisIdx);
8167      Value maskSubVec =
8268          vector::ExtractOp::create (rewriter, loc, maskVec, thisIdx);
8369      Value passThruSubVec =
8470          vector::ExtractOp::create (rewriter, loc, passThruVec, thisIdx);
85-       Value subGather = vector::GatherOp::create (
86-           rewriter, loc, subTy, op.getBase (), op.getIndices (), indexSubVec,
87-           maskSubVec, passThruSubVec);
88-       result =
89-           vector::InsertOp::create (rewriter, loc, subGather, result, thisIdx);
90-     }
71+       return  vector::GatherOp::create (rewriter, loc, subTy, op.getBase (),
72+                                       op.getIndices (), indexSubVec, maskSubVec,
73+                                       passThruSubVec);
74+     };
9175
92-     rewriter.replaceOp (op, result);
93-     return  success ();
76+     return  unrollVectorOp (op, rewriter, unrollGatherFn);
9477  }
9578};
9679
0 commit comments