@@ -1408,7 +1408,7 @@ struct VectorScalableExtractOpLowering
14081408// / ```
14091409// / is rewritten into:
14101410// / ```
1411- // / %r = splat %f0: vector<2x4xf32>
1411+ // / %r = vector.broadcast %f0 : f32 to vector<2x4xf32>
14121412// / %va = vector.extractvalue %a[0] : vector<2x4xf32>
14131413// / %vb = vector.extractvalue %b[0] : vector<2x4xf32>
14141414// / %vc = vector.extractvalue %c[0] : vector<2x4xf32>
@@ -1441,7 +1441,7 @@ class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
14411441 auto elemType = vType.getElementType ();
14421442 Value zero = rewriter.create <arith::ConstantOp>(
14431443 loc, elemType, rewriter.getZeroAttr (elemType));
1444- Value desc = rewriter.create <vector::SplatOp >(loc, vType, zero);
1444+ Value desc = rewriter.create <vector::BroadcastOp >(loc, vType, zero);
14451445 for (int64_t i = 0 , e = vType.getShape ().front (); i != e; ++i) {
14461446 Value extrLHS = rewriter.create <ExtractOp>(loc, op.getLhs (), i);
14471447 Value extrRHS = rewriter.create <ExtractOp>(loc, op.getRhs (), i);
@@ -1583,7 +1583,7 @@ class VectorCreateMaskOpConversion
15831583 /* isScalable=*/ true ));
15841584 auto bound = getValueOrCreateCastToIndexLike (rewriter, loc, idxType,
15851585 adaptor.getOperands ()[0 ]);
1586- Value bounds = rewriter.create <SplatOp >(loc, indices.getType (), bound);
1586+ Value bounds = rewriter.create <BroadcastOp >(loc, indices.getType (), bound);
15871587 Value comp = rewriter.create <arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
15881588 indices, bounds);
15891589 rewriter.replaceOp (op, comp);
@@ -1767,63 +1767,79 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
17671767 }
17681768};
17691769
1770- // / The Splat operation is lowered to an insertelement + a shufflevector
1771- // / operation. Splat to only 0-d and 1-d vector result types are lowered.
1772- struct VectorSplatOpLowering : public ConvertOpToLLVMPattern <vector::SplatOp> {
1773- using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern;
1770+ // / A broadcast of a scalar is lowered to an insertelement + a shufflevector
1771+ // / operation. Only broadcasts to 0-d and 1-d vectors are lowered by this
1772+ // / pattern, the higher rank cases are handled by another pattern.
1773+ struct VectorBroadcastScalarToLowRankLowering
1774+ : public ConvertOpToLLVMPattern<vector::BroadcastOp> {
1775+ using ConvertOpToLLVMPattern<vector::BroadcastOp>::ConvertOpToLLVMPattern;
17741776
17751777 LogicalResult
1776- matchAndRewrite (vector::SplatOp splatOp , OpAdaptor adaptor,
1778+ matchAndRewrite (vector::BroadcastOp broadcast , OpAdaptor adaptor,
17771779 ConversionPatternRewriter &rewriter) const override {
1778- VectorType resultType = cast<VectorType>(splatOp.getType ());
1780+
1781+ if (isa<VectorType>(broadcast.getSourceType ()))
1782+ return rewriter.notifyMatchFailure (
1783+ broadcast, " broadcast from vector type not handled" );
1784+
1785+ VectorType resultType = broadcast.getType ();
17791786 if (resultType.getRank () > 1 )
1780- return failure ();
1787+ return rewriter.notifyMatchFailure (broadcast,
1788+ " broadcast to 2+-d handled elsewhere" );
17811789
17821790 // First insert it into a poison vector so we can shuffle it.
1783- auto vectorType = typeConverter->convertType (splatOp .getType ());
1791+ auto vectorType = typeConverter->convertType (broadcast .getType ());
17841792 Value poison =
1785- rewriter.create <LLVM::PoisonOp>(splatOp .getLoc (), vectorType);
1793+ rewriter.create <LLVM::PoisonOp>(broadcast .getLoc (), vectorType);
17861794 auto zero = rewriter.create <LLVM::ConstantOp>(
1787- splatOp .getLoc (),
1795+ broadcast .getLoc (),
17881796 typeConverter->convertType (rewriter.getIntegerType (32 )),
17891797 rewriter.getZeroAttr (rewriter.getIntegerType (32 )));
17901798
17911799 // For 0-d vector, we simply do `insertelement`.
17921800 if (resultType.getRank () == 0 ) {
17931801 rewriter.replaceOpWithNewOp <LLVM::InsertElementOp>(
1794- splatOp , vectorType, poison, adaptor.getInput (), zero);
1802+ broadcast , vectorType, poison, adaptor.getSource (), zero);
17951803 return success ();
17961804 }
17971805
17981806 // For 1-d vector, we additionally do a `vectorshuffle`.
17991807 auto v = rewriter.create <LLVM::InsertElementOp>(
1800- splatOp .getLoc (), vectorType, poison, adaptor.getInput (), zero);
1808+ broadcast .getLoc (), vectorType, poison, adaptor.getSource (), zero);
18011809
1802- int64_t width = cast<VectorType>(splatOp .getType ()).getDimSize (0 );
1810+ int64_t width = cast<VectorType>(broadcast .getType ()).getDimSize (0 );
18031811 SmallVector<int32_t > zeroValues (width, 0 );
18041812
18051813 // Shuffle the value across the desired number of elements.
1806- rewriter.replaceOpWithNewOp <LLVM::ShuffleVectorOp>(splatOp , v, poison,
1814+ rewriter.replaceOpWithNewOp <LLVM::ShuffleVectorOp>(broadcast , v, poison,
18071815 zeroValues);
18081816 return success ();
18091817 }
18101818};
18111819
1812- // / The Splat operation is lowered to an insertelement + a shufflevector
1813- // / operation. Splat to only 2+-d vector result types are lowered by the
1814- // / SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
1815- struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern <SplatOp> {
1816- using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
1820+ // / The broadcast of a scalar is lowered to an insertelement + a shufflevector
1821+ // / operation. Only broadcasts to 2+-d vector result types are lowered by this
1822+ // / pattern, the 1-d case is handled by another pattern. Broadcasts from vectors
1823+ // / are not converted to LLVM, only broadcasts from scalars are.
1824+ struct VectorBroadcastScalarToNdLowering
1825+ : public ConvertOpToLLVMPattern<BroadcastOp> {
1826+ using ConvertOpToLLVMPattern<BroadcastOp>::ConvertOpToLLVMPattern;
18171827
18181828 LogicalResult
1819- matchAndRewrite (SplatOp splatOp , OpAdaptor adaptor,
1829+ matchAndRewrite (BroadcastOp broadcast , OpAdaptor adaptor,
18201830 ConversionPatternRewriter &rewriter) const override {
1821- VectorType resultType = splatOp.getType ();
1831+
1832+ if (isa<VectorType>(broadcast.getSourceType ()))
1833+ return rewriter.notifyMatchFailure (
1834+ broadcast, " broadcast from vector type not handled" );
1835+
1836+ VectorType resultType = broadcast.getType ();
18221837 if (resultType.getRank () <= 1 )
1823- return failure ();
1838+ return rewriter.notifyMatchFailure (
1839+ broadcast, " broadcast to 1-d or 0-d handled elsewhere" );
18241840
18251841 // First insert it into an undef vector so we can shuffle it.
1826- auto loc = splatOp .getLoc ();
1842+ auto loc = broadcast .getLoc ();
18271843 auto vectorTypeInfo =
18281844 LLVM::detail::extractNDVectorTypeInfo (resultType, *getTypeConverter ());
18291845 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy ;
@@ -1834,26 +1850,26 @@ struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
18341850 // Construct returned value.
18351851 Value desc = rewriter.create <LLVM::PoisonOp>(loc, llvmNDVectorTy);
18361852
1837- // Construct a 1-D vector with the splatted value that we insert in all the
1838- // places within the returned descriptor.
1853+ // Construct a 1-D vector with the broadcasted value that we insert in all
1854+ // the places within the returned descriptor.
18391855 Value vdesc = rewriter.create <LLVM::PoisonOp>(loc, llvm1DVectorTy);
18401856 auto zero = rewriter.create <LLVM::ConstantOp>(
18411857 loc, typeConverter->convertType (rewriter.getIntegerType (32 )),
18421858 rewriter.getZeroAttr (rewriter.getIntegerType (32 )));
18431859 Value v = rewriter.create <LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
1844- adaptor.getInput (), zero);
1860+ adaptor.getSource (), zero);
18451861
18461862 // Shuffle the value across the desired number of elements.
18471863 int64_t width = resultType.getDimSize (resultType.getRank () - 1 );
18481864 SmallVector<int32_t > zeroValues (width, 0 );
18491865 v = rewriter.create <LLVM::ShuffleVectorOp>(loc, v, v, zeroValues);
18501866
1851- // Iterate of linear index, convert to coords space and insert splatted 1-D
1852- // vector in each position.
1867+ // Iterate of linear index, convert to coords space and insert broadcasted
1868+ // 1-D vector in each position.
18531869 nDVectorIterate (vectorTypeInfo, rewriter, [&](ArrayRef<int64_t > position) {
18541870 desc = rewriter.create <LLVM::InsertValueOp>(loc, desc, v, position);
18551871 });
1856- rewriter.replaceOp (splatOp , desc);
1872+ rewriter.replaceOp (broadcast , desc);
18571873 return success ();
18581874 }
18591875};
@@ -2035,6 +2051,19 @@ struct VectorScalableStepOpLowering
20352051 }
20362052};
20372053
2054+ // / Convert `vector.splat` to `vector.broadcast`. There is a path to LLVM from
2055+ // / `vector.broadcast` through other patterns.
2056+ struct VectorSplatToBroadcast : public ConvertOpToLLVMPattern <vector::SplatOp> {
2057+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
2058+ LogicalResult
2059+ matchAndRewrite (vector::SplatOp splat, OpAdaptor adaptor,
2060+ ConversionPatternRewriter &rewriter) const override {
2061+ rewriter.replaceOpWithNewOp <vector::BroadcastOp>(splat, splat.getType (),
2062+ adaptor.getInput ());
2063+ return success ();
2064+ }
2065+ };
2066+
20382067} // namespace
20392068
20402069void mlir::vector::populateVectorRankReducingFMAPattern (
@@ -2063,7 +2092,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
20632092 VectorInsertOpConversion, VectorPrintOpConversion,
20642093 VectorTypeCastOpConversion, VectorScaleOpConversion,
20652094 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
2066- VectorSplatOpLowering, VectorSplatNdOpLowering,
2095+ VectorSplatToBroadcast, VectorBroadcastScalarToLowRankLowering,
2096+ VectorBroadcastScalarToNdLowering,
20672097 VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
20682098 MaskedReductionOpConversion, VectorInterleaveOpLowering,
20692099 VectorDeinterleaveOpLowering, VectorFromElementsLowering,
0 commit comments