Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 52 additions & 2 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5184,6 +5184,23 @@ std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
return llvm::to_vector<4>(getVectorType().getShape());
}

/// Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...]
static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
auto vecType = dyn_cast<VectorType>(indexVec.getType());
if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
return failure();

if (indexVec.getDefiningOp<StepOp>())
return success();

DenseIntElementsAttr elements;
if (!matchPattern(indexVec, m_Constant(&elements)))
return failure();

return success(
llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about contiguous indices with a different start number?

}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@banach-space, is there a common utility that we can use here and for the extract op in the Linalg vectorizer?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not yet - the vectorizer looks at the scalar indices before vectorization. However, this patch make me think that we could do better 🤔 Let me look into this!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Hardcode84 Would you have some linalg examples that vectorize into these contiguous gathers? That would be helpful, but no worries if Vector is your actual starting point here.


namespace {
class GatherFolder final : public OpRewritePattern<GatherOp> {
public:
Expand All @@ -5202,11 +5219,28 @@ class GatherFolder final : public OpRewritePattern<GatherOp> {
llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
}
};

/// Fold gathers with consecutive offsets [0, 1, 2, ...] into contiguous
/// maskedload. Only 1D fixed vectors are supported for now.
class FoldContiguousGather final : public OpRewritePattern<GatherOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(GatherOp op,
PatternRewriter &rewriter) const override {
if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
return failure();

rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(),
op.getIndices(), op.getMask(),
op.getPassThru());
return success();
}
};
} // namespace

void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<GatherFolder>(context);
results.add<GatherFolder, FoldContiguousGather>(context);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -5248,11 +5282,27 @@ class ScatterFolder final : public OpRewritePattern<ScatterOp> {
llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
}
};

/// Fold scatters with consecutive offsets [0, 1, 2, ...] into contiguous
/// maskedstore. Only 1D fixed vectors are supported for now.
class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ScatterOp op,
PatternRewriter &rewriter) const override {
if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
return failure();

rewriter.replaceOpWithNewOp<MaskedStoreOp>(
op, op.getBase(), op.getIndices(), op.getMask(), op.getValueToStore());
return success();
}
};
} // namespace

void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ScatterFolder>(context);
results.add<ScatterFolder, FoldContiguousScatter>(context);
}

//===----------------------------------------------------------------------===//
Expand Down
141 changes: 141 additions & 0 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2838,3 +2838,144 @@ func.func @contiguous_extract_strided_slices_to_extract_failure_non_full_inner_s
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 2, 1, 1, 1], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x2x1x1x1xi32>
return %1 : vector<1x1x2x1x1x1xi32>
}

// -----

// CHECK-LABEL: @contiguous_gather
// CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[PASSTHRU:.*]]: vector<16xf32>)
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[R:.*]] = vector.maskedload %[[BASE]][%[[C0]]], %[[MASK]], %[[PASSTHRU]] : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
// CHECK: return %[[R]]
func.func @contiguous_gather(%base: memref<?xf32>,
%mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> {
%c0 = arith.constant 0 : index
%indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add tests for:

  • Start index != 0
  • ConstantMaskOp
  • constant indices that describe a broadcast (e.g., [3, 3, 3, 3, 3, 3... 3])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • We don't need any special handling for constant mask as it already handled in existing masked -> non-masked canonicalizations, added a couple of tests.
  • I can add support for non-zero start, but broadcast is more involved
    • For scatters duplicated indices are undefined per current spec
    • For gather we need reduce(mask) + 1-element vector.maskedload + extract + splat and I would rather not do this as part of this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

broadcast is more involved

Could you add a test and a TODO? The test would be "negative" (i.e. the folder would leave the code unchanged). For "scatter" load we'd only need to make sure that invalid.mlir contains relevant test. Could you check that?

I can add support for non-zero start

Could you add a negative test to exercise this case? And a TODO to extend the pattern :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Invalid scatter indices (and invalid dynamic indices in general) should not fail validation (see https://mlir.llvm.org/getting_started/DeveloperGuide/#ir-verifier), so nothing to add to invalid.mlir

%1 = vector.gather %base[%c0][%indices], %mask, %passthru :
memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return %1 : vector<16xf32>
}

// -----

// CHECK-LABEL: @contiguous_gather_non_zero_start(
// TODO: Non-zero start is not supported yet.
// CHECK: %[[R:.*]] = vector.gather
// CHECK: return %[[R]]
func.func @contiguous_gather_non_zero_start(%base: memref<?xf32>,
%mask: vector<16xi1>,
%passthru: vector<16xf32>) -> vector<16xf32> {
%c0 = arith.constant 0 : index
%indices = arith.constant dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> : vector<16xi32>
%1 = vector.gather %base[%c0][%indices], %mask, %passthru :
memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return %1 : vector<16xf32>
}

// -----

// CHECK-LABEL: @contiguous_gather_2d(
// TODO: Only 1D vectors are supported.
// CHECK: %[[R:.*]] = vector.gather
// CHECK: return %[[R]]
func.func @contiguous_gather_2d(%base: memref<?x?xf32>,
%mask: vector<4x4xi1>, %passthru: vector<4x4xf32>) -> vector<4x4xf32> {
%c0 = arith.constant 0 : index
%indices = arith.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : vector<4x4xi32>
%1 = vector.gather %base[%c0, %c0][%indices], %mask, %passthru :
memref<?x?xf32>, vector<4x4xi32>, vector<4x4xi1>, vector<4x4xf32> into vector<4x4xf32>
return %1 : vector<4x4xf32>
}

// -----

// CHECK-LABEL: @contiguous_gather_const_mask
// CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[PASSTHRU:.*]]: vector<16xf32>)
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[R:.*]] = vector.load %[[BASE]][%[[C0]]] : memref<?xf32>, vector<16xf32>
// CHECK: return %[[R]]
func.func @contiguous_gather_const_mask(%base: memref<?xf32>,
%passthru: vector<16xf32>) -> vector<16xf32> {
%c0 = arith.constant 0 : index
%indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
%mask = arith.constant dense<true> : vector<16xi1>
%1 = vector.gather %base[%c0][%indices], %mask, %passthru :
memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return %1 : vector<16xf32>
}

// -----

// CHECK-LABEL: @contiguous_gather_step
// CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[PASSTHRU:.*]]: vector<16xf32>)
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[R:.*]] = vector.maskedload %[[BASE]][%[[C0]]], %[[MASK]], %[[PASSTHRU]] : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
// CHECK: return %[[R]]
func.func @contiguous_gather_step(%base: memref<?xf32>,
%mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> {
%c0 = arith.constant 0 : index
%indices = vector.step : vector<16xindex>
%1 = vector.gather %base[%c0][%indices], %mask, %passthru :
memref<?xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return %1 : vector<16xf32>
}

// -----

// CHECK-LABEL: @gather_broadcast(
// TODO: Broadcast is not supported yet
// CHECK: %[[R:.*]] = vector.gather
// CHECK: return %[[R]]
func.func @gather_broadcast(%base: memref<?xf32>,
%mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> {
%c0 = arith.constant 0 : index
%indices = arith.constant dense<0> : vector<16xi32>
%1 = vector.gather %base[%c0][%indices], %mask, %passthru :
memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return %1 : vector<16xf32>
}

// -----

// CHECK-LABEL: @contiguous_scatter
// CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[VALUE:.*]]: vector<16xf32>)
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: vector.maskedstore %[[BASE]][%[[C0]]], %[[MASK]], %[[VALUE]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
func.func @contiguous_scatter(%base: memref<?xf32>,
%mask: vector<16xi1>, %value: vector<16xf32>) {
%c0 = arith.constant 0 : index
%indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
vector.scatter %base[%c0][%indices], %mask, %value :
memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
return
}

// -----

// CHECK-LABEL: @contiguous_scatter_const_mask
// CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[VALUE:.*]]: vector<16xf32>)
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: vector.store %[[VALUE]], %[[BASE]][%[[C0]]] : memref<?xf32>, vector<16xf32>
func.func @contiguous_scatter_const_mask(%base: memref<?xf32>,
%value: vector<16xf32>) {
%c0 = arith.constant 0 : index
%indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
%mask = vector.constant_mask [16] : vector<16xi1>
vector.scatter %base[%c0][%indices], %mask, %value :
memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
return
}

// -----

// CHECK-LABEL: @contiguous_scatter_step
// CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[VALUE:.*]]: vector<16xf32>)
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: vector.maskedstore %[[BASE]][%[[C0]]], %[[MASK]], %[[VALUE]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
func.func @contiguous_scatter_step(%base: memref<?xf32>,
%mask: vector<16xi1>, %value: vector<16xf32>) {
%c0 = arith.constant 0 : index
%indices = vector.step : vector<16xindex>
vector.scatter %base[%c0][%indices], %mask, %value :
memref<?xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32>
return
}
Loading