Skip to content

Commit 1e50115

Browse files
committed
[mlir][vector] Prevent folding non memref-type gather into maskedload
This patch fixes an issue in the FoldContiguousGather pattern which was incorrectly folding vector.gather operations with contiguous indices into vector.maskedload operations regardless of the base operand type. While vector.gather operations can work on both tensor and memref types, vector.maskedload operations are only valid for memref types. The pattern was incorrectly lowering a tensor-based gather into a masked-load, which is invalid. This fix adds a type check to ensure the pattern only applies to memref-based gather operations.
1 parent 04c3898 commit 1e50115

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5340,6 +5340,9 @@ class FoldContiguousGather final : public OpRewritePattern<GatherOp> {
53405340
using OpRewritePattern::OpRewritePattern;
53415341
LogicalResult matchAndRewrite(GatherOp op,
53425342
PatternRewriter &rewriter) const override {
5343+
if (!op.getBase().getType().isa<MemRefType>())
5344+
return failure();
5345+
53435346
if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
53445347
return failure();
53455348

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3149,6 +3149,18 @@ func.func @contiguous_gather_step(%base: memref<?xf32>,
31493149

31503150
// -----
31513151

3152+
// CHECK-LABEL: @dont_fold_tensor_type_contiguous_gather
3153+
func.func @dont_fold_tensor_type_contiguous_gather(%base: tensor<8xf32>, %mask: vector<4xi1>, %pass_thru: vector<4xf32>) -> vector<4xf32> {
3154+
%c0 = arith.constant 0 : index
3155+
%indices = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
3156+
// CHECK: vector.gather
3157+
// CHECK-NOT: vector.maskedload
3158+
%0 = vector.gather %base[%c0][%indices], %mask, %pass_thru : tensor<8xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
3159+
return %0 : vector<4xf32>
3160+
}
3161+
3162+
// -----
3163+
31523164
// CHECK-LABEL: @gather_broadcast(
31533165
// TODO: Broadcast is not supported yet
31543166
// CHECK: %[[R:.*]] = vector.gather

0 commit comments

Comments
 (0)