Skip to content

Commit 231371c

Browse files
committed
Address comments
1 parent 03789ec commit 231371c

File tree

1 file changed

+23
-26
lines changed

1 file changed

+23
-26
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626

2727
using namespace mlir;
2828

29+
constexpr unsigned defaultTargetVectorBitWidth =
30+
std::numeric_limits<unsigned>::max();
31+
2932
static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
3033
// For BW-0, all operations are legal
3134
if (targetBitWidth == 0)
@@ -86,7 +89,7 @@ struct LinearizeConstantLike final
8689

8790
LinearizeConstantLike(
8891
const TypeConverter &typeConverter, MLIRContext *context,
89-
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
92+
unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
9093
PatternBenefit benefit = 1)
9194
: OpTraitConversionPattern(typeConverter, context, benefit),
9295
targetVectorBitWidth(targetVectBitWidth) {}
@@ -140,7 +143,7 @@ struct LinearizeVectorizable final
140143
public:
141144
LinearizeVectorizable(
142145
const TypeConverter &typeConverter, MLIRContext *context,
143-
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
146+
unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
144147
PatternBenefit benefit = 1)
145148
: OpTraitConversionPattern(typeConverter, context, benefit),
146149
targetVectorBitWidth(targetVectBitWidth) {}
@@ -179,7 +182,7 @@ struct LinearizeVectorExtractStridedSlice final
179182
using OpConversionPattern::OpConversionPattern;
180183
LinearizeVectorExtractStridedSlice(
181184
const TypeConverter &typeConverter, MLIRContext *context,
182-
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
185+
unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
183186
PatternBenefit benefit = 1)
184187
: OpConversionPattern(typeConverter, context, benefit),
185188
targetVectorBitWidth(targetVectBitWidth) {}
@@ -295,7 +298,7 @@ struct LinearizeVectorInsertStridedSlice final
295298
using OpConversionPattern<vector::InsertStridedSliceOp>::OpConversionPattern;
296299
LinearizeVectorInsertStridedSlice(
297300
const TypeConverter &typeConverter, MLIRContext *context,
298-
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
301+
unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
299302
PatternBenefit benefit = 1)
300303
: OpConversionPattern(typeConverter, context, benefit),
301304
targetVectorBitWidth(targetVectBitWidth) {}
@@ -317,11 +320,6 @@ struct LinearizeVectorInsertStridedSlice final
317320
insertOp,
318321
"InsertStridedSliceOp linearization only supports 2D source.");
319322

320-
if (!srcTy.hasStaticShape() || !dstTy.hasStaticShape())
321-
return rewriter.notifyMatchFailure(
322-
insertOp,
323-
"InsertStridedSliceOp linerization only supports static shapes.");
324-
325323
if (srcTy.isScalable() || dstTy.isScalable())
326324
return rewriter.notifyMatchFailure(insertOp,
327325
"scalable vectors are not supported.");
@@ -372,7 +370,7 @@ struct LinearizeVectorShuffle final
372370
using OpConversionPattern::OpConversionPattern;
373371
LinearizeVectorShuffle(
374372
const TypeConverter &typeConverter, MLIRContext *context,
375-
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
373+
unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
376374
PatternBenefit benefit = 1)
377375
: OpConversionPattern(typeConverter, context, benefit),
378376
targetVectorBitWidth(targetVectBitWidth) {}
@@ -445,7 +443,7 @@ struct LinearizeVectorExtract final
445443
using OpConversionPattern::OpConversionPattern;
446444
LinearizeVectorExtract(
447445
const TypeConverter &typeConverter, MLIRContext *context,
448-
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
446+
unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
449447
PatternBenefit benefit = 1)
450448
: OpConversionPattern(typeConverter, context, benefit),
451449
targetVectorBitWidth(targetVectBitWidth) {}
@@ -513,7 +511,7 @@ struct LinearizeVectorInsert final
513511
using OpConversionPattern::OpConversionPattern;
514512
LinearizeVectorInsert(
515513
const TypeConverter &typeConverter, MLIRContext *context,
516-
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
514+
unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
517515
PatternBenefit benefit = 1)
518516
: OpConversionPattern(typeConverter, context, benefit),
519517
targetVectorBitWidth(targetVectBitWidth) {}
@@ -593,7 +591,7 @@ struct LinearizeVectorBitCast final
593591
using OpConversionPattern::OpConversionPattern;
594592
LinearizeVectorBitCast(
595593
const TypeConverter &typeConverter, MLIRContext *context,
596-
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
594+
unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
597595
PatternBenefit benefit = 1)
598596
: OpConversionPattern(typeConverter, context, benefit),
599597
targetVectorBitWidth(targetVectBitWidth) {}
@@ -618,27 +616,27 @@ struct LinearizeVectorBitCast final
618616
unsigned targetVectorBitWidth;
619617
};
620618

619+
// clang-format off
621620
/// This pattern converts the LoadOp to a series of LoadOp & InsertOp
622621
/// that works on a linearized vector.
623622
/// Following,
624623
/// vector.load %base[%indices] : vector<4x4xf32>
625624
/// is converted to :
626625
/// %result = arith.constant dense<0.0> : vector<4x4xf32>
627626
/// %slice_0 = vector.load %base[%indices] : vector<4xf32>
628-
/// %result_0 = vector.insert %slice_0, %result[0] : vector<4xf32> into
629-
/// vector<4x4xf32> %slice_1 = vector.load %base[%indices + 1] : vector<4xf32>
630-
/// %result_1 = vector.insert %slice_1, %result_0[1] : vector<4xf32> into
631-
/// vector<4x4xf32>
627+
/// %result_0 = vector.insert %slice_0, %result[0] : vector<4xf32> into vector<4x4xf32>
628+
/// %slice_1 = vector.load %base[%indices + 1] : vector<4xf32>
629+
/// %result_1 = vector.insert %slice_1, %result_0[1] : vector<4xf32> into vector<4x4xf32>
632630
/// ...
633631
/// This unrolls the 2D vector load into multiple 1D vector loads and inserts
634632
/// them into the result vector. The pattern currently supports only 2D vectors
633+
// clang-format on
635634
struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
636635
using OpConversionPattern<vector::LoadOp>::OpConversionPattern;
637636

638-
LinearizeVectorLoad(
639-
const TypeConverter &typeConverter, MLIRContext *context,
640-
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
641-
PatternBenefit benefit = 1)
637+
LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context,
638+
unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
639+
PatternBenefit benefit = 1)
642640
: OpConversionPattern(typeConverter, context, benefit),
643641
targetVectorBitWidth(targetVectBitWidth) {}
644642

@@ -702,7 +700,7 @@ struct LinearizeVectorStore final
702700

703701
LinearizeVectorStore(
704702
const TypeConverter &typeConverter, MLIRContext *context,
705-
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
703+
unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
706704
PatternBenefit benefit = 1)
707705
: OpConversionPattern(typeConverter, context, benefit),
708706
targetVectorBitWidth(targetVectBitWidth) {}
@@ -758,7 +756,7 @@ struct LinearizeVectorSplat final
758756

759757
LinearizeVectorSplat(
760758
const TypeConverter &typeConverter, MLIRContext *context,
761-
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
759+
unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
762760
PatternBenefit benefit = 1)
763761
: OpConversionPattern(typeConverter, context, benefit),
764762
targetVectorBitWidth(targetVectBitWidth) {}
@@ -794,7 +792,7 @@ struct LinearizeVectorCreateMask final
794792

795793
LinearizeVectorCreateMask(
796794
const TypeConverter &typeConverter, MLIRContext *context,
797-
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
795+
unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
798796
PatternBenefit benefit = 1)
799797
: OpConversionPattern(typeConverter, context, benefit),
800798
targetVectorBitWidth(targetVectBitWidth) {}
@@ -907,8 +905,7 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
907905
if (isLessThanTargetBitWidth(op, targetBitWidth)) {
908906
auto srcTy = op.getSourceVectorType();
909907
auto dstTy = op.getDestVectorType();
910-
if (!op.hasNonUnitStrides() && srcTy.getRank() == 2 &&
911-
srcTy.hasStaticShape() && dstTy.hasStaticShape())
908+
if (!op.hasNonUnitStrides() && srcTy.getRank() == 2)
912909
return false;
913910
}
914911
return true;

0 commit comments

Comments
 (0)