@@ -5184,13 +5184,14 @@ std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
51845184 return llvm::to_vector<4 >(getVectorType ().getShape ());
51855185}
51865186
5187- static LogicalResult isContiguousIndices (Value val) {
5188- auto vecType = dyn_cast<VectorType>(val.getType ());
5187+ // / Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...]
5188+ static LogicalResult isContiguousIndices (Value indexVec) {
5189+ auto vecType = dyn_cast<VectorType>(indexVec.getType ());
51895190 if (!vecType || vecType.getRank () != 1 || vecType.isScalable ())
51905191 return failure ();
51915192
51925193 DenseIntElementsAttr elements;
5193- if (!matchPattern (val , m_Constant (&elements)))
5194+ if (!matchPattern (indexVec , m_Constant (&elements)))
51945195 return failure ();
51955196
51965197 return success (
@@ -5216,6 +5217,8 @@ class GatherFolder final : public OpRewritePattern<GatherOp> {
52165217 }
52175218};
52185219
5220+ // / Fold gathers with consecutive offsets [0, 1, 2, ...] into contiguous
5221+ // / maskedload. Only 1D non-scalable vectors are supported for now.
52195222class GatherTrivialIndices final : public OpRewritePattern<GatherOp> {
52205223public:
52215224 using OpRewritePattern::OpRewritePattern;
@@ -5277,6 +5280,8 @@ class ScatterFolder final : public OpRewritePattern<ScatterOp> {
52775280 }
52785281};
52795282
5283+ // / Fold scatters with consecutive offsets [0, 1, 2, ...] into contiguous
5284+ // / maskedstore. Only 1D non-scalable vectors are supported for now.
52805285class ScatterTrivialIndices final : public OpRewritePattern<ScatterOp> {
52815286public:
52825287 using OpRewritePattern::OpRewritePattern;
0 commit comments