Skip to content

Commit f65021a

Browse files
pashu123ita9naiwa
authored andcommitted
[mlir][Vector] Replace vector.transpose with vector.shape_cast
Suppose the permutation width is defined as the last index in the permutation array that is not equal to its index. This pattern is applied to transpose operations where the input vector has a shape with at most one non-unit dimension up to the permutation width. The pattern replaces the transpose operation with a shape cast operation. For example: %0 = vector.transpose %1, [1, 0, 2] : vector<4x1x1xi32> to vector<1x4x1xi32> is replaced by %0 = vector.shape_cast %1 : vector<4x1x1xi32> to vector<1x4x1xi32> given the permutation width is 2.
1 parent f9af5c1 commit f65021a

File tree

2 files changed

+63
-2
lines changed

2 files changed

+63
-2
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,59 @@ class Transpose2DWithUnitDimToShapeCast
450450
}
451451
};
452452

453+
// Suppose the permutation width is defined as the last index in the permutation
454+
// array that is not equal to its index. This pattern is applied to transpose
455+
// operations where the input vector has a shape with at most one non-unit
456+
// dimension up to the permutation width. The pattern replaces the transpose
457+
// operation with a shape cast operation.
458+
// For example:
459+
// %0 = vector.transpose %1, [1, 0, 2] : vector<4x1x1xi32> to vector<1x4x1xi32>
460+
// is replaced by
461+
// %0 = vector.shape_cast %1 : vector<4x1x1xi32> to vector<1x4x1xi32>
462+
// given the permutation width is 2.
463+
class TransposeWithUnitDimToShapeCast
464+
: public OpRewritePattern<vector::TransposeOp> {
465+
public:
466+
using OpRewritePattern::OpRewritePattern;
467+
468+
TransposeWithUnitDimToShapeCast(MLIRContext *context,
469+
PatternBenefit benefit = 1)
470+
: OpRewritePattern<vector::TransposeOp>(context, benefit) {}
471+
472+
LogicalResult matchAndRewrite(vector::TransposeOp op,
473+
PatternRewriter &rewriter) const override {
474+
Value input = op.getVector();
475+
VectorType inputType = op.getSourceVectorType();
476+
if (inputType.isScalable())
477+
return rewriter.notifyMatchFailure(
478+
op, "This lowering does not support scalable vectors");
479+
VectorType resType = op.getResultVectorType();
480+
481+
ArrayRef<int64_t> transp = op.getPermutation();
482+
483+
// Get the permutation width.
484+
int64_t permWidth = 1;
485+
for (auto &&[idx, val] : llvm::enumerate(transp)) {
486+
if (static_cast<int64_t>(idx) != val)
487+
permWidth = idx + 1;
488+
}
489+
490+
// Check the no. of non unit dim in the input shape upto permutation width
491+
// is not greater than one.
492+
auto inputShape = inputType.getShape();
493+
494+
int64_t countNonUnitDims = 0;
495+
for (int i = 0; i < permWidth; i++) {
496+
if (inputShape[i] != 1)
497+
countNonUnitDims++;
498+
if (countNonUnitDims > 1)
499+
return failure();
500+
}
501+
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
502+
return success();
503+
}
504+
};
505+
453506
/// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
454507
/// If the strategy is Shuffle1D, it will be lowered to:
455508
/// vector.shape_cast 2D -> 1D
@@ -522,8 +575,9 @@ class TransposeOp2DToShuffleLowering
522575
void mlir::vector::populateVectorTransposeLoweringPatterns(
523576
RewritePatternSet &patterns, VectorTransformsOptions options,
524577
PatternBenefit benefit) {
525-
patterns.add<Transpose2DWithUnitDimToShapeCast>(patterns.getContext(),
526-
benefit);
578+
patterns
579+
.add<Transpose2DWithUnitDimToShapeCast, TransposeWithUnitDimToShapeCast>(
580+
patterns.getContext(), benefit);
527581
patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
528582
options, patterns.getContext(), benefit);
529583
}

mlir/test/Dialect/Vector/vector-transpose-lowering.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,13 @@ func.func @transpose10_4x1xf32_scalable(%arg0: vector<4x[1]xf32>) -> vector<[1]x
386386
return %0 : vector<[1]x4xf32>
387387
}
388388

389+
// CHECK-LABEL: func @transpose_nd
390+
func.func @transpose_nd(%arg0: vector<1x2x1x16xf32>) -> vector<1x1x2x16xf32> {
391+
// CHECK-NEXT: vector.shape_cast %arg0 : vector<1x2x1x16xf32> to vector<1x1x2x16xf32>
392+
%0 = vector.transpose %arg0, [0, 2, 1, 3] : vector<1x2x1x16xf32> to vector<1x1x2x16xf32>
393+
return %0 : vector<1x1x2x16xf32>
394+
}
395+
389396
module attributes {transform.with_named_sequence} {
390397
transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
391398
%func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">

0 commit comments

Comments
 (0)