-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][vector] Canonicalize gathers/scatters with trivial offsets #117939
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
bfd8536
cf51b97
50f4984
88de40b
0c7d962
900393f
c079482
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5184,6 +5184,23 @@ std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() { | |
| return llvm::to_vector<4>(getVectorType().getShape()); | ||
| } | ||
|
|
||
| /// Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...] | ||
| static LogicalResult isZeroBasedContiguousSeq(Value indexVec) { | ||
| auto vecType = dyn_cast<VectorType>(indexVec.getType()); | ||
| if (!vecType || vecType.getRank() != 1 || vecType.isScalable()) | ||
| return failure(); | ||
|
|
||
| if (indexVec.getDefiningOp<StepOp>()) | ||
| return success(); | ||
|
|
||
| DenseIntElementsAttr elements; | ||
| if (!matchPattern(indexVec, m_Constant(&elements))) | ||
| return failure(); | ||
|
|
||
| return success( | ||
| llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements()))); | ||
| } | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @banach-space, is there a common utility that we can use here and for the extract op in the Linalg vectorizer?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not yet - the vectorizer looks at the scalar indices before vectorization. However, this patch make me think that we could do better 🤔 Let me look into this!
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Hardcode84 Would you have some linalg examples that vectorize into these contiguous gathers? That would be helpful, but no worries if |
||
|
|
||
| namespace { | ||
| class GatherFolder final : public OpRewritePattern<GatherOp> { | ||
| public: | ||
|
|
@@ -5202,11 +5219,28 @@ class GatherFolder final : public OpRewritePattern<GatherOp> { | |
| llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder"); | ||
| } | ||
| }; | ||
|
|
||
| /// Fold gathers with consecutive offsets [0, 1, 2, ...] into contiguous | ||
| /// maskedload. Only 1D fixed vectors are supported for now. | ||
| class FoldContiguousGather final : public OpRewritePattern<GatherOp> { | ||
| public: | ||
| using OpRewritePattern::OpRewritePattern; | ||
| LogicalResult matchAndRewrite(GatherOp op, | ||
| PatternRewriter &rewriter) const override { | ||
| if (failed(isZeroBasedContiguousSeq(op.getIndexVec()))) | ||
| return failure(); | ||
|
|
||
| rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(), | ||
| op.getIndices(), op.getMask(), | ||
| op.getPassThru()); | ||
| return success(); | ||
| } | ||
| }; | ||
| } // namespace | ||
|
|
||
| void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results, | ||
| MLIRContext *context) { | ||
| results.add<GatherFolder>(context); | ||
| results.add<GatherFolder, FoldContiguousGather>(context); | ||
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
|
|
@@ -5248,11 +5282,27 @@ class ScatterFolder final : public OpRewritePattern<ScatterOp> { | |
| llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder"); | ||
| } | ||
| }; | ||
|
|
||
| /// Fold scatters with consecutive offsets [0, 1, 2, ...] into contiguous | ||
| /// maskedstore. Only 1D fixed vectors are supported for now. | ||
| class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> { | ||
| public: | ||
| using OpRewritePattern::OpRewritePattern; | ||
| LogicalResult matchAndRewrite(ScatterOp op, | ||
| PatternRewriter &rewriter) const override { | ||
| if (failed(isZeroBasedContiguousSeq(op.getIndexVec()))) | ||
| return failure(); | ||
|
|
||
| rewriter.replaceOpWithNewOp<MaskedStoreOp>( | ||
| op, op.getBase(), op.getIndices(), op.getMask(), op.getValueToStore()); | ||
| return success(); | ||
| } | ||
| }; | ||
| } // namespace | ||
|
|
||
| void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results, | ||
| MLIRContext *context) { | ||
| results.add<ScatterFolder>(context); | ||
| results.add<ScatterFolder, FoldContiguousScatter>(context); | ||
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2838,3 +2838,144 @@ func.func @contiguous_extract_strided_slices_to_extract_failure_non_full_inner_s | |
| %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 2, 1, 1, 1], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x2x1x1x1xi32> | ||
| return %1 : vector<1x1x2x1x1x1xi32> | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // CHECK-LABEL: @contiguous_gather | ||
| // CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[PASSTHRU:.*]]: vector<16xf32>) | ||
| // CHECK: %[[C0:.*]] = arith.constant 0 : index | ||
| // CHECK: %[[R:.*]] = vector.maskedload %[[BASE]][%[[C0]]], %[[MASK]], %[[PASSTHRU]] : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32> | ||
| // CHECK: return %[[R]] | ||
| func.func @contiguous_gather(%base: memref<?xf32>, | ||
| %mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> { | ||
| %c0 = arith.constant 0 : index | ||
| %indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32> | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add tests for:
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Could you add a test and a TODO? The test would be "negative" (i.e. the folder would leave the code unchanged). For "scatter" load we'd only need to make sure that invalid.mlir contains relevant test. Could you check that?
Could you add a negative test to exercise this case? And a TODO to extend the pattern :)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Invalid scatter indices (and invalid dynamic indices in general) should not fail validation (see https://mlir.llvm.org/getting_started/DeveloperGuide/#ir-verifier), so nothing to add to |
||
| %1 = vector.gather %base[%c0][%indices], %mask, %passthru : | ||
| memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> | ||
| return %1 : vector<16xf32> | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // CHECK-LABEL: @contiguous_gather_non_zero_start( | ||
| // TODO: Non-zero start is not supported yet. | ||
| // CHECK: %[[R:.*]] = vector.gather | ||
| // CHECK: return %[[R]] | ||
| func.func @contiguous_gather_non_zero_start(%base: memref<?xf32>, | ||
| %mask: vector<16xi1>, | ||
| %passthru: vector<16xf32>) -> vector<16xf32> { | ||
| %c0 = arith.constant 0 : index | ||
| %indices = arith.constant dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> : vector<16xi32> | ||
| %1 = vector.gather %base[%c0][%indices], %mask, %passthru : | ||
| memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> | ||
| return %1 : vector<16xf32> | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // CHECK-LABEL: @contiguous_gather_2d( | ||
| // TODO: Only 1D vectors are supported. | ||
| // CHECK: %[[R:.*]] = vector.gather | ||
| // CHECK: return %[[R]] | ||
| func.func @contiguous_gather_2d(%base: memref<?x?xf32>, | ||
| %mask: vector<4x4xi1>, %passthru: vector<4x4xf32>) -> vector<4x4xf32> { | ||
| %c0 = arith.constant 0 : index | ||
| %indices = arith.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : vector<4x4xi32> | ||
| %1 = vector.gather %base[%c0, %c0][%indices], %mask, %passthru : | ||
| memref<?x?xf32>, vector<4x4xi32>, vector<4x4xi1>, vector<4x4xf32> into vector<4x4xf32> | ||
| return %1 : vector<4x4xf32> | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // CHECK-LABEL: @contiguous_gather_const_mask | ||
| // CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[PASSTHRU:.*]]: vector<16xf32>) | ||
| // CHECK: %[[C0:.*]] = arith.constant 0 : index | ||
| // CHECK: %[[R:.*]] = vector.load %[[BASE]][%[[C0]]] : memref<?xf32>, vector<16xf32> | ||
| // CHECK: return %[[R]] | ||
| func.func @contiguous_gather_const_mask(%base: memref<?xf32>, | ||
| %passthru: vector<16xf32>) -> vector<16xf32> { | ||
| %c0 = arith.constant 0 : index | ||
| %indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32> | ||
| %mask = arith.constant dense<true> : vector<16xi1> | ||
| %1 = vector.gather %base[%c0][%indices], %mask, %passthru : | ||
| memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> | ||
| return %1 : vector<16xf32> | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // CHECK-LABEL: @contiguous_gather_step | ||
| // CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[PASSTHRU:.*]]: vector<16xf32>) | ||
| // CHECK: %[[C0:.*]] = arith.constant 0 : index | ||
| // CHECK: %[[R:.*]] = vector.maskedload %[[BASE]][%[[C0]]], %[[MASK]], %[[PASSTHRU]] : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32> | ||
| // CHECK: return %[[R]] | ||
| func.func @contiguous_gather_step(%base: memref<?xf32>, | ||
| %mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> { | ||
| %c0 = arith.constant 0 : index | ||
| %indices = vector.step : vector<16xindex> | ||
| %1 = vector.gather %base[%c0][%indices], %mask, %passthru : | ||
| memref<?xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32> into vector<16xf32> | ||
| return %1 : vector<16xf32> | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // CHECK-LABEL: @gather_broadcast( | ||
| // TODO: Broadcast is not supported yet | ||
| // CHECK: %[[R:.*]] = vector.gather | ||
| // CHECK: return %[[R]] | ||
| func.func @gather_broadcast(%base: memref<?xf32>, | ||
| %mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> { | ||
| %c0 = arith.constant 0 : index | ||
| %indices = arith.constant dense<0> : vector<16xi32> | ||
| %1 = vector.gather %base[%c0][%indices], %mask, %passthru : | ||
| memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> | ||
| return %1 : vector<16xf32> | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // CHECK-LABEL: @contiguous_scatter | ||
| // CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[VALUE:.*]]: vector<16xf32>) | ||
| // CHECK: %[[C0:.*]] = arith.constant 0 : index | ||
| // CHECK: vector.maskedstore %[[BASE]][%[[C0]]], %[[MASK]], %[[VALUE]] : memref<?xf32>, vector<16xi1>, vector<16xf32> | ||
| func.func @contiguous_scatter(%base: memref<?xf32>, | ||
| %mask: vector<16xi1>, %value: vector<16xf32>) { | ||
| %c0 = arith.constant 0 : index | ||
| %indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32> | ||
| vector.scatter %base[%c0][%indices], %mask, %value : | ||
| memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> | ||
| return | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // CHECK-LABEL: @contiguous_scatter_const_mask | ||
| // CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[VALUE:.*]]: vector<16xf32>) | ||
| // CHECK: %[[C0:.*]] = arith.constant 0 : index | ||
| // CHECK: vector.store %[[VALUE]], %[[BASE]][%[[C0]]] : memref<?xf32>, vector<16xf32> | ||
| func.func @contiguous_scatter_const_mask(%base: memref<?xf32>, | ||
| %value: vector<16xf32>) { | ||
| %c0 = arith.constant 0 : index | ||
| %indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32> | ||
| %mask = vector.constant_mask [16] : vector<16xi1> | ||
| vector.scatter %base[%c0][%indices], %mask, %value : | ||
| memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> | ||
| return | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // CHECK-LABEL: @contiguous_scatter_step | ||
| // CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[VALUE:.*]]: vector<16xf32>) | ||
| // CHECK: %[[C0:.*]] = arith.constant 0 : index | ||
| // CHECK: vector.maskedstore %[[BASE]][%[[C0]]], %[[MASK]], %[[VALUE]] : memref<?xf32>, vector<16xi1>, vector<16xf32> | ||
| func.func @contiguous_scatter_step(%base: memref<?xf32>, | ||
| %mask: vector<16xi1>, %value: vector<16xf32>) { | ||
| %c0 = arith.constant 0 : index | ||
| %indices = vector.step : vector<16xindex> | ||
| vector.scatter %base[%c0][%indices], %mask, %value : | ||
| memref<?xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32> | ||
| return | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about contiguous indices with a different start number?