@@ -5185,7 +5185,7 @@ std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
51855185}
51865186
51875187// / Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...]
5188- static LogicalResult isContiguousIndices (Value indexVec) {
5188+ static LogicalResult isZeroBasedContiguousSeq (Value indexVec) {
51895189 auto vecType = dyn_cast<VectorType>(indexVec.getType ());
51905190 if (!vecType || vecType.getRank () != 1 || vecType.isScalable ())
51915191 return failure ();
@@ -5222,12 +5222,12 @@ class GatherFolder final : public OpRewritePattern<GatherOp> {
52225222
52235223// / Fold gathers with consecutive offsets [0, 1, 2, ...] into contiguous
52245224// / maskedload. Only 1D non-scalable vectors are supported for now.
5225- class GatherTrivialIndices final : public OpRewritePattern<GatherOp> {
5225+ class FoldContiguousGather final : public OpRewritePattern<GatherOp> {
52265226public:
52275227 using OpRewritePattern::OpRewritePattern;
52285228 LogicalResult matchAndRewrite (GatherOp op,
52295229 PatternRewriter &rewriter) const override {
5230- if (failed (isContiguousIndices (op.getIndexVec ())))
5230+ if (failed (isZeroBasedContiguousSeq (op.getIndexVec ())))
52315231 return failure ();
52325232
52335233 rewriter.replaceOpWithNewOp <MaskedLoadOp>(op, op.getType (), op.getBase (),
@@ -5240,7 +5240,7 @@ class GatherTrivialIndices final : public OpRewritePattern<GatherOp> {
52405240
52415241void GatherOp::getCanonicalizationPatterns (RewritePatternSet &results,
52425242 MLIRContext *context) {
5243- results.add <GatherFolder, GatherTrivialIndices >(context);
5243+ results.add <GatherFolder, FoldContiguousGather >(context);
52445244}
52455245
52465246// ===----------------------------------------------------------------------===//
@@ -5284,13 +5284,13 @@ class ScatterFolder final : public OpRewritePattern<ScatterOp> {
52845284};
52855285
52865286// / Fold scatters with consecutive offsets [0, 1, 2, ...] into contiguous
5287- // / maskedstore. Only 1D non-scalable vectors are supported for now.
5288- class ScatterTrivialIndices final : public OpRewritePattern<ScatterOp> {
5287+ // / maskedstore. Only 1D fixed vectors are supported for now.
5288+ class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
52895289public:
52905290 using OpRewritePattern::OpRewritePattern;
52915291 LogicalResult matchAndRewrite (ScatterOp op,
52925292 PatternRewriter &rewriter) const override {
5293- if (failed (isContiguousIndices (op.getIndexVec ())))
5293+ if (failed (isZeroBasedContiguousSeq (op.getIndexVec ())))
52945294 return failure ();
52955295
52965296 rewriter.replaceOpWithNewOp <MaskedStoreOp>(
@@ -5302,7 +5302,7 @@ class ScatterTrivialIndices final : public OpRewritePattern<ScatterOp> {
53025302
53035303void ScatterOp::getCanonicalizationPatterns (RewritePatternSet &results,
53045304 MLIRContext *context) {
5305- results.add <ScatterFolder, ScatterTrivialIndices >(context);
5305+ results.add <ScatterFolder, FoldContiguousScatter >(context);
53065306}
53075307
53085308// ===----------------------------------------------------------------------===//
0 commit comments