@@ -2499,6 +2499,7 @@ static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
24992499 return DenseElementsAttr::get (destVecType, convertedElements);
25002500}
25012501
2502+
25022503OpFoldResult FromElementsOp::fold (FoldAdaptor adaptor) {
25032504 if (auto res = foldFromElementsToElements (*this ))
25042505 return res;
@@ -6723,6 +6724,63 @@ class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
67236724 }
67246725};
67256726
6727+ // / Folds transpose(from_elements(...)) into a new from_elements with permuted
6728+ // / operands matching the transposed shape.
6729+ class FoldTransposeFromElements final
6730+ : public OpRewritePattern<TransposeOp> {
6731+ public:
6732+
6733+ using Base::Base;
6734+ LogicalResult matchAndRewrite (vector::TransposeOp transposeOp,
6735+ PatternRewriter &rewriter) const override {
6736+ auto fromElementsOp =
6737+ transposeOp.getVector ().getDefiningOp <vector::FromElementsOp>();
6738+ if (!fromElementsOp)
6739+ return failure ();
6740+
6741+ VectorType srcTy = fromElementsOp.getDest ().getType ();
6742+ VectorType dstTy = transposeOp.getType ();
6743+
6744+ ArrayRef<int64_t > permutation = transposeOp.getPermutation ();
6745+ int64_t rank = srcTy.getRank ();
6746+
6747+ // Build inverse permutation to map destination indices back to source.
6748+ SmallVector<int64_t , 4 > inversePerm (rank, 0 );
6749+ for (int64_t i = 0 ; i < rank; ++i)
6750+ inversePerm[permutation[i]] = i;
6751+
6752+ ArrayRef<int64_t > srcShape = srcTy.getShape ();
6753+ ArrayRef<int64_t > dstShape = dstTy.getShape ();
6754+ SmallVector<int64_t , 4 > srcIdx (rank, 0 );
6755+ SmallVector<int64_t , 4 > dstIdx (rank, 0 );
6756+ SmallVector<int64_t , 4 > srcStrides = computeStrides (srcShape);
6757+ SmallVector<int64_t , 4 > dstStrides = computeStrides (dstShape);
6758+
6759+ auto elements = fromElementsOp.getElements ();
6760+ SmallVector<Value> newElements;
6761+ int64_t dstNumElements = dstTy.getNumElements ();
6762+ newElements.reserve (dstNumElements);
6763+
6764+ // For each element in destination row-major order, pick the corresponding
6765+ // source element.
6766+ for (int64_t lin = 0 ; lin < dstNumElements; ++lin) {
6767+ // Pick the destination element index.
6768+ dstIdx = delinearize (lin, dstStrides);
6769+ // Map the destination element index to the source element index.
6770+ for (int64_t j = 0 ; j < rank; ++j)
6771+ srcIdx[j] = dstIdx[inversePerm[j]];
6772+ // Linearize the source element index.
6773+ int64_t srcLin = linearize (srcIdx, srcStrides);
6774+ // Add the source element to the new elements.
6775+ newElements.push_back (elements[srcLin]);
6776+ }
6777+
6778+ rewriter.replaceOpWithNewOp <FromElementsOp>(transposeOp, dstTy,
6779+ newElements);
6780+ return success ();
6781+ }
6782+ };
6783+
67266784// / Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
67276785// / 'order preserving', where 'order preserving' means the flattened
67286786// / inputs and outputs of the transpose have identical (numerical) values.
@@ -6823,7 +6881,8 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
68236881void vector::TransposeOp::getCanonicalizationPatterns (
68246882 RewritePatternSet &results, MLIRContext *context) {
68256883 results.add <FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
6826- FoldTransposeSplat, FoldTransposeBroadcast>(context);
6884+ FoldTransposeSplat, FoldTransposeFromElements,
6885+ FoldTransposeBroadcast>(context);
68276886}
68286887
68296888// ===----------------------------------------------------------------------===//
0 commit comments