-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[MLIR] [Vector] Added canonicalizer for folding from_elements + transpose #161841
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR] [Vector] Added canonicalizer for folding from_elements + transpose #161841
Conversation
…ctor.transpose) Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Keshav Vinayak Jha (keshavvinayak01) ChangesDescriptionAdds a new canonicalizer that folds TestingAdded a 2D vector lit test that verifies the working of the rewrite. Full diff: https://github.com/llvm/llvm-project/pull/161841.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b0132e889302f..7f6313c11ea18 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6723,6 +6723,61 @@ class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
}
};
+/// Folds transpose(from_elements(...)) into a new from_elements with permuted
+/// operands matching the transposed shape.
+class FoldTransposeFromElements final : public OpRewritePattern<TransposeOp> {
+public:
+ using Base::Base;
+ LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
+ PatternRewriter &rewriter) const override {
+ auto fromElementsOp =
+ transposeOp.getVector().getDefiningOp<vector::FromElementsOp>();
+ if (!fromElementsOp)
+ return failure();
+
+ VectorType srcTy = fromElementsOp.getDest().getType();
+ VectorType dstTy = transposeOp.getType();
+
+ ArrayRef<int64_t> permutation = transposeOp.getPermutation();
+ int64_t rank = srcTy.getRank();
+
+ // Build inverse permutation to map destination indices back to source.
+ SmallVector<int64_t, 4> inversePerm(rank, 0);
+ for (int64_t i = 0; i < rank; ++i)
+ inversePerm[permutation[i]] = i;
+
+ ArrayRef<int64_t> srcShape = srcTy.getShape();
+ ArrayRef<int64_t> dstShape = dstTy.getShape();
+ SmallVector<int64_t, 4> srcIdx(rank, 0);
+ SmallVector<int64_t, 4> dstIdx(rank, 0);
+ SmallVector<int64_t, 4> srcStrides = computeStrides(srcShape);
+ SmallVector<int64_t, 4> dstStrides = computeStrides(dstShape);
+
+ auto elements = fromElementsOp.getElements();
+ SmallVector<Value> newElements;
+ int64_t dstNumElements = dstTy.getNumElements();
+ newElements.reserve(dstNumElements);
+
+ // For each element in destination row-major order, pick the corresponding
+ // source element.
+ for (int64_t lin = 0; lin < dstNumElements; ++lin) {
+ // Pick the destination element index.
+ dstIdx = delinearize(lin, dstStrides);
+ // Map the destination element index to the source element index.
+ for (int64_t j = 0; j < rank; ++j)
+ srcIdx[j] = dstIdx[inversePerm[j]];
+ // Linearize the source element index.
+ int64_t srcLin = linearize(srcIdx, srcStrides);
+ // Add the source element to the new elements.
+ newElements.push_back(elements[srcLin]);
+ }
+
+ rewriter.replaceOpWithNewOp<FromElementsOp>(transposeOp, dstTy,
+ newElements);
+ return success();
+ }
+};
+
/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
/// 'order preserving', where 'order preserving' means the flattened
/// inputs and outputs of the transpose have identical (numerical) values.
@@ -6823,7 +6878,8 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
void vector::TransposeOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
- FoldTransposeSplat, FoldTransposeBroadcast>(context);
+ FoldTransposeSplat, FoldTransposeFromElements,
+ FoldTransposeBroadcast>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 5448976f84760..5f34d144cd472 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -308,6 +308,18 @@ func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x
// -----
+// CHECK-LABEL: transpose_from_elements_2d
+func.func @transpose_from_elements_2d(%a0: i32, %a1: i32, %a2: i32,
+ %a3: i32, %a4: i32, %a5: i32) -> vector<3x2xi32> {
+ %v = vector.from_elements %a0, %a1, %a2, %a3, %a4, %a5 : vector<2x3xi32>
+ %t = vector.transpose %v, [1, 0] : vector<2x3xi32> to vector<3x2xi32>
+ return %t : vector<3x2xi32>
+ // CHECK: %[[R:.*]] = vector.from_elements %arg0, %arg3, %arg1, %arg4, %arg2, %arg5 : vector<3x2xi32>
+ // CHECK-NOT: vector.transpose
+}
+
+// -----
+
func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
%0 = vector.constant_mask [2, 2] : vector<4x3xi1>
%1 = vector.extract_strided_slice %0
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
|
||
// For each element in destination row-major order, pick the corresponding | ||
// source element. | ||
for (int64_t lin = 0; lin < dstNumElements; ++lin) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does lin
represent?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So "lin" is short for "linear index" - it's a 1D index that represents the position of an element when the multi-dimensional vector is laid out in row-major order in memory. I felt it was a good iter name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find lin
a bit too enigmatic. Why not linearIdx
?
1. Minor nitpicks in code formatting. 2. More lit tests, convering 1D, 2D, 3D cases. Signed-off-by: Keshav Vinayak Jha <[email protected]>
|
||
// For each element in destination row-major order, pick the corresponding | ||
// source element. | ||
for (int64_t lin = 0; lin < dstNumElements; ++lin) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find lin
a bit too enigmatic. Why not linearIdx
?
Signed-off-by: Keshav Vinayak Jha <[email protected]>
1. Changed variable name of linearIdx iterator. 2. Moved canonicalizer lit tests to other vector.from_elements tests. 3. Added blocked comments signaling beginning, end, and name of the pattern. Signed-off-by: Keshav Vinayak Jha <[email protected]>
LGTM, thanks! (modulo pending comments) |
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Description
Adds a new canonicalizer that folds
vector.from_elements(vector.transpose))
=>vector.from_elements
. This canonicalization reorders the input elements forvector.from_elements
, adjusts the output shape to match the effect of the transpose op and eliminating its need.Testing
Added a 2D vector lit test that verifies the working of the rewrite.