diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 5324e38fa7d25..fdbdc72c057af 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5348,6 +5348,9 @@ class FoldContiguousGather final : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GatherOp op, PatternRewriter &rewriter) const override { + if (!op.getBase().getType().isa()) + return rewriter.notifyMatchFailure(op, "base must be of memref type"); + if (failed(isZeroBasedContiguousSeq(op.getIndexVec()))) return failure(); diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index a6d82b85777b0..78b0ea78849e8 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -3198,6 +3198,19 @@ func.func @contiguous_gather_step(%base: memref, // ----- +// CHECK-LABEL: @no_fold_contiguous_gather_tensor +func.func @no_fold_contiguous_gather_tensor(%base: tensor<8xf32>, %mask: vector<4xi1>, %pass_thru: vector<4xf32>) -> vector<4xf32> { + %c0 = arith.constant 0 : index + %indices = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> + // CHECK: vector.gather + // CHECK-NOT: vector.maskedload + %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru : + tensor<8xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32> + return %0 : vector<4xf32> +} + +// ----- + // CHECK-LABEL: @gather_broadcast( // TODO: Broadcast is not supported yet // CHECK: %[[R:.*]] = vector.gather