Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 56 additions & 2 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,59 @@ class Transpose2DWithUnitDimToShapeCast
}
};

// Suppose the permutation width is defined as the last index in the permutation
// array that is not equal to its index. This pattern is applied to transpose
// operations where the input vector has a shape with at most one non-unit
// dimension up to the permutation width. The pattern replaces the transpose
// operation with a shape cast operation.
// For example:
// %0 = vector.transpose %1, [1, 0, 2] : vector<4x1x1xi32> to vector<1x4x1xi32>
// is replaced by
// %0 = vector.shape_cast %1 : vector<4x1x1xi32> to vector<1x4x1xi32>
// given the permutation width is 2.
class TransposeWithUnitDimToShapeCast
: public OpRewritePattern<vector::TransposeOp> {
public:
using OpRewritePattern::OpRewritePattern;

TransposeWithUnitDimToShapeCast(MLIRContext *context,
PatternBenefit benefit = 1)
: OpRewritePattern<vector::TransposeOp>(context, benefit) {}

LogicalResult matchAndRewrite(vector::TransposeOp op,
PatternRewriter &rewriter) const override {
Value input = op.getVector();
VectorType inputType = op.getSourceVectorType();
if (inputType.isScalable())
return rewriter.notifyMatchFailure(
op, "This lowering does not support scalable vectors");
VectorType resType = op.getResultVectorType();

ArrayRef<int64_t> transp = op.getPermutation();

// Get the permutation width.
int64_t permWidth = 1;
for (auto &&[idx, val] : llvm::enumerate(transp)) {
if (static_cast<int64_t>(idx) != val)
permWidth = idx + 1;
}

// Check the no. of non unit dim in the input shape upto permutation width
// is not greater than one.
auto inputShape = inputType.getShape();

int64_t countNonUnitDims = 0;
for (int i = 0; i < permWidth; i++) {
if (inputShape[i] != 1)
countNonUnitDims++;
if (countNonUnitDims > 1)
return failure();
}
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
return success();
}
};

/// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
/// If the strategy is Shuffle1D, it will be lowered to:
/// vector.shape_cast 2D -> 1D
Expand Down Expand Up @@ -522,8 +575,9 @@ class TransposeOp2DToShuffleLowering
void mlir::vector::populateVectorTransposeLoweringPatterns(
RewritePatternSet &patterns, VectorTransformsOptions options,
PatternBenefit benefit) {
patterns.add<Transpose2DWithUnitDimToShapeCast>(patterns.getContext(),
benefit);
patterns
.add<Transpose2DWithUnitDimToShapeCast, TransposeWithUnitDimToShapeCast>(
patterns.getContext(), benefit);
patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
options, patterns.getContext(), benefit);
}
52 changes: 17 additions & 35 deletions mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -23,53 +23,21 @@ func.func @transpose23(%arg0: vector<2x3xf32>) -> vector<3x2xf32> {

// CHECK-LABEL: func @transpose102_1x8x8xf32
func.func @transpose102_1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<8x1x8xf32> {
// CHECK: vector.extract {{.*}}[0, 0] : vector<8xf32> from vector<1x8x8xf32>
// CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8xf32> into vector<8x1x8xf32>
// CHECK-NEXT: vector.extract {{.*}}[0, 1] : vector<8xf32> from vector<1x8x8xf32>
// CHECK-NEXT: vector.insert {{.*}} [1, 0] : vector<8xf32> into vector<8x1x8xf32>
// CHECK-NEXT: vector.extract {{.*}}[0, 2] : vector<8xf32> from vector<1x8x8xf32>
// CHECK-NEXT: vector.insert {{.*}} [2, 0] : vector<8xf32> into vector<8x1x8xf32>
// CHECK-NEXT: vector.extract {{.*}}[0, 3] : vector<8xf32> from vector<1x8x8xf32>
// CHECK-NEXT: vector.insert {{.*}} [3, 0] : vector<8xf32> into vector<8x1x8xf32>
// CHECK-NEXT: vector.extract {{.*}}[0, 4] : vector<8xf32> from vector<1x8x8xf32>
// CHECK-NEXT: vector.insert {{.*}} [4, 0] : vector<8xf32> into vector<8x1x8xf32>
// CHECK-NEXT: vector.extract {{.*}}[0, 5] : vector<8xf32> from vector<1x8x8xf32>
// CHECK-NEXT: vector.insert {{.*}} [5, 0] : vector<8xf32> into vector<8x1x8xf32>
// CHECK-NEXT: vector.extract {{.*}}[0, 6] : vector<8xf32> from vector<1x8x8xf32>
// CHECK-NEXT: vector.insert {{.*}} [6, 0] : vector<8xf32> into vector<8x1x8xf32>
// CHECK-NEXT: vector.extract {{.*}}[0, 7] : vector<8xf32> from vector<1x8x8xf32>
// CHECK-NEXT: vector.insert {{.*}} [7, 0] : vector<8xf32> into vector<8x1x8xf32>
// CHECK: %0 = vector.shape_cast %arg0 : vector<1x8x8xf32> to vector<8x1x8xf32>
%0 = vector.transpose %arg0, [1, 0, 2] : vector<1x8x8xf32> to vector<8x1x8xf32>
return %0 : vector<8x1x8xf32>
}

// CHECK-LABEL: func @transpose102_8x1x8xf32
func.func @transpose102_8x1x8xf32(%arg0: vector<8x1x8xf32>) -> vector<1x8x8xf32> {
// CHECK: vector.extract {{.*}}[0, 0] : vector<8xf32> from vector<8x1x8xf32>
// CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8xf32> into vector<1x8x8xf32>
// CHECK-NEXT: vector.extract {{.*}}[1, 0] : vector<8xf32> from vector<8x1x8xf32>
// CHECK-NEXT: vector.insert {{.*}} [0, 1] : vector<8xf32> into vector<1x8x8xf32>
// CHECK-NEXT: vector.extract {{.*}}[2, 0] : vector<8xf32> from vector<8x1x8xf32>
// CHECK-NEXT: vector.insert {{.*}} [0, 2] : vector<8xf32> into vector<1x8x8xf32>
// CHECK-NEXT: vector.extract {{.*}}[3, 0] : vector<8xf32> from vector<8x1x8xf32>
// CHECK-NEXT: vector.insert {{.*}} [0, 3] : vector<8xf32> into vector<1x8x8xf32>
// CHECK-NEXT: vector.extract {{.*}}[4, 0] : vector<8xf32> from vector<8x1x8xf32>
// CHECK-NEXT: vector.insert {{.*}} [0, 4] : vector<8xf32> into vector<1x8x8xf32>
// CHECK-NEXT: vector.extract {{.*}}[5, 0] : vector<8xf32> from vector<8x1x8xf32>
// CHECK-NEXT: vector.insert {{.*}} [0, 5] : vector<8xf32> into vector<1x8x8xf32>
// CHECK-NEXT: vector.extract {{.*}}[6, 0] : vector<8xf32> from vector<8x1x8xf32>
// CHECK-NEXT: vector.insert {{.*}} [0, 6] : vector<8xf32> into vector<1x8x8xf32>
// CHECK-NEXT: vector.extract {{.*}}[7, 0] : vector<8xf32> from vector<8x1x8xf32>
// CHECK-NEXT: vector.insert {{.*}} [0, 7] : vector<8xf32> into vector<1x8x8xf32>
// CHECK: %0 = vector.shape_cast %arg0 : vector<8x1x8xf32> to vector<1x8x8xf32>
%0 = vector.transpose %arg0, [1, 0, 2] : vector<8x1x8xf32> to vector<1x8x8xf32>
return %0 : vector<1x8x8xf32>
}

// CHECK-LABEL: func @transpose1023_1x1x8x8xf32(
func.func @transpose1023_1x1x8x8xf32(%arg0: vector<1x1x8x8xf32>) -> vector<1x1x8x8xf32> {
// Note the single 2-D extract/insert pair since 2 and 3 are not transposed!
// CHECK: vector.extract {{.*}}[0, 0] : vector<8x8xf32> from vector<1x1x8x8xf32>
// CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8x8xf32> into vector<1x1x8x8xf32>
// CHECK: return %arg0 : vector<1x1x8x8xf32>
%0 = vector.transpose %arg0, [1, 0, 2, 3] : vector<1x1x8x8xf32> to vector<1x1x8x8xf32>
return %0 : vector<1x1x8x8xf32>
}
Expand Down Expand Up @@ -386,6 +354,20 @@ func.func @transpose10_4x1xf32_scalable(%arg0: vector<4x[1]xf32>) -> vector<[1]x
return %0 : vector<[1]x4xf32>
}

// CHECK-LABEL: func @transpose_nd1
func.func @transpose_nd1(%arg0: vector<1x2x1x16xf32>) -> vector<1x1x2x16xf32> {
// CHECK-NEXT: vector.shape_cast %arg0 : vector<1x2x1x16xf32> to vector<1x1x2x16xf32>
%0 = vector.transpose %arg0, [0, 2, 1, 3] : vector<1x2x1x16xf32> to vector<1x1x2x16xf32>
return %0 : vector<1x1x2x16xf32>
}

// CHECK-LABEL: func @transpose_nd2
func.func @transpose_nd2(%arg0: vector<1x1x2x16xf32>) -> vector<1x2x1x16xf32> {
// CHECK-NEXT: vector.shape_cast %arg0 : vector<1x1x2x16xf32> to vector<1x2x1x16xf32>
%0 = vector.transpose %arg0, [0, 2, 1, 3] : vector<1x1x2x16xf32> to vector<1x2x1x16xf32>
return %0 : vector<1x2x1x16xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
%func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
Expand Down