@@ -326,6 +326,10 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
326326 VectorType inputType = op.getSourceVectorType ();
327327 VectorType resType = op.getResultVectorType ();
328328
329+ if (inputType.isScalable ())
330+ return rewriter.notifyMatchFailure (
331+ op, " This lowering does not support scalable vectors" );
332+
329333 // Set up convenience transposition table.
330334 ArrayRef<int64_t > transp = op.getPermutation ();
331335
@@ -334,28 +338,6 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
334338 return rewriter.notifyMatchFailure (
335339 op, " Options specifies lowering to shuffle" );
336340
337- // Replace:
338- // vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
339- // vector<1xnxelty>
340- // with:
341- // vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
342- //
343- // Source with leading unit dim (inverse) is also replaced. Unit dim must
344- // be fixed. Non-unit can be scalable.
345- if (resType.getRank () == 2 &&
346- ((resType.getShape ().front () == 1 &&
347- !resType.getScalableDims ().front ()) ||
348- (resType.getShape ().back () == 1 &&
349- !resType.getScalableDims ().back ())) &&
350- transp == ArrayRef<int64_t >({1 , 0 })) {
351- rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(op, resType, input);
352- return success ();
353- }
354-
355- // TODO: Add support for scalable vectors
356- if (inputType.isScalable ())
357- return failure ();
358-
359341 // Handle a true 2-D matrix transpose differently when requested.
360342 if (vectorTransformOptions.vectorTransposeLowering ==
361343 vector::VectorTransposeLowering::Flat &&
@@ -411,6 +393,64 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
411393 vector::VectorTransformsOptions vectorTransformOptions;
412394};
413395
396+ // / Rewrites vector.transpose as vector.shape_cast. This pattern is only applied
397+ // / to 2D vectors with at least one unit dim. For example:
398+ // /
399+ // / Replace:
400+ // / vector.transpose %0, [1, 0] : vector<4x1xi32>> to
401+ // / vector<1x4xi32>
402+ // / with:
403+ // / vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32>
404+ // /
405+ // / Source with leading unit dim (inverse) is also replaced. Unit dim must
406+ // / be fixed. Non-unit dim can be scalable.
407+ // /
408+ // / TODO: This pattern was introduced specifically to help lower scalable
409+ // / vectors. In hindsight, a more specialised canonicalization (for shape_cast's
410+ // / to cancel out) would be preferable:
411+ // /
412+ // / BEFORE:
413+ // / %0 = some_op
414+ // / %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32>
415+ // / %2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
416+ // / AFTER:
417+ // / %0 = some_op
418+ // / %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32>
419+ // /
420+ // / Given the context above, we may want to consider (re-)moving this pattern
421+ // / at some later time. I am leaving it for now in case there are other users
422+ // / that I am not aware of.
423+ class Transpose2DWithUnitDimToShapeCast
424+ : public OpRewritePattern<vector::TransposeOp> {
425+ public:
426+ using OpRewritePattern::OpRewritePattern;
427+
428+ Transpose2DWithUnitDimToShapeCast (MLIRContext *context,
429+ PatternBenefit benefit = 1 )
430+ : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
431+
432+ LogicalResult matchAndRewrite (vector::TransposeOp op,
433+ PatternRewriter &rewriter) const override {
434+ Value input = op.getVector ();
435+ VectorType resType = op.getResultVectorType ();
436+
437+ // Set up convenience transposition table.
438+ ArrayRef<int64_t > transp = op.getPermutation ();
439+
440+ if (resType.getRank () == 2 &&
441+ ((resType.getShape ().front () == 1 &&
442+ !resType.getScalableDims ().front ()) ||
443+ (resType.getShape ().back () == 1 &&
444+ !resType.getScalableDims ().back ())) &&
445+ transp == ArrayRef<int64_t >({1 , 0 })) {
446+ rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(op, resType, input);
447+ return success ();
448+ }
449+
450+ return failure ();
451+ }
452+ };
453+
414454// / Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
415455// / If the strategy is Shuffle1D, it will be lowered to:
416456// / vector.shape_cast 2D -> 1D
@@ -483,6 +523,8 @@ class TransposeOp2DToShuffleLowering
483523void mlir::vector::populateVectorTransposeLoweringPatterns (
484524 RewritePatternSet &patterns, VectorTransformsOptions options,
485525 PatternBenefit benefit) {
526+ patterns.add <Transpose2DWithUnitDimToShapeCast>(patterns.getContext (),
527+ benefit);
486528 patterns.add <TransposeOpLowering, TransposeOp2DToShuffleLowering>(
487529 options, patterns.getContext (), benefit);
488530}
0 commit comments