Skip to content

Commit 97a6c57

Browse files
committed
Add linearization pattern for vector.splat
1 parent 8832a59 commit 97a6c57

File tree

2 files changed

+70
-10
lines changed

2 files changed

+70
-10
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626

2727
using namespace mlir;
2828

29+
constexpr unsigned defaultTargetVectorBitWidth =
30+
std::numeric_limits<unsigned>::max();
31+
2932
static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
3033
auto resultTypes = op->getResultTypes();
3134
for (auto resType : resultTypes) {
@@ -82,7 +85,7 @@ struct LinearizeConstantLike final
8285

8386
LinearizeConstantLike(
8487
const TypeConverter &typeConverter, MLIRContext *context,
85-
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
88+
unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
8689
PatternBenefit benefit = 1)
8790
: OpTraitConversionPattern(typeConverter, context, benefit),
8891
targetVectorBitWidth(targetVectBitWidth) {}
@@ -136,7 +139,7 @@ struct LinearizeVectorizable final
136139
public:
137140
LinearizeVectorizable(
138141
const TypeConverter &typeConverter, MLIRContext *context,
139-
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
142+
unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
140143
PatternBenefit benefit = 1)
141144
: OpTraitConversionPattern(typeConverter, context, benefit),
142145
targetVectorBitWidth(targetVectBitWidth) {}
@@ -175,7 +178,7 @@ struct LinearizeVectorExtractStridedSlice final
175178
using OpConversionPattern::OpConversionPattern;
176179
LinearizeVectorExtractStridedSlice(
177180
const TypeConverter &typeConverter, MLIRContext *context,
178-
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
181+
unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
179182
PatternBenefit benefit = 1)
180183
: OpConversionPattern(typeConverter, context, benefit),
181184
targetVectorBitWidth(targetVectBitWidth) {}
@@ -289,7 +292,7 @@ struct LinearizeVectorShuffle final
289292
using OpConversionPattern::OpConversionPattern;
290293
LinearizeVectorShuffle(
291294
const TypeConverter &typeConverter, MLIRContext *context,
292-
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
295+
unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
293296
PatternBenefit benefit = 1)
294297
: OpConversionPattern(typeConverter, context, benefit),
295298
targetVectorBitWidth(targetVectBitWidth) {}
@@ -362,13 +365,17 @@ struct LinearizeVectorExtract final
362365
using OpConversionPattern::OpConversionPattern;
363366
LinearizeVectorExtract(
364367
const TypeConverter &typeConverter, MLIRContext *context,
365-
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
368+
unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
366369
PatternBenefit benefit = 1)
367370
: OpConversionPattern(typeConverter, context, benefit),
368371
targetVectorBitWidth(targetVectBitWidth) {}
369372
LogicalResult
370373
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
371374
ConversionPatternRewriter &rewriter) const override {
375+
// Skip if result is not a vector type
376+
if (!isa<VectorType>(extractOp.getType()))
377+
return rewriter.notifyMatchFailure(extractOp,
378+
"scalar extract is not supported.");
372379
Type dstTy = getTypeConverter()->convertType(extractOp.getType());
373380
if (!dstTy)
374381
return rewriter.notifyMatchFailure(extractOp,
@@ -425,7 +432,7 @@ struct LinearizeVectorInsert final
425432
using OpConversionPattern::OpConversionPattern;
426433
LinearizeVectorInsert(
427434
const TypeConverter &typeConverter, MLIRContext *context,
428-
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
435+
unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
429436
PatternBenefit benefit = 1)
430437
: OpConversionPattern(typeConverter, context, benefit),
431438
targetVectorBitWidth(targetVectBitWidth) {}
@@ -506,7 +513,7 @@ struct LinearizeVectorBitCast final
506513
using OpConversionPattern::OpConversionPattern;
507514
LinearizeVectorBitCast(
508515
const TypeConverter &typeConverter, MLIRContext *context,
509-
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
516+
unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
510517
PatternBenefit benefit = 1)
511518
: OpConversionPattern(typeConverter, context, benefit),
512519
targetVectorBitWidth(targetVectBitWidth) {}
@@ -531,12 +538,48 @@ struct LinearizeVectorBitCast final
531538
unsigned targetVectorBitWidth;
532539
};
533540

541+
/// This pattern converts the SplatOp to work on a linearized vector.
542+
/// Following,
543+
/// vector.splat %value : vector<4x4xf32>
544+
/// is converted to:
545+
/// %out_1d = vector.splat %value : vector<16xf32>
546+
/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
547+
/// It ensures that the operation is compatible with the target vector
548+
/// bit width and replaces the original operation with a new SplatOp
549+
/// that operates on the converted type.
550+
struct LinearizeVectorSplat final
551+
: public OpConversionPattern<vector::SplatOp> {
552+
using OpConversionPattern::OpConversionPattern;
553+
554+
LinearizeVectorSplat(
555+
const TypeConverter &typeConverter, MLIRContext *context,
556+
unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
557+
PatternBenefit benefit = 1)
558+
: OpConversionPattern(typeConverter, context, benefit),
559+
targetVectorBitWidth(targetVectBitWidth) {}
560+
561+
LogicalResult
562+
matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
563+
ConversionPatternRewriter &rewriter) const override {
564+
auto dstTy = getTypeConverter()->convertType(splatOp.getType());
565+
if (!dstTy)
566+
return rewriter.notifyMatchFailure(splatOp, "cannot convert type.");
567+
rewriter.replaceOpWithNewOp<vector::SplatOp>(splatOp, adaptor.getInput(),
568+
dstTy);
569+
return success();
570+
}
571+
572+
private:
573+
unsigned targetVectorBitWidth;
574+
};
575+
534576
} // namespace
535577

536578
void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
537579
TypeConverter &typeConverter, RewritePatternSet &patterns,
538580
ConversionTarget &target, unsigned targetBitWidth) {
539581

582+
typeConverter.addConversion([](Type type) -> Type { return type; });
540583
typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
541584
if (!isLinearizableVector(type))
542585
return type;
@@ -557,7 +600,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
557600
typeConverter.addTargetMaterialization(materializeCast);
558601
target.markUnknownOpDynamicallyLegal(
559602
[=](Operation *op) -> std::optional<bool> {
560-
if ((isa<vector::BitCastOp>(op) ||
603+
if ((isa<vector::BitCastOp, vector::SplatOp>(op) ||
561604
op->hasTrait<OpTrait::ConstantLike>() ||
562605
op->hasTrait<OpTrait::Vectorizable>())) {
563606
return (isLessThanTargetBitWidth(op, targetBitWidth)
@@ -568,8 +611,8 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
568611
});
569612

570613
patterns.add<LinearizeConstantLike, LinearizeVectorizable,
571-
LinearizeVectorBitCast>(typeConverter, patterns.getContext(),
572-
targetBitWidth);
614+
LinearizeVectorBitCast, LinearizeVectorSplat>(
615+
typeConverter, patterns.getContext(), targetBitWidth);
573616
}
574617

575618
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(

mlir/test/Dialect/Vector/linearize.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,3 +399,20 @@ func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
399399
%1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16>
400400
return %1 : vector<[4]x4xf16>
401401
}
402+
403+
// -----
404+
// ALL-LABEL: linearize_vector_splat
405+
// ALL-SAME: (%[[ARG:.*]]: i32) -> vector<4x2xi32>
406+
func.func @linearize_vector_splat(%arg0: i32) -> vector<4x2xi32> {
407+
// DEFAULT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8xi32>
408+
// DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<8xi32> to vector<4x2xi32>
409+
// DEFAULT: return %[[CAST]] : vector<4x2xi32>
410+
// BW-128: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8xi32>
411+
// BW-128: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<8xi32> to vector<4x2xi32>
412+
// BW-128: return %[[CAST]] : vector<4x2xi32>
413+
414+
// BW-0: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<4x2xi32>
415+
// BW-0: return %[[SPLAT]] : vector<4x2xi32>
416+
%0 = vector.splat %arg0 : vector<4x2xi32>
417+
return %0 : vector<4x2xi32>
418+
}

0 commit comments

Comments
 (0)