@@ -5176,13 +5176,14 @@ std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
51765176 return llvm::to_vector<4 >(getVectorType ().getShape ());
51775177}
51785178
5179- static LogicalResult isContiguousIndices (Value val) {
5180- auto vecType = dyn_cast<VectorType>(val.getType ());
5179+ // / Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...]
5180+ static LogicalResult isContiguousIndices (Value indexVec) {
5181+ auto vecType = dyn_cast<VectorType>(indexVec.getType ());
51815182 if (!vecType || vecType.getRank () != 1 || vecType.isScalable ())
51825183 return failure ();
51835184
51845185 DenseIntElementsAttr elements;
5185- if (!matchPattern (val , m_Constant (&elements)))
5186+ if (!matchPattern (indexVec , m_Constant (&elements)))
51865187 return failure ();
51875188
51885189 return success (
@@ -5208,6 +5209,8 @@ class GatherFolder final : public OpRewritePattern<GatherOp> {
52085209 }
52095210};
52105211
5212+ // / Fold gathers with consecutive offsets [0, 1, 2, ...] into contiguous
5213+ // / maskedload. Only 1D non-scalable vectors are supported for now.
52115214class GatherTrivialIndices final : public OpRewritePattern<GatherOp> {
52125215public:
52135216 using OpRewritePattern::OpRewritePattern;
@@ -5269,6 +5272,8 @@ class ScatterFolder final : public OpRewritePattern<ScatterOp> {
52695272 }
52705273};
52715274
5275+ // / Fold scatters with consecutive offsets [0, 1, 2, ...] into contiguous
5276+ // / maskedstore. Only 1D non-scalable vectors are supported for now.
52725277class ScatterTrivialIndices final : public OpRewritePattern<ScatterOp> {
52735278public:
52745279 using OpRewritePattern::OpRewritePattern;
0 commit comments