-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][vector] Sink vector.extract/splat into load/store ops #134389
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
fc53309
e2dd80a
c2ddc12
abf51af
3668826
9b7af3a
1b5b408
cfaef9d
a959b60
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -161,6 +161,20 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns( | |
| void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, | ||
| PatternBenefit benefit = 1); | ||
|
|
||
| /// Patterns that remove redundant Vector Ops by merging them with load/store | ||
| /// ops | ||
| /// ``` | ||
| /// vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32> | ||
| /// vector.extract %0[1] : f32 from vector<4xf32> | ||
| /// ``` | ||
| /// Gets converted to: | ||
| /// ``` | ||
| /// %c1 = arith.constant 1 : index | ||
| /// %0 = arith.addi %arg1, %c1 overflow<nsw> : index | ||
| /// %1 = memref.load %arg0[%0] : memref<?xf32> | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this should be a canonicalization pattern iff there's only one use which is a vector.extract. I can't think of a reason why we would want to load the redundant elements. I would clearly document that this only applies to cases with one use/extract op.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
No objections from me, but from a purely maintenance point of view, I'd leave the implementation and most of the tests where they are. Otherwise, we risk "bloating" canonicalization.mlir and e.g. VectorOps.cpp.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One potential usecase where keeping vector.load + extract may be useful is when we are loading vector on aligned address for perf reasons and then using extract with offset to get unaligned data. I don't have such examples in practice, though.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Increasing the granularity of memory accesses may cause you not to be able to use wider load/store instructions, and undoing this later on and proving that you can use a wider memory access may be hard. We'd be losing information about how many bits are dereferencable and potentially misaligning the access. For this reason, I don't think this should be on by default.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Say one of your memory regions is dword-sized but your memory accesses take byte offsets: %x = vector.load ... : vector<4xi8>
%y = vector.extract %x [2]: i8The original load is efficient because you are accessing a full dword. However, if you turn it into
For example, the buffer instruction on amdgpu allow you to get a default value for any OOB accesses. Looking at the example above, it could be that only the last byte is OOB, but this alone makes the whole
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we are actually need any special handling or tests for subbyte types. The only ways we can have
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Applying this pattern to a vector of bits would lead to Also, in cases like this: %x = vector.load ... : vector<8xi1>
%y = vector.extract %x [5]: i1vector load is probably just a scalar load anyway. My suggestion is to restrict this patter to multi-byte element types (*) and rely on "narrow-type-emulation" to help with sub-bytes. (*) Multi-byte - at least one byte.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @kuhar, those examples were helpful! I'm still kind of borderline but let’s move forward with this as an independent pattern. The proliferation of dangling “populate” methods is concerning but this case may be worth it.
For that example, I would expect the alignment information to be explicit somewhere as
Yes but we can’t attribute hardware-specific semantics to A valid lowering of
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I'd be surprised if there is no issue with the data layout as the vector one assumes a packed layout and the scalar one would be unpacked. Looking at the generated LLVM IR for both cases would help |
||
| void populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns, | ||
| PatternBenefit benefit = 1); | ||
|
|
||
| /// Patterns that fold chained vector reductions. These patterns assume that | ||
| /// elementwise operations (e.g., `arith.addf` with vector operands) are | ||
| /// cheaper than vector reduction. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -902,6 +902,8 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> { | |
| }; | ||
|
|
||
| /// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex: | ||
| /// | ||
| /// Example: | ||
| /// ``` | ||
| /// %a = vector.broadcast %arg1 : index to vector<1x4xindex> | ||
| /// %b = vector.broadcast %arg2 : index to vector<1x4xindex> | ||
|
|
@@ -987,6 +989,8 @@ struct ReorderElementwiseOpsOnBroadcast final | |
| /// This may result in cleaner code when extracting a single value | ||
| /// from multi-element vector and also to help canonicalize 1-element vectors to | ||
| /// scalars. | ||
| /// | ||
| /// Example: | ||
| /// ``` | ||
| /// %0 = arith.addf %arg0, %arg1 : vector<4xf32> | ||
| /// %1 = vector.extract %0[1] : f32 from vector<4xf32> | ||
|
|
@@ -1043,6 +1047,144 @@ class ExtractOpFromElementwise final | |
| } | ||
| }; | ||
|
|
||
| static bool isSupportedMemSinkElementType(Type type) { | ||
Hardcode84 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if (isa<IndexType>(type)) | ||
| return true; | ||
|
|
||
| // Non-byte-aligned types are tricky, skip them. | ||
Hardcode84 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return type.isIntOrFloat() && type.getIntOrFloatBitWidth() % 8 == 0; | ||
| } | ||
|
|
||
| /// Pattern to rewrite vector.extract(vector.load) -> vector/memref.load. | ||
Hardcode84 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| /// | ||
| /// Example: | ||
| /// ``` | ||
| /// vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32> | ||
| /// vector.extract %0[1] : f32 from vector<4xf32> | ||
| /// ``` | ||
| /// Gets converted to: | ||
| /// ``` | ||
| /// %c1 = arith.constant 1 : index | ||
| /// %0 = arith.addi %arg1, %c1 overflow<nsw> : index | ||
| /// %1 = memref.load %arg0[%0] : memref<?xf32> | ||
| /// ``` | ||
Hardcode84 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> { | ||
| public: | ||
| using OpRewritePattern::OpRewritePattern; | ||
|
|
||
| LogicalResult matchAndRewrite(vector::ExtractOp op, | ||
| PatternRewriter &rewriter) const override { | ||
| auto loadOp = op.getVector().getDefiningOp<vector::LoadOp>(); | ||
| if (!loadOp) | ||
| return rewriter.notifyMatchFailure(op, "expected a load op"); | ||
|
|
||
| // Checking for single use so we won't duplicate load ops. | ||
| if (!loadOp->hasOneUse()) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If moving this to canonicalization, I would add a comment here stating that this condition is the one that makes this a canonicalization pattern and shouldn't be changed. |
||
| return rewriter.notifyMatchFailure(op, "expected single op use"); | ||
|
|
||
| VectorType loadVecType = loadOp.getVectorType(); | ||
| if (loadVecType.isScalable()) | ||
| return rewriter.notifyMatchFailure(op, | ||
| "scalable vectors are not supported"); | ||
|
|
||
| MemRefType memType = loadOp.getMemRefType(); | ||
| if (!isSupportedMemSinkElementType(memType.getElementType())) | ||
| return rewriter.notifyMatchFailure(op, "unsupported memref element type"); | ||
Hardcode84 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| int64_t rankOffset = memType.getRank() - loadVecType.getRank(); | ||
| if (rankOffset < 0) | ||
| return rewriter.notifyMatchFailure(op, "unsupported ranks combination"); | ||
|
|
||
| auto extractVecType = dyn_cast<VectorType>(op.getResult().getType()); | ||
| int64_t finalRank = 0; | ||
| if (extractVecType) | ||
| finalRank = extractVecType.getRank(); | ||
|
|
||
| SmallVector<Value> indices = loadOp.getIndices(); | ||
| SmallVector<OpFoldResult> extractPos = op.getMixedPosition(); | ||
|
|
||
| // There may be memory stores between the load and the extract op, so we | ||
| // need to make sure that the new load op is inserted at the same place as | ||
| // the original load op. | ||
| OpBuilder::InsertionGuard g(rewriter); | ||
| rewriter.setInsertionPoint(loadOp); | ||
banach-space marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Location loc = loadOp.getLoc(); | ||
| ArithIndexingBuilder idxBuilderf(rewriter, loc); | ||
| for (auto i : llvm::seq<int64_t>(rankOffset, indices.size() - finalRank)) { | ||
| OpFoldResult pos = extractPos[i - rankOffset]; | ||
| if (isConstantIntValue(pos, 0)) | ||
| continue; | ||
|
|
||
| Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos); | ||
| indices[i] = idxBuilderf.add(indices[i], offset); | ||
| } | ||
|
|
||
| Value base = loadOp.getBase(); | ||
| if (extractVecType) { | ||
| rewriter.replaceOpWithNewOp<vector::LoadOp>(op, extractVecType, base, | ||
| indices); | ||
| } else { | ||
| rewriter.replaceOpWithNewOp<memref::LoadOp>(op, base, indices); | ||
| } | ||
| // We checked for single use so we can safely erase the load op. | ||
| rewriter.eraseOp(loadOp); | ||
Hardcode84 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return success(); | ||
| } | ||
| }; | ||
|
|
||
| /// Pattern to rewrite vector.store(vector.splat) -> vector/memref.store. | ||
| /// | ||
| /// Example: | ||
| /// ``` | ||
| /// %0 = vector.splat %arg2 : vector<1xf32> | ||
| /// vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32> | ||
| /// ``` | ||
| /// Gets converted to: | ||
| /// ``` | ||
| /// memref.store %arg2, %arg0[%arg1] : memref<?xf32> | ||
| /// ``` | ||
| class StoreOpFromSplatOrBroadcast final | ||
| : public OpRewritePattern<vector::StoreOp> { | ||
| public: | ||
| using OpRewritePattern::OpRewritePattern; | ||
|
|
||
| LogicalResult matchAndRewrite(vector::StoreOp op, | ||
| PatternRewriter &rewriter) const override { | ||
| VectorType vecType = op.getVectorType(); | ||
| if (vecType.isScalable()) | ||
| return rewriter.notifyMatchFailure(op, | ||
| "scalable vectors are not supported"); | ||
|
|
||
| if (isa<VectorType>(op.getMemRefType().getElementType())) | ||
| return rewriter.notifyMatchFailure( | ||
| op, "memrefs of vectors are not supported"); | ||
|
|
||
| if (vecType.getNumElements() != 1) | ||
| return rewriter.notifyMatchFailure( | ||
| op, "only 1-element, vectors are supported"); | ||
Hardcode84 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| Operation *splat = op.getValueToStore().getDefiningOp(); | ||
| if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat)) | ||
| return rewriter.notifyMatchFailure(op, "neither a splat nor a broadcast"); | ||
|
|
||
| // Checking for single use so we can remove splat. | ||
| if (!splat->hasOneUse()) | ||
| return rewriter.notifyMatchFailure(op, "expected single op use"); | ||
|
|
||
| Value source = splat->getOperand(0); | ||
| Value base = op.getBase(); | ||
| ValueRange indices = op.getIndices(); | ||
|
|
||
| if (isa<VectorType>(source.getType())) { | ||
| rewriter.replaceOpWithNewOp<vector::StoreOp>(op, source, base, indices); | ||
| } else { | ||
| rewriter.replaceOpWithNewOp<memref::StoreOp>(op, source, base, indices); | ||
| } | ||
| rewriter.eraseOp(splat); | ||
Hardcode84 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return success(); | ||
| } | ||
| }; | ||
|
|
||
| // Helper that returns a vector comparison that constructs a mask: | ||
| // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] | ||
| // | ||
|
|
@@ -2109,6 +2251,12 @@ void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns, | |
| patterns.getContext(), benefit); | ||
| } | ||
|
|
||
| void mlir::vector::populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns, | ||
| PatternBenefit benefit) { | ||
| patterns.add<ExtractOpFromLoad, StoreOpFromSplatOrBroadcast>( | ||
| patterns.getContext(), benefit); | ||
| } | ||
|
|
||
| void mlir::vector::populateChainedVectorReductionFoldingPatterns( | ||
| RewritePatternSet &patterns, PatternBenefit benefit) { | ||
| patterns.add<ChainedReduction>(patterns.getContext(), benefit); | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.