@@ -1373,7 +1373,7 @@ struct VectorScalableExtractOpLowering
13731373// / ```
13741374// / is rewritten into:
13751375// / ```
1376- // / %r = splat %f0: vector<2x4xf32>
1376+ // / %r = vector.broadcast %f0 : f32 to vector<2x4xf32>
13771377// / %va = vector.extractvalue %a[0] : vector<2x4xf32>
13781378// / %vb = vector.extractvalue %b[0] : vector<2x4xf32>
13791379// / %vc = vector.extractvalue %c[0] : vector<2x4xf32>
@@ -1406,7 +1406,7 @@ class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
14061406 auto elemType = vType.getElementType ();
14071407 Value zero = rewriter.create <arith::ConstantOp>(
14081408 loc, elemType, rewriter.getZeroAttr (elemType));
1409- Value desc = rewriter.create <vector::SplatOp >(loc, vType, zero);
1409+ Value desc = rewriter.create <vector::BroadcastOp >(loc, vType, zero);
14101410 for (int64_t i = 0 , e = vType.getShape ().front (); i != e; ++i) {
14111411 Value extrLHS = rewriter.create <ExtractOp>(loc, op.getLhs (), i);
14121412 Value extrRHS = rewriter.create <ExtractOp>(loc, op.getRhs (), i);
@@ -1548,7 +1548,7 @@ class VectorCreateMaskOpConversion
15481548 /* isScalable=*/ true ));
15491549 auto bound = getValueOrCreateCastToIndexLike (rewriter, loc, idxType,
15501550 adaptor.getOperands ()[0 ]);
1551- Value bounds = rewriter.create <SplatOp >(loc, indices.getType (), bound);
1551+ Value bounds = rewriter.create <BroadcastOp >(loc, indices.getType (), bound);
15521552 Value comp = rewriter.create <arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
15531553 indices, bounds);
15541554 rewriter.replaceOp (op, comp);
@@ -1732,63 +1732,77 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
17321732 }
17331733};
17341734
1735- // / The Splat operation is lowered to an insertelement + a shufflevector
1736- // / operation. Splat to only 0-d and 1-d vector result types are lowered.
1737- struct VectorSplatOpLowering : public ConvertOpToLLVMPattern <vector::SplatOp> {
1738- using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern;
1735+ // / A broadcast of a scalar is lowered to an insertelement + a shufflevector
1736+ // / operation. Only broadcasts to 0-d and 1-d vectors are lowered by this
1737+ // / pattern, the higher rank cases are handled by another pattern.
1738+ struct VectorBroadcastScalarToLowRankLowering
1739+ : public ConvertOpToLLVMPattern<vector::BroadcastOp> {
1740+ using ConvertOpToLLVMPattern<vector::BroadcastOp>::ConvertOpToLLVMPattern;
17391741
17401742 LogicalResult
1741- matchAndRewrite (vector::SplatOp splatOp , OpAdaptor adaptor,
1743+ matchAndRewrite (vector::BroadcastOp broadcast , OpAdaptor adaptor,
17421744 ConversionPatternRewriter &rewriter) const override {
1743- VectorType resultType = cast<VectorType>(splatOp.getType ());
1745+ if (isa<VectorType>(broadcast.getSourceType ()))
1746+ return rewriter.notifyMatchFailure (
1747+ broadcast, " broadcast from vector type not handled" );
1748+
1749+ VectorType resultType = broadcast.getType ();
17441750 if (resultType.getRank () > 1 )
1745- return failure ();
1751+ return rewriter.notifyMatchFailure (broadcast,
1752+ " broadcast to 2+-d handled elsewhere" );
17461753
17471754 // First insert it into a poison vector so we can shuffle it.
1748- auto vectorType = typeConverter->convertType (splatOp .getType ());
1755+ auto vectorType = typeConverter->convertType (broadcast .getType ());
17491756 Value poison =
1750- rewriter.create <LLVM::PoisonOp>(splatOp .getLoc (), vectorType);
1757+ rewriter.create <LLVM::PoisonOp>(broadcast .getLoc (), vectorType);
17511758 auto zero = rewriter.create <LLVM::ConstantOp>(
1752- splatOp .getLoc (),
1759+ broadcast .getLoc (),
17531760 typeConverter->convertType (rewriter.getIntegerType (32 )),
17541761 rewriter.getZeroAttr (rewriter.getIntegerType (32 )));
17551762
17561763 // For 0-d vector, we simply do `insertelement`.
17571764 if (resultType.getRank () == 0 ) {
17581765 rewriter.replaceOpWithNewOp <LLVM::InsertElementOp>(
1759- splatOp , vectorType, poison, adaptor.getInput (), zero);
1766+ broadcast , vectorType, poison, adaptor.getSource (), zero);
17601767 return success ();
17611768 }
17621769
17631770 // For 1-d vector, we additionally do a `vectorshuffle`.
17641771 auto v = rewriter.create <LLVM::InsertElementOp>(
1765- splatOp .getLoc (), vectorType, poison, adaptor.getInput (), zero);
1772+ broadcast .getLoc (), vectorType, poison, adaptor.getSource (), zero);
17661773
1767- int64_t width = cast<VectorType>(splatOp .getType ()).getDimSize (0 );
1774+ int64_t width = cast<VectorType>(broadcast .getType ()).getDimSize (0 );
17681775 SmallVector<int32_t > zeroValues (width, 0 );
17691776
17701777 // Shuffle the value across the desired number of elements.
1771- rewriter.replaceOpWithNewOp <LLVM::ShuffleVectorOp>(splatOp , v, poison,
1778+ rewriter.replaceOpWithNewOp <LLVM::ShuffleVectorOp>(broadcast , v, poison,
17721779 zeroValues);
17731780 return success ();
17741781 }
17751782};
17761783
1777- // / The Splat operation is lowered to an insertelement + a shufflevector
1778- // / operation. Splat to only 2+-d vector result types are lowered by the
1779- // / SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
1780- struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern <SplatOp> {
1781- using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
1784+ // / The broadcast of a scalar is lowered to an insertelement + a shufflevector
1785+ // / operation. Only broadcasts to 2+-d vector result types are lowered by this
1786+ // / pattern, the 1-d case is handled by another pattern. Broadcasts from vectors
1787+ // / are not converted to LLVM, only broadcasts from scalars are.
1788+ struct VectorBroadcastScalarToNdLowering
1789+ : public ConvertOpToLLVMPattern<BroadcastOp> {
1790+ using ConvertOpToLLVMPattern<BroadcastOp>::ConvertOpToLLVMPattern;
17821791
17831792 LogicalResult
1784- matchAndRewrite (SplatOp splatOp , OpAdaptor adaptor,
1793+ matchAndRewrite (BroadcastOp broadcast , OpAdaptor adaptor,
17851794 ConversionPatternRewriter &rewriter) const override {
1786- VectorType resultType = splatOp.getType ();
1795+ if (isa<VectorType>(broadcast.getSourceType ()))
1796+ return rewriter.notifyMatchFailure (
1797+ broadcast, " broadcast from vector type not handled" );
1798+
1799+ VectorType resultType = broadcast.getType ();
17871800 if (resultType.getRank () <= 1 )
1788- return failure ();
1801+ return rewriter.notifyMatchFailure (
1802+ broadcast, " broadcast to 1-d or 0-d handled elsewhere" );
17891803
17901804 // First insert it into an undef vector so we can shuffle it.
1791- auto loc = splatOp .getLoc ();
1805+ auto loc = broadcast .getLoc ();
17921806 auto vectorTypeInfo =
17931807 LLVM::detail::extractNDVectorTypeInfo (resultType, *getTypeConverter ());
17941808 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy ;
@@ -1799,26 +1813,26 @@ struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
17991813 // Construct returned value.
18001814 Value desc = rewriter.create <LLVM::PoisonOp>(loc, llvmNDVectorTy);
18011815
1802- // Construct a 1-D vector with the splatted value that we insert in all the
1803- // places within the returned descriptor.
1816+ // Construct a 1-D vector with the broadcasted value that we insert in all
1817+ // the places within the returned descriptor.
18041818 Value vdesc = rewriter.create <LLVM::PoisonOp>(loc, llvm1DVectorTy);
18051819 auto zero = rewriter.create <LLVM::ConstantOp>(
18061820 loc, typeConverter->convertType (rewriter.getIntegerType (32 )),
18071821 rewriter.getZeroAttr (rewriter.getIntegerType (32 )));
18081822 Value v = rewriter.create <LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
1809- adaptor.getInput (), zero);
1823+ adaptor.getSource (), zero);
18101824
18111825 // Shuffle the value across the desired number of elements.
18121826 int64_t width = resultType.getDimSize (resultType.getRank () - 1 );
18131827 SmallVector<int32_t > zeroValues (width, 0 );
18141828 v = rewriter.create <LLVM::ShuffleVectorOp>(loc, v, v, zeroValues);
18151829
1816- // Iterate of linear index, convert to coords space and insert splatted 1-D
1817- // vector in each position.
1830+ // Iterate of linear index, convert to coords space and insert broadcasted
1831+ // 1-D vector in each position.
18181832 nDVectorIterate (vectorTypeInfo, rewriter, [&](ArrayRef<int64_t > position) {
18191833 desc = rewriter.create <LLVM::InsertValueOp>(loc, desc, v, position);
18201834 });
1821- rewriter.replaceOp (splatOp , desc);
1835+ rewriter.replaceOp (broadcast , desc);
18221836 return success ();
18231837 }
18241838};
@@ -2177,6 +2191,19 @@ class TransposeOpToMatrixTransposeOpLowering
21772191 }
21782192};
21792193
2194+ // / Convert `vector.splat` to `vector.broadcast`. There is a path to LLVM from
2195+ // / `vector.broadcast` through other patterns.
2196+ struct VectorSplatToBroadcast : public ConvertOpToLLVMPattern <vector::SplatOp> {
2197+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
2198+ LogicalResult
2199+ matchAndRewrite (vector::SplatOp splat, OpAdaptor adaptor,
2200+ ConversionPatternRewriter &rewriter) const override {
2201+ rewriter.replaceOpWithNewOp <vector::BroadcastOp>(splat, splat.getType (),
2202+ adaptor.getInput ());
2203+ return success ();
2204+ }
2205+ };
2206+
21802207} // namespace
21812208
21822209void mlir::vector::populateVectorRankReducingFMAPattern (
@@ -2216,7 +2243,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
22162243 VectorInsertOpConversion, VectorPrintOpConversion,
22172244 VectorTypeCastOpConversion, VectorScaleOpConversion,
22182245 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
2219- VectorSplatOpLowering, VectorSplatNdOpLowering,
2246+ VectorSplatToBroadcast, VectorBroadcastScalarToLowRankLowering,
2247+ VectorBroadcastScalarToNdLowering,
22202248 VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
22212249 MaskedReductionOpConversion, VectorInterleaveOpLowering,
22222250 VectorDeinterleaveOpLowering, VectorFromElementsLowering,
0 commit comments