2626
2727using namespace mlir ;
2828
29+ constexpr unsigned defaultTargetVectorBitWidth =
30+ std::numeric_limits<unsigned >::max();
31+
2932static bool isLessThanTargetBitWidth (Operation *op, unsigned targetBitWidth) {
33+ if (targetBitWidth == 0 )
34+ return false ;
3035 auto resultTypes = op->getResultTypes ();
3136 for (auto resType : resultTypes) {
3237 VectorType vecType = dyn_cast<VectorType>(resType);
@@ -82,7 +87,7 @@ struct LinearizeConstantLike final
8287
8388 LinearizeConstantLike (
8489 const TypeConverter &typeConverter, MLIRContext *context,
85- unsigned targetVectBitWidth = std::numeric_limits< unsigned >::max() ,
90+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth ,
8691 PatternBenefit benefit = 1 )
8792 : OpTraitConversionPattern(typeConverter, context, benefit),
8893 targetVectorBitWidth (targetVectBitWidth) {}
@@ -136,7 +141,7 @@ struct LinearizeVectorizable final
136141public:
137142 LinearizeVectorizable (
138143 const TypeConverter &typeConverter, MLIRContext *context,
139- unsigned targetVectBitWidth = std::numeric_limits< unsigned >::max() ,
144+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth ,
140145 PatternBenefit benefit = 1 )
141146 : OpTraitConversionPattern(typeConverter, context, benefit),
142147 targetVectorBitWidth (targetVectBitWidth) {}
@@ -175,7 +180,7 @@ struct LinearizeVectorExtractStridedSlice final
175180 using OpConversionPattern::OpConversionPattern;
176181 LinearizeVectorExtractStridedSlice (
177182 const TypeConverter &typeConverter, MLIRContext *context,
178- unsigned targetVectBitWidth = std::numeric_limits< unsigned >::max() ,
183+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth ,
179184 PatternBenefit benefit = 1 )
180185 : OpConversionPattern(typeConverter, context, benefit),
181186 targetVectorBitWidth (targetVectBitWidth) {}
@@ -289,7 +294,7 @@ struct LinearizeVectorShuffle final
289294 using OpConversionPattern::OpConversionPattern;
290295 LinearizeVectorShuffle (
291296 const TypeConverter &typeConverter, MLIRContext *context,
292- unsigned targetVectBitWidth = std::numeric_limits< unsigned >::max() ,
297+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth ,
293298 PatternBenefit benefit = 1 )
294299 : OpConversionPattern(typeConverter, context, benefit),
295300 targetVectorBitWidth (targetVectBitWidth) {}
@@ -362,13 +367,17 @@ struct LinearizeVectorExtract final
362367 using OpConversionPattern::OpConversionPattern;
363368 LinearizeVectorExtract (
364369 const TypeConverter &typeConverter, MLIRContext *context,
365- unsigned targetVectBitWidth = std::numeric_limits< unsigned >::max() ,
370+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth ,
366371 PatternBenefit benefit = 1 )
367372 : OpConversionPattern(typeConverter, context, benefit),
368373 targetVectorBitWidth (targetVectBitWidth) {}
369374 LogicalResult
370375 matchAndRewrite (vector::ExtractOp extractOp, OpAdaptor adaptor,
371376 ConversionPatternRewriter &rewriter) const override {
377+ // Skip if result is not a vector type
378+ if (!isa<VectorType>(extractOp.getType ()))
379+ return rewriter.notifyMatchFailure (extractOp,
380+ " scalar extract is not supported." );
372381 Type dstTy = getTypeConverter ()->convertType (extractOp.getType ());
373382 if (!dstTy)
374383 return rewriter.notifyMatchFailure (extractOp,
@@ -425,7 +434,7 @@ struct LinearizeVectorInsert final
425434 using OpConversionPattern::OpConversionPattern;
426435 LinearizeVectorInsert (
427436 const TypeConverter &typeConverter, MLIRContext *context,
428- unsigned targetVectBitWidth = std::numeric_limits< unsigned >::max() ,
437+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth ,
429438 PatternBenefit benefit = 1 )
430439 : OpConversionPattern(typeConverter, context, benefit),
431440 targetVectorBitWidth (targetVectBitWidth) {}
@@ -506,7 +515,7 @@ struct LinearizeVectorBitCast final
506515 using OpConversionPattern::OpConversionPattern;
507516 LinearizeVectorBitCast (
508517 const TypeConverter &typeConverter, MLIRContext *context,
509- unsigned targetVectBitWidth = std::numeric_limits< unsigned >::max() ,
518+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth ,
510519 PatternBenefit benefit = 1 )
511520 : OpConversionPattern(typeConverter, context, benefit),
512521 targetVectorBitWidth (targetVectBitWidth) {}
@@ -531,12 +540,139 @@ struct LinearizeVectorBitCast final
531540 unsigned targetVectorBitWidth;
532541};
533542
543+ // clang-format off
544+ // / This pattern converts the LoadOp to a series of LoadOp & InsertOp
545+ // / that works on a linearized vector.
546+ // / Following,
547+ // / vector.load %base[%indices] : vector<4x4xf32>
548+ // / is converted to :
549+ // / %result = arith.constant dense<0.0> : vector<4x4xf32>
550+ // / %slice_0 = vector.load %base[%indices] : vector<4xf32>
551+ // / %result_0 = vector.insert %slice_0, %result[0] : vector<4xf32> into vector<4x4xf32>
552+ // / %slice_1 = vector.load %base[%indices + 1] : vector<4xf32>
553+ // / %result_1 = vector.insert %slice_1, %result_0[1] : vector<4xf32> into vector<4x4xf32>
554+ // / ...
555+ // / This unrolls the 2D vector load into multiple 1D vector loads and inserts
556+ // / them into the result vector. The pattern currently supports only 2D vectors
557+ // clang-format on
558+ struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
559+ using OpConversionPattern<vector::LoadOp>::OpConversionPattern;
560+
561+ LinearizeVectorLoad (const TypeConverter &typeConverter, MLIRContext *context,
562+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
563+ PatternBenefit benefit = 1 )
564+ : OpConversionPattern(typeConverter, context, benefit),
565+ targetVectorBitWidth (targetVectBitWidth) {}
566+
567+ LogicalResult
568+ matchAndRewrite (vector::LoadOp loadOp, OpAdaptor adaptor,
569+ ConversionPatternRewriter &rewriter) const override {
570+ auto loc = loadOp->getLoc ();
571+ VectorType vecType = loadOp.getVectorType ();
572+ auto shape = vecType.getShape ();
573+
574+ if (shape.size () != 2 )
575+ return rewriter.notifyMatchFailure (loc, " Can only linearize 2D vectors." );
576+
577+ auto unrollCount = shape[0 ];
578+ auto vecSize = shape[1 ];
579+ VectorType newVecType =
580+ VectorType::get ({vecSize}, vecType.getElementType ());
581+
582+ llvm::SmallVector<Value, 4 > indices = adaptor.getIndices ();
583+ Value xBaseIndex = indices[0 ];
584+
585+ // Construct the 2D vector.
586+ Value resultVec =
587+ rewriter.create <arith::ConstantOp>(loc, rewriter.getZeroAttr (vecType));
588+ // Emit unrolled loads for each 1D vector slice.
589+ for (auto i = 0 ; i < unrollCount; i++) {
590+ Value xIndex = xBaseIndex;
591+ if (i) {
592+ auto increment = rewriter.create <arith::ConstantIndexOp>(loc, i);
593+ xIndex = rewriter.create <arith::AddIOp>(loc, xBaseIndex, increment);
594+ }
595+ indices[0 ] = xIndex;
596+ auto vec = rewriter.create <vector::LoadOp>(loc, newVecType,
597+ adaptor.getBase (), indices);
598+ resultVec = rewriter.create <vector::InsertOp>(loc, vec, resultVec, i);
599+ }
600+
601+ rewriter.replaceOp (loadOp, resultVec);
602+ return success ();
603+ }
604+
605+ private:
606+ unsigned targetVectorBitWidth;
607+ };
608+
609+ // / This pattern converts the StoreOp to a series of StoreOp & ExtractOp
610+ // / that works on a linearized vector.
611+ // / Following,
612+ // / vector.store %source, %base[%indices] : vector<4x4xf32>
613+ // / is converted to :
614+ // / %slice_0 = vector.extract %source[0] : vector<4xf32>
615+ // / vector.store %slice_0, %base[%indices] : vector<4xf32>
616+ // / %slice_1 = vector.extract %source[1] : vector<4xf32>
617+ // / vector.store %slice_1, %base[%indices + 1] : vector<4xf32>
618+ // / ...
619+ // / This unrolls the 2D vector store into multiple 1D vector stores by
620+ // / extracting slices from the source vector and storing them into the
621+ // / destination. The pattern currently supports only 2D vectors
622+ struct LinearizeVectorStore final
623+ : public OpConversionPattern<vector::StoreOp> {
624+ using OpConversionPattern<vector::StoreOp>::OpConversionPattern;
625+
626+ LinearizeVectorStore (
627+ const TypeConverter &typeConverter, MLIRContext *context,
628+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
629+ PatternBenefit benefit = 1 )
630+ : OpConversionPattern(typeConverter, context, benefit),
631+ targetVectorBitWidth (targetVectBitWidth) {}
632+
633+ LogicalResult
634+ matchAndRewrite (vector::StoreOp storeOp, OpAdaptor adaptor,
635+ ConversionPatternRewriter &rewriter) const override {
636+ auto loc = storeOp->getLoc ();
637+ VectorType vecType = storeOp.getVectorType ();
638+ auto shape = vecType.getShape ();
639+
640+ if (shape.size () != 2 )
641+ return rewriter.notifyMatchFailure (loc, " Can only linearize 2D vectors." );
642+
643+ auto unrollCount = shape[0 ];
644+ llvm::SmallVector<Value, 4 > indices = adaptor.getIndices ();
645+ Value xBaseIndex = indices[0 ];
646+
647+ auto vec = rewriter.create <vector::ShapeCastOp>(loc, vecType,
648+ adaptor.getValueToStore ());
649+
650+ for (auto i = 0 ; i < unrollCount; i++) {
651+ auto vecSlice = rewriter.create <vector::ExtractOp>(loc, vec, i);
652+ Value xIndex = xBaseIndex;
653+ if (i) {
654+ auto increment = rewriter.create <arith::ConstantIndexOp>(loc, i);
655+ xIndex = rewriter.create <arith::AddIOp>(loc, xBaseIndex, increment);
656+ }
657+ indices[0 ] = xIndex;
658+ rewriter.create <vector::StoreOp>(loc, vecSlice, adaptor.getBase (),
659+ indices);
660+ }
661+ rewriter.eraseOp (storeOp);
662+ return success ();
663+ }
664+
665+ private:
666+ unsigned targetVectorBitWidth;
667+ };
668+
534669} // namespace
535670
536671void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality (
537672 TypeConverter &typeConverter, RewritePatternSet &patterns,
538673 ConversionTarget &target, unsigned targetBitWidth) {
539674
675+ typeConverter.addConversion ([](Type type) -> Type { return type; });
540676 typeConverter.addConversion ([](VectorType type) -> std::optional<Type> {
541677 if (!isLinearizableVector (type))
542678 return type;
@@ -555,9 +691,10 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
555691 };
556692 typeConverter.addSourceMaterialization (materializeCast);
557693 typeConverter.addTargetMaterialization (materializeCast);
694+ target.addLegalOp <vector::ShapeCastOp>();
558695 target.markUnknownOpDynamicallyLegal (
559696 [=](Operation *op) -> std::optional<bool > {
560- if ((isa<vector::BitCastOp>(op) ||
697+ if ((isa<vector::BitCastOp, vector::LoadOp, vector::StoreOp >(op) ||
561698 op->hasTrait <OpTrait::ConstantLike>() ||
562699 op->hasTrait <OpTrait::Vectorizable>())) {
563700 return (isLessThanTargetBitWidth (op, targetBitWidth)
@@ -567,9 +704,10 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
567704 return std::nullopt ;
568705 });
569706
570- patterns.add <LinearizeConstantLike, LinearizeVectorizable,
571- LinearizeVectorBitCast>(typeConverter, patterns.getContext (),
572- targetBitWidth);
707+ patterns
708+ .add <LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
709+ LinearizeVectorLoad, LinearizeVectorStore>(
710+ typeConverter, patterns.getContext (), targetBitWidth);
573711}
574712
575713void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns (
0 commit comments