2626
2727using namespace mlir ;
2828
29+ constexpr unsigned defaultTargetVectorBitWidth =
30+ std::numeric_limits<unsigned >::max();
31+
2932static bool isLessThanTargetBitWidth (Operation *op, unsigned targetBitWidth) {
3033 auto resultTypes = op->getResultTypes ();
3134 for (auto resType : resultTypes) {
@@ -82,7 +85,7 @@ struct LinearizeConstantLike final
8285
8386 LinearizeConstantLike (
8487 const TypeConverter &typeConverter, MLIRContext *context,
85- unsigned targetVectBitWidth = std::numeric_limits< unsigned >::max() ,
88+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth ,
8689 PatternBenefit benefit = 1 )
8790 : OpTraitConversionPattern(typeConverter, context, benefit),
8891 targetVectorBitWidth (targetVectBitWidth) {}
@@ -136,7 +139,7 @@ struct LinearizeVectorizable final
136139public:
137140 LinearizeVectorizable (
138141 const TypeConverter &typeConverter, MLIRContext *context,
139- unsigned targetVectBitWidth = std::numeric_limits< unsigned >::max() ,
142+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth ,
140143 PatternBenefit benefit = 1 )
141144 : OpTraitConversionPattern(typeConverter, context, benefit),
142145 targetVectorBitWidth (targetVectBitWidth) {}
@@ -175,7 +178,7 @@ struct LinearizeVectorExtractStridedSlice final
175178 using OpConversionPattern::OpConversionPattern;
176179 LinearizeVectorExtractStridedSlice (
177180 const TypeConverter &typeConverter, MLIRContext *context,
178- unsigned targetVectBitWidth = std::numeric_limits< unsigned >::max() ,
181+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth ,
179182 PatternBenefit benefit = 1 )
180183 : OpConversionPattern(typeConverter, context, benefit),
181184 targetVectorBitWidth (targetVectBitWidth) {}
@@ -289,7 +292,7 @@ struct LinearizeVectorShuffle final
289292 using OpConversionPattern::OpConversionPattern;
290293 LinearizeVectorShuffle (
291294 const TypeConverter &typeConverter, MLIRContext *context,
292- unsigned targetVectBitWidth = std::numeric_limits< unsigned >::max() ,
295+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth ,
293296 PatternBenefit benefit = 1 )
294297 : OpConversionPattern(typeConverter, context, benefit),
295298 targetVectorBitWidth (targetVectBitWidth) {}
@@ -362,13 +365,17 @@ struct LinearizeVectorExtract final
362365 using OpConversionPattern::OpConversionPattern;
363366 LinearizeVectorExtract (
364367 const TypeConverter &typeConverter, MLIRContext *context,
365- unsigned targetVectBitWidth = std::numeric_limits< unsigned >::max() ,
368+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth ,
366369 PatternBenefit benefit = 1 )
367370 : OpConversionPattern(typeConverter, context, benefit),
368371 targetVectorBitWidth (targetVectBitWidth) {}
369372 LogicalResult
370373 matchAndRewrite (vector::ExtractOp extractOp, OpAdaptor adaptor,
371374 ConversionPatternRewriter &rewriter) const override {
375+ // Skip if result is not a vector type
376+ if (!isa<VectorType>(extractOp.getType ()))
377+ return rewriter.notifyMatchFailure (extractOp,
378+ " scalar extract is not supported." );
372379 Type dstTy = getTypeConverter ()->convertType (extractOp.getType ());
373380 if (!dstTy)
374381 return rewriter.notifyMatchFailure (extractOp,
@@ -425,7 +432,7 @@ struct LinearizeVectorInsert final
425432 using OpConversionPattern::OpConversionPattern;
426433 LinearizeVectorInsert (
427434 const TypeConverter &typeConverter, MLIRContext *context,
428- unsigned targetVectBitWidth = std::numeric_limits< unsigned >::max() ,
435+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth ,
429436 PatternBenefit benefit = 1 )
430437 : OpConversionPattern(typeConverter, context, benefit),
431438 targetVectorBitWidth (targetVectBitWidth) {}
@@ -506,7 +513,7 @@ struct LinearizeVectorBitCast final
506513 using OpConversionPattern::OpConversionPattern;
507514 LinearizeVectorBitCast (
508515 const TypeConverter &typeConverter, MLIRContext *context,
509- unsigned targetVectBitWidth = std::numeric_limits< unsigned >::max() ,
516+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth ,
510517 PatternBenefit benefit = 1 )
511518 : OpConversionPattern(typeConverter, context, benefit),
512519 targetVectorBitWidth (targetVectBitWidth) {}
@@ -531,12 +538,48 @@ struct LinearizeVectorBitCast final
531538 unsigned targetVectorBitWidth;
532539};
533540
541+ // / This pattern converts the SplatOp to work on a linearized vector.
542+ // / Following,
543+ // / vector.splat %value : vector<4x4xf32>
544+ // / is converted to:
545+ // / %out_1d = vector.splat %value : vector<16xf32>
546+ // / %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
547+ // / It ensures that the operation is compatible with the target vector
548+ // / bit width and replaces the original operation with a new SplatOp
549+ // / that operates on the converted type.
550+ struct LinearizeVectorSplat final
551+ : public OpConversionPattern<vector::SplatOp> {
552+ using OpConversionPattern::OpConversionPattern;
553+
554+ LinearizeVectorSplat (
555+ const TypeConverter &typeConverter, MLIRContext *context,
556+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
557+ PatternBenefit benefit = 1 )
558+ : OpConversionPattern(typeConverter, context, benefit),
559+ targetVectorBitWidth (targetVectBitWidth) {}
560+
561+ LogicalResult
562+ matchAndRewrite (vector::SplatOp splatOp, OpAdaptor adaptor,
563+ ConversionPatternRewriter &rewriter) const override {
564+ auto dstTy = getTypeConverter ()->convertType (splatOp.getType ());
565+ if (!dstTy)
566+ return rewriter.notifyMatchFailure (splatOp, " cannot convert type." );
567+ rewriter.replaceOpWithNewOp <vector::SplatOp>(splatOp, adaptor.getInput (),
568+ dstTy);
569+ return success ();
570+ }
571+
572+ private:
573+ unsigned targetVectorBitWidth;
574+ };
575+
534576} // namespace
535577
536578void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality (
537579 TypeConverter &typeConverter, RewritePatternSet &patterns,
538580 ConversionTarget &target, unsigned targetBitWidth) {
539581
582+ typeConverter.addConversion ([](Type type) -> Type { return type; });
540583 typeConverter.addConversion ([](VectorType type) -> std::optional<Type> {
541584 if (!isLinearizableVector (type))
542585 return type;
@@ -557,7 +600,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
557600 typeConverter.addTargetMaterialization (materializeCast);
558601 target.markUnknownOpDynamicallyLegal (
559602 [=](Operation *op) -> std::optional<bool > {
560- if ((isa<vector::BitCastOp>(op) ||
603+ if ((isa<vector::BitCastOp, vector::SplatOp >(op) ||
561604 op->hasTrait <OpTrait::ConstantLike>() ||
562605 op->hasTrait <OpTrait::Vectorizable>())) {
563606 return (isLessThanTargetBitWidth (op, targetBitWidth)
@@ -568,8 +611,8 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
568611 });
569612
570613 patterns.add <LinearizeConstantLike, LinearizeVectorizable,
571- LinearizeVectorBitCast>(typeConverter, patterns. getContext (),
572- targetBitWidth);
614+ LinearizeVectorBitCast, LinearizeVectorSplat>(
615+ typeConverter, patterns. getContext (), targetBitWidth);
573616}
574617
575618void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns (
0 commit comments