Skip to content

Commit c94bbb7

Browse files
Added canonicalization (vector.from_elements + vector.transpose -> vector.transpose)
Signed-off-by: Keshav Vinayak Jha <[email protected]>
1 parent ff4aec5 commit c94bbb7

File tree

2 files changed

+72
-1
lines changed

2 files changed

+72
-1
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2499,6 +2499,7 @@ static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
24992499
return DenseElementsAttr::get(destVecType, convertedElements);
25002500
}
25012501

2502+
25022503
OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
25032504
if (auto res = foldFromElementsToElements(*this))
25042505
return res;
@@ -6723,6 +6724,63 @@ class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
67236724
}
67246725
};
67256726

6727+
/// Folds transpose(from_elements(...)) into a new from_elements with permuted
6728+
/// operands matching the transposed shape.
6729+
class FoldTransposeFromElements final
6730+
: public OpRewritePattern<TransposeOp> {
6731+
public:
6732+
6733+
using Base::Base;
6734+
LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
6735+
PatternRewriter &rewriter) const override {
6736+
auto fromElementsOp =
6737+
transposeOp.getVector().getDefiningOp<vector::FromElementsOp>();
6738+
if (!fromElementsOp)
6739+
return failure();
6740+
6741+
VectorType srcTy = fromElementsOp.getDest().getType();
6742+
VectorType dstTy = transposeOp.getType();
6743+
6744+
ArrayRef<int64_t> permutation = transposeOp.getPermutation();
6745+
int64_t rank = srcTy.getRank();
6746+
6747+
// Build inverse permutation to map destination indices back to source.
6748+
SmallVector<int64_t, 4> inversePerm(rank, 0);
6749+
for (int64_t i = 0; i < rank; ++i)
6750+
inversePerm[permutation[i]] = i;
6751+
6752+
ArrayRef<int64_t> srcShape = srcTy.getShape();
6753+
ArrayRef<int64_t> dstShape = dstTy.getShape();
6754+
SmallVector<int64_t, 4> srcIdx(rank, 0);
6755+
SmallVector<int64_t, 4> dstIdx(rank, 0);
6756+
SmallVector<int64_t, 4> srcStrides = computeStrides(srcShape);
6757+
SmallVector<int64_t, 4> dstStrides = computeStrides(dstShape);
6758+
6759+
auto elements = fromElementsOp.getElements();
6760+
SmallVector<Value> newElements;
6761+
int64_t dstNumElements = dstTy.getNumElements();
6762+
newElements.reserve(dstNumElements);
6763+
6764+
// For each element in destination row-major order, pick the corresponding
6765+
// source element.
6766+
for (int64_t lin = 0; lin < dstNumElements; ++lin) {
6767+
// Pick the destination element index.
6768+
dstIdx = delinearize(lin, dstStrides);
6769+
// Map the destination element index to the source element index.
6770+
for (int64_t j = 0; j < rank; ++j)
6771+
srcIdx[j] = dstIdx[inversePerm[j]];
6772+
// Linearize the source element index.
6773+
int64_t srcLin = linearize(srcIdx, srcStrides);
6774+
// Add the source element to the new elements.
6775+
newElements.push_back(elements[srcLin]);
6776+
}
6777+
6778+
rewriter.replaceOpWithNewOp<FromElementsOp>(transposeOp, dstTy,
6779+
newElements);
6780+
return success();
6781+
}
6782+
};
6783+
67266784
/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
67276785
/// 'order preserving', where 'order preserving' means the flattened
67286786
/// inputs and outputs of the transpose have identical (numerical) values.
@@ -6823,7 +6881,8 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
68236881
void vector::TransposeOp::getCanonicalizationPatterns(
68246882
RewritePatternSet &results, MLIRContext *context) {
68256883
results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
6826-
FoldTransposeSplat, FoldTransposeBroadcast>(context);
6884+
FoldTransposeSplat, FoldTransposeFromElements,
6885+
FoldTransposeBroadcast>(context);
68276886
}
68286887

68296888
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,18 @@ func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x
308308

309309
// -----
310310

311+
// CHECK-LABEL: transpose_from_elements_2d
312+
func.func @transpose_from_elements_2d(%a0: i32, %a1: i32, %a2: i32,
313+
%a3: i32, %a4: i32, %a5: i32) -> vector<3x2xi32> {
314+
%v = vector.from_elements %a0, %a1, %a2, %a3, %a4, %a5 : vector<2x3xi32>
315+
%t = vector.transpose %v, [1, 0] : vector<2x3xi32> to vector<3x2xi32>
316+
return %t : vector<3x2xi32>
317+
// CHECK: %[[R:.*]] = vector.from_elements %arg0, %arg3, %arg1, %arg4, %arg2, %arg5 : vector<3x2xi32>
318+
// CHECK-NOT: vector.transpose
319+
}
320+
321+
// -----
322+
311323
func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
312324
%0 = vector.constant_mask [2, 2] : vector<4x3xi1>
313325
%1 = vector.extract_strided_slice %0

0 commit comments

Comments
 (0)