2626
2727using namespace mlir ;
2828
29+ constexpr unsigned defaultTargetVectorBitWidth =
30+ std::numeric_limits<unsigned >::max();
31+
2932static bool isLessThanTargetBitWidth (Operation *op, unsigned targetBitWidth) {
3033 // For BW-0, all operations are legal
3134 if (targetBitWidth == 0 )
@@ -86,7 +89,7 @@ struct LinearizeConstantLike final
8689
8790 LinearizeConstantLike (
8891 const TypeConverter &typeConverter, MLIRContext *context,
89- unsigned targetVectBitWidth = std::numeric_limits< unsigned >::max() ,
92+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth ,
9093 PatternBenefit benefit = 1 )
9194 : OpTraitConversionPattern(typeConverter, context, benefit),
9295 targetVectorBitWidth (targetVectBitWidth) {}
@@ -140,7 +143,7 @@ struct LinearizeVectorizable final
140143public:
141144 LinearizeVectorizable (
142145 const TypeConverter &typeConverter, MLIRContext *context,
143- unsigned targetVectBitWidth = std::numeric_limits< unsigned >::max() ,
146+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth ,
144147 PatternBenefit benefit = 1 )
145148 : OpTraitConversionPattern(typeConverter, context, benefit),
146149 targetVectorBitWidth (targetVectBitWidth) {}
@@ -179,7 +182,7 @@ struct LinearizeVectorExtractStridedSlice final
179182 using OpConversionPattern::OpConversionPattern;
180183 LinearizeVectorExtractStridedSlice (
181184 const TypeConverter &typeConverter, MLIRContext *context,
182- unsigned targetVectBitWidth = std::numeric_limits< unsigned >::max() ,
185+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth ,
183186 PatternBenefit benefit = 1 )
184187 : OpConversionPattern(typeConverter, context, benefit),
185188 targetVectorBitWidth (targetVectBitWidth) {}
@@ -295,7 +298,7 @@ struct LinearizeVectorInsertStridedSlice final
295298 using OpConversionPattern<vector::InsertStridedSliceOp>::OpConversionPattern;
296299 LinearizeVectorInsertStridedSlice (
297300 const TypeConverter &typeConverter, MLIRContext *context,
298- unsigned targetVectBitWidth = std::numeric_limits< unsigned >::max() ,
301+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth ,
299302 PatternBenefit benefit = 1 )
300303 : OpConversionPattern(typeConverter, context, benefit),
301304 targetVectorBitWidth (targetVectBitWidth) {}
@@ -317,11 +320,6 @@ struct LinearizeVectorInsertStridedSlice final
317320 insertOp,
318321 " InsertStridedSliceOp linearization only supports 2D source." );
319322
320- if (!srcTy.hasStaticShape () || !dstTy.hasStaticShape ())
321- return rewriter.notifyMatchFailure (
322- insertOp,
323- " InsertStridedSliceOp linerization only supports static shapes." );
324-
325323 if (srcTy.isScalable () || dstTy.isScalable ())
326324 return rewriter.notifyMatchFailure (insertOp,
327325 " scalable vectors are not supported." );
@@ -372,7 +370,7 @@ struct LinearizeVectorShuffle final
372370 using OpConversionPattern::OpConversionPattern;
373371 LinearizeVectorShuffle (
374372 const TypeConverter &typeConverter, MLIRContext *context,
375- unsigned targetVectBitWidth = std::numeric_limits< unsigned >::max() ,
373+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth ,
376374 PatternBenefit benefit = 1 )
377375 : OpConversionPattern(typeConverter, context, benefit),
378376 targetVectorBitWidth (targetVectBitWidth) {}
@@ -445,7 +443,7 @@ struct LinearizeVectorExtract final
445443 using OpConversionPattern::OpConversionPattern;
446444 LinearizeVectorExtract (
447445 const TypeConverter &typeConverter, MLIRContext *context,
448- unsigned targetVectBitWidth = std::numeric_limits< unsigned >::max() ,
446+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth ,
449447 PatternBenefit benefit = 1 )
450448 : OpConversionPattern(typeConverter, context, benefit),
451449 targetVectorBitWidth (targetVectBitWidth) {}
@@ -513,7 +511,7 @@ struct LinearizeVectorInsert final
513511 using OpConversionPattern::OpConversionPattern;
514512 LinearizeVectorInsert (
515513 const TypeConverter &typeConverter, MLIRContext *context,
516- unsigned targetVectBitWidth = std::numeric_limits< unsigned >::max() ,
514+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth ,
517515 PatternBenefit benefit = 1 )
518516 : OpConversionPattern(typeConverter, context, benefit),
519517 targetVectorBitWidth (targetVectBitWidth) {}
@@ -593,7 +591,7 @@ struct LinearizeVectorBitCast final
593591 using OpConversionPattern::OpConversionPattern;
594592 LinearizeVectorBitCast (
595593 const TypeConverter &typeConverter, MLIRContext *context,
596- unsigned targetVectBitWidth = std::numeric_limits< unsigned >::max() ,
594+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth ,
597595 PatternBenefit benefit = 1 )
598596 : OpConversionPattern(typeConverter, context, benefit),
599597 targetVectorBitWidth (targetVectBitWidth) {}
@@ -618,27 +616,27 @@ struct LinearizeVectorBitCast final
618616 unsigned targetVectorBitWidth;
619617};
620618
619+ // clang-format off
621620// / This pattern converts the LoadOp to a series of LoadOp & InsertOp
622621// / that works on a linearized vector.
623622// / Following,
624623// / vector.load %base[%indices] : vector<4x4xf32>
625624// / is converted to :
626625// / %result = arith.constant dense<0.0> : vector<4x4xf32>
627626// / %slice_0 = vector.load %base[%indices] : vector<4xf32>
628- // / %result_0 = vector.insert %slice_0, %result[0] : vector<4xf32> into
629- // / vector<4x4xf32> %slice_1 = vector.load %base[%indices + 1] : vector<4xf32>
630- // / %result_1 = vector.insert %slice_1, %result_0[1] : vector<4xf32> into
631- // / vector<4x4xf32>
627+ // / %result_0 = vector.insert %slice_0, %result[0] : vector<4xf32> into vector<4x4xf32>
628+ // / %slice_1 = vector.load %base[%indices + 1] : vector<4xf32>
629+ // / %result_1 = vector.insert %slice_1, %result_0[1] : vector<4xf32> into vector<4x4xf32>
632630// / ...
633631// / This unrolls the 2D vector load into multiple 1D vector loads and inserts
634632// / them into the result vector. The pattern currently supports only 2D vectors
633+ // clang-format on
635634struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
636635 using OpConversionPattern<vector::LoadOp>::OpConversionPattern;
637636
638- LinearizeVectorLoad (
639- const TypeConverter &typeConverter, MLIRContext *context,
640- unsigned targetVectBitWidth = std::numeric_limits<unsigned >::max(),
641- PatternBenefit benefit = 1 )
637+ LinearizeVectorLoad (const TypeConverter &typeConverter, MLIRContext *context,
638+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
639+ PatternBenefit benefit = 1 )
642640 : OpConversionPattern(typeConverter, context, benefit),
643641 targetVectorBitWidth (targetVectBitWidth) {}
644642
@@ -702,7 +700,7 @@ struct LinearizeVectorStore final
702700
703701 LinearizeVectorStore (
704702 const TypeConverter &typeConverter, MLIRContext *context,
705- unsigned targetVectBitWidth = std::numeric_limits< unsigned >::max() ,
703+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth ,
706704 PatternBenefit benefit = 1 )
707705 : OpConversionPattern(typeConverter, context, benefit),
708706 targetVectorBitWidth (targetVectBitWidth) {}
@@ -758,7 +756,7 @@ struct LinearizeVectorSplat final
758756
759757 LinearizeVectorSplat (
760758 const TypeConverter &typeConverter, MLIRContext *context,
761- unsigned targetVectBitWidth = std::numeric_limits< unsigned >::max() ,
759+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth ,
762760 PatternBenefit benefit = 1 )
763761 : OpConversionPattern(typeConverter, context, benefit),
764762 targetVectorBitWidth (targetVectBitWidth) {}
@@ -794,7 +792,7 @@ struct LinearizeVectorCreateMask final
794792
795793 LinearizeVectorCreateMask (
796794 const TypeConverter &typeConverter, MLIRContext *context,
797- unsigned targetVectBitWidth = std::numeric_limits< unsigned >::max() ,
795+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth ,
798796 PatternBenefit benefit = 1 )
799797 : OpConversionPattern(typeConverter, context, benefit),
800798 targetVectorBitWidth (targetVectBitWidth) {}
@@ -907,8 +905,7 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
907905 if (isLessThanTargetBitWidth (op, targetBitWidth)) {
908906 auto srcTy = op.getSourceVectorType ();
909907 auto dstTy = op.getDestVectorType ();
910- if (!op.hasNonUnitStrides () && srcTy.getRank () == 2 &&
911- srcTy.hasStaticShape () && dstTy.hasStaticShape ())
908+ if (!op.hasNonUnitStrides () && srcTy.getRank () == 2 )
912909 return false ;
913910 }
914911 return true ;
0 commit comments