@@ -817,6 +817,50 @@ struct LinearizeVectorToElements final
817817 }
818818};
819819
820+ // / Convert broadcasts from scalars or 1-element vectors, such as
821+ // /
822+ // / ```mlir
823+ // / vector.broadcast %value : f32 to vector<4x4xf32>
824+ // / ```
825+ // /
826+ // / to broadcasts to rank-1 vectors, with shape_casts before/after as needed.
827+ // / The above becomes,
828+ // /
829+ // / ```mlir
830+ // / %out_1d = vector.broadcast %value : f32 to vector<16xf32>
831+ // / %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
832+ // / ```
833+ struct LinearizeVectorBroadcast final
834+ : public OpConversionPattern<vector::BroadcastOp> {
835+ using Base::Base;
836+
837+ LinearizeVectorBroadcast (const TypeConverter &typeConverter,
838+ MLIRContext *context, PatternBenefit benefit = 1 )
839+ : OpConversionPattern(typeConverter, context, benefit) {}
840+
841+ LogicalResult
842+ matchAndRewrite (vector::BroadcastOp broadcastOp, OpAdaptor adaptor,
843+ ConversionPatternRewriter &rewriter) const override {
844+
845+ int numElements = 1 ;
846+ Type sourceType = broadcastOp.getSourceType ();
847+ if (auto vecType = dyn_cast<VectorType>(sourceType)) {
848+ numElements = vecType.getNumElements ();
849+ }
850+
851+ if (numElements != 1 ) {
852+ return rewriter.notifyMatchFailure (
853+ broadcastOp, " only broadcasts of single elements can be linearized." );
854+ }
855+
856+ auto dstTy = getTypeConverter ()->convertType (broadcastOp.getType ());
857+ rewriter.replaceOpWithNewOp <vector::BroadcastOp>(broadcastOp, dstTy,
858+ adaptor.getSource ());
859+
860+ return success ();
861+ }
862+ };
863+
820864} // namespace
821865
822866// / This method defines the set of operations that are linearizable, and hence
@@ -909,8 +953,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
909953 patterns
910954 .add <LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
911955 LinearizeVectorCreateMask, LinearizeVectorLoad, LinearizeVectorStore,
912- LinearizeVectorFromElements, LinearizeVectorToElements>(
913- typeConverter, patterns.getContext ());
956+ LinearizeVectorBroadcast, LinearizeVectorFromElements,
957+ LinearizeVectorToElements>( typeConverter, patterns.getContext ());
914958}
915959
916960void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns (
0 commit comments