@@ -54,43 +54,26 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {
5454
5555 LogicalResult matchAndRewrite (vector::GatherOp op,
5656 PatternRewriter &rewriter) const override {
57- VectorType resultTy = op.getType ();
58- if (resultTy.getRank () < 2 )
59- return rewriter.notifyMatchFailure (op, " already 1-D" );
60-
61- // Unrolling doesn't take vscale into account. Pattern is disabled for
62- // vectors with leading scalable dim(s).
63- if (resultTy.getScalableDims ().front ())
64- return rewriter.notifyMatchFailure (op, " cannot unroll scalable dim" );
65-
66- Location loc = op.getLoc ();
6757 Value indexVec = op.getIndexVec ();
6858 Value maskVec = op.getMask ();
6959 Value passThruVec = op.getPassThru ();
7060
71- Value result = arith::ConstantOp::create (rewriter, loc, resultTy,
72- rewriter.getZeroAttr (resultTy));
73-
74- VectorType subTy = VectorType::Builder (resultTy).dropDim (0 );
75-
76- for (int64_t i = 0 , e = resultTy.getShape ().front (); i < e; ++i) {
77- int64_t thisIdx[1 ] = {i};
61+ auto unrollGatherFn = [&](PatternRewriter &rewriter, Location loc,
62+ VectorType subTy, int64_t index) {
63+ int64_t thisIdx[1 ] = {index};
7864
7965 Value indexSubVec =
8066 vector::ExtractOp::create (rewriter, loc, indexVec, thisIdx);
8167 Value maskSubVec =
8268 vector::ExtractOp::create (rewriter, loc, maskVec, thisIdx);
8369 Value passThruSubVec =
8470 vector::ExtractOp::create (rewriter, loc, passThruVec, thisIdx);
85- Value subGather = vector::GatherOp::create (
86- rewriter, loc, subTy, op.getBase (), op.getIndices (), indexSubVec,
87- maskSubVec, passThruSubVec);
88- result =
89- vector::InsertOp::create (rewriter, loc, subGather, result, thisIdx);
90- }
71+ return vector::GatherOp::create (rewriter, loc, subTy, op.getBase (),
72+ op.getIndices (), indexSubVec, maskSubVec,
73+ passThruSubVec);
74+ };
9175
92- rewriter.replaceOp (op, result);
93- return success ();
76+ return unrollVectorOp (op, rewriter, unrollGatherFn);
9477 }
9578};
9679
0 commit comments