-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[MLIR] [Vector] Added canonicalizer for folding from_elements + transpose #161841
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
c94bbb7
6bef6d2
70d3d8f
617267b
2889f3d
08a0802
375b711
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6723,6 +6723,61 @@ class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> { | |
} | ||
}; | ||
|
||
/// Folds transpose(from_elements(...)) into a new from_elements with permuted | ||
/// operands matching the transposed shape. | ||
class FoldTransposeFromElements final : public OpRewritePattern<TransposeOp> { | ||
public: | ||
using Base::Base; | ||
LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, | ||
PatternRewriter &rewriter) const override { | ||
auto fromElementsOp = | ||
transposeOp.getVector().getDefiningOp<vector::FromElementsOp>(); | ||
if (!fromElementsOp) | ||
return failure(); | ||
|
||
VectorType srcTy = fromElementsOp.getDest().getType(); | ||
VectorType dstTy = transposeOp.getType(); | ||
|
||
ArrayRef<int64_t> permutation = transposeOp.getPermutation(); | ||
int64_t rank = srcTy.getRank(); | ||
|
||
// Build inverse permutation to map destination indices back to source. | ||
SmallVector<int64_t, 4> inversePerm(rank, 0); | ||
keshavvinayak01 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
for (int64_t i = 0; i < rank; ++i) | ||
inversePerm[permutation[i]] = i; | ||
|
||
ArrayRef<int64_t> srcShape = srcTy.getShape(); | ||
ArrayRef<int64_t> dstShape = dstTy.getShape(); | ||
SmallVector<int64_t, 4> srcIdx(rank, 0); | ||
SmallVector<int64_t, 4> dstIdx(rank, 0); | ||
SmallVector<int64_t, 4> srcStrides = computeStrides(srcShape); | ||
SmallVector<int64_t, 4> dstStrides = computeStrides(dstShape); | ||
|
||
auto elements = fromElementsOp.getElements(); | ||
SmallVector<Value> newElements; | ||
keshavvinayak01 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
int64_t dstNumElements = dstTy.getNumElements(); | ||
newElements.reserve(dstNumElements); | ||
|
||
// For each element in destination row-major order, pick the corresponding | ||
// source element. | ||
for (int64_t lin = 0; lin < dstNumElements; ++lin) { | ||
|
||
// Pick the destination element index. | ||
dstIdx = delinearize(lin, dstStrides); | ||
// Map the destination element index to the source element index. | ||
for (int64_t j = 0; j < rank; ++j) | ||
srcIdx[j] = dstIdx[inversePerm[j]]; | ||
// Linearize the source element index. | ||
int64_t srcLin = linearize(srcIdx, srcStrides); | ||
// Add the source element to the new elements. | ||
newElements.push_back(elements[srcLin]); | ||
} | ||
|
||
rewriter.replaceOpWithNewOp<FromElementsOp>(transposeOp, dstTy, | ||
newElements); | ||
return success(); | ||
} | ||
}; | ||
|
||
/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is | ||
/// 'order preserving', where 'order preserving' means the flattened | ||
/// inputs and outputs of the transpose have identical (numerical) values. | ||
|
@@ -6823,7 +6878,8 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> { | |
void vector::TransposeOp::getCanonicalizationPatterns( | ||
RewritePatternSet &results, MLIRContext *context) { | ||
results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder, | ||
FoldTransposeSplat, FoldTransposeBroadcast>(context); | ||
FoldTransposeSplat, FoldTransposeFromElements, | ||
FoldTransposeBroadcast>(context); | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
|
keshavvinayak01 marked this conversation as resolved.
Show resolved
Hide resolved
|
Uh oh!
There was an error while loading. Please reload this page.