Skip to content

Commit ff52343

Browse files
committed
lowering from vector -- support broadcast natively
1 parent 47bbdf1 commit ff52343

File tree

21 files changed

+247
-163
lines changed

21 files changed

+247
-163
lines changed

mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
153153

154154
if (inVecType.getShape().empty()) {
155155
Value zerodSplat =
156-
rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
156+
rewriter.createOrFold<vector::BroadcastOp>(loc, outType, zero);
157157
Value scalarIn =
158158
rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
159159
Value scalarExt =
@@ -166,7 +166,7 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
166166

167167
VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
168168
outType.getElementType());
169-
Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
169+
Value result = rewriter.createOrFold<vector::BroadcastOp>(loc, flatTy, zero);
170170

171171
if (inVecType.getRank() > 1) {
172172
inVecType = VectorType::get(SmallVector<int64_t>{numElements},
@@ -315,7 +315,7 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
315315

316316
VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
317317
outVecType.getElementType());
318-
Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
318+
Value result = rewriter.createOrFold<vector::BroadcastOp>(loc, flatTy, zero);
319319

320320
if (inVectorTy.getRank() > 1) {
321321
inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
@@ -383,7 +383,7 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
383383
int64_t numElements = outVecType.getNumElements();
384384
Value zero = rewriter.createOrFold<arith::ConstantOp>(
385385
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
386-
Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
386+
Value result = rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
387387

388388
if (inVectorTy.getRank() > 1) {
389389
inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
@@ -479,7 +479,7 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
479479

480480
if (!outVecType) {
481481
Value inCast =
482-
rewriter.create<vector::SplatOp>(loc, VectorType::get(1, inType), in);
482+
rewriter.create<vector::BroadcastOp>(loc, VectorType::get(1, inType), in);
483483
// TODO: replace this with non-packed ScaledExtOp
484484
Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>(
485485
loc, extScaleResultType, inCast, scale, 0);
@@ -509,7 +509,7 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
509509

510510
Value zero = rewriter.create<arith::ConstantOp>(
511511
loc, outType, rewriter.getFloatAttr(outType, 0.0));
512-
Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
512+
Value result = rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
513513

514514
for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
515515
SmallVector<int64_t> strides(offsets.size(), 1);
@@ -523,7 +523,7 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
523523

524524
VectorType blockResultType = VectorType::get(blockSize, outType);
525525
Value blockResult =
526-
rewriter.createOrFold<vector::SplatOp>(loc, blockResultType, zero);
526+
rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
527527

528528
for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
529529
i < blockSize;
@@ -587,7 +587,7 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
587587

588588
if (!outVecType) {
589589
Type inVecType = VectorType::get(1, inType);
590-
Value inCast = rewriter.create<vector::SplatOp>(loc, inVecType, in);
590+
Value inCast = rewriter.create<vector::BroadcastOp>(loc, inVecType, in);
591591
// TODO: replace this with non-packed ScaledTruncOp
592592
Value scaleTrunc = rewriter.create<amdgpu::PackedScaledTruncOp>(
593593
loc, truncScaleResultType, inCast, scale, 0, /*existing=*/nullptr);
@@ -616,7 +616,7 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
616616

617617
int64_t blockSize = computeProduct(ratio);
618618

619-
Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
619+
Value result = rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
620620

621621
for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
622622
SmallVector<int64_t> strides(offsets.size(), 1);
@@ -630,7 +630,7 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
630630

631631
VectorType blockResultType = VectorType::get(blockSize, outType);
632632
Value blockResult =
633-
rewriter.createOrFold<vector::SplatOp>(loc, blockResultType, zero);
633+
rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
634634

635635
for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
636636
i < blockSize;

mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,7 @@ struct InsertTileSliceConversion
604604
rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
605605
auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
606606
/*scalableDims=*/{true});
607-
auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
607+
auto allActiveMask = rewriter.create<vector::BroadcastOp>(loc, predTy, one);
608608

609609
// Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
610610
switch (insertTileSliceOp.getLayout()) {

mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
324324

325325
// Splat pad into 1-D vector matching type of tile slice.
326326
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
327-
auto pad1DOp = rewriter.create<vector::SplatOp>(loc, tileSliceType, padOp);
327+
auto pad1DOp = rewriter.create<vector::BroadcastOp>(loc, tileSliceType, padOp);
328328

329329
auto loadSlice = rewriter.create<vector::MaskedLoadOp>(
330330
loc, tileSliceType, tileLoadOp.getBase(), memrefIndices, maskOp1D,

mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -792,7 +792,7 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
792792
op.getLoc(), vectorType.getElementType(),
793793
rewriter.getZeroAttr(vectorType.getElementType()));
794794
Value result =
795-
rewriter.create<vector::SplatOp>(op.getLoc(), fill, vectorType);
795+
rewriter.create<vector::BroadcastOp>(op.getLoc(), vectorType, fill);
796796

797797
bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
798798

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 63 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1408,7 +1408,7 @@ struct VectorScalableExtractOpLowering
14081408
/// ```
14091409
/// is rewritten into:
14101410
/// ```
1411-
/// %r = splat %f0: vector<2x4xf32>
1411+
/// %r = vector.broadcast %f0 : f32 to vector<2x4xf32>
14121412
/// %va = vector.extractvalue %a[0] : vector<2x4xf32>
14131413
/// %vb = vector.extractvalue %b[0] : vector<2x4xf32>
14141414
/// %vc = vector.extractvalue %c[0] : vector<2x4xf32>
@@ -1441,7 +1441,7 @@ class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
14411441
auto elemType = vType.getElementType();
14421442
Value zero = rewriter.create<arith::ConstantOp>(
14431443
loc, elemType, rewriter.getZeroAttr(elemType));
1444-
Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero);
1444+
Value desc = rewriter.create<vector::BroadcastOp>(loc, vType, zero);
14451445
for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
14461446
Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i);
14471447
Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i);
@@ -1583,7 +1583,7 @@ class VectorCreateMaskOpConversion
15831583
/*isScalable=*/true));
15841584
auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
15851585
adaptor.getOperands()[0]);
1586-
Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
1586+
Value bounds = rewriter.create<BroadcastOp>(loc, indices.getType(), bound);
15871587
Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
15881588
indices, bounds);
15891589
rewriter.replaceOp(op, comp);
@@ -1767,63 +1767,79 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
17671767
}
17681768
};
17691769

1770-
/// The Splat operation is lowered to an insertelement + a shufflevector
1771-
/// operation. Splat to only 0-d and 1-d vector result types are lowered.
1772-
struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> {
1773-
using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern;
1770+
/// A broadcast of a scalar is lowered to an insertelement + a shufflevector
1771+
/// operation. Only broadcasts to 0-d and 1-d vectors are lowered by this
1772+
/// pattern, the higher rank cases are handled by another pattern.
1773+
struct VectorBroadcastScalarToLowRankLowering
1774+
: public ConvertOpToLLVMPattern<vector::BroadcastOp> {
1775+
using ConvertOpToLLVMPattern<vector::BroadcastOp>::ConvertOpToLLVMPattern;
17741776

17751777
LogicalResult
1776-
matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
1778+
matchAndRewrite(vector::BroadcastOp broadcast, OpAdaptor adaptor,
17771779
ConversionPatternRewriter &rewriter) const override {
1778-
VectorType resultType = cast<VectorType>(splatOp.getType());
1780+
1781+
if (isa<VectorType>(broadcast.getSourceType()))
1782+
return rewriter.notifyMatchFailure(
1783+
broadcast, "broadcast from vector type not handled");
1784+
1785+
VectorType resultType = broadcast.getType();
17791786
if (resultType.getRank() > 1)
1780-
return failure();
1787+
return rewriter.notifyMatchFailure(broadcast,
1788+
"broadcast to 2+-d handled elsewhere");
17811789

17821790
// First insert it into a poison vector so we can shuffle it.
1783-
auto vectorType = typeConverter->convertType(splatOp.getType());
1791+
auto vectorType = typeConverter->convertType(broadcast.getType());
17841792
Value poison =
1785-
rewriter.create<LLVM::PoisonOp>(splatOp.getLoc(), vectorType);
1793+
rewriter.create<LLVM::PoisonOp>(broadcast.getLoc(), vectorType);
17861794
auto zero = rewriter.create<LLVM::ConstantOp>(
1787-
splatOp.getLoc(),
1795+
broadcast.getLoc(),
17881796
typeConverter->convertType(rewriter.getIntegerType(32)),
17891797
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
17901798

17911799
// For 0-d vector, we simply do `insertelement`.
17921800
if (resultType.getRank() == 0) {
17931801
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1794-
splatOp, vectorType, poison, adaptor.getInput(), zero);
1802+
broadcast, vectorType, poison, adaptor.getSource(), zero);
17951803
return success();
17961804
}
17971805

17981806
// For 1-d vector, we additionally do a `vectorshuffle`.
17991807
auto v = rewriter.create<LLVM::InsertElementOp>(
1800-
splatOp.getLoc(), vectorType, poison, adaptor.getInput(), zero);
1808+
broadcast.getLoc(), vectorType, poison, adaptor.getSource(), zero);
18011809

1802-
int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0);
1810+
int64_t width = cast<VectorType>(broadcast.getType()).getDimSize(0);
18031811
SmallVector<int32_t> zeroValues(width, 0);
18041812

18051813
// Shuffle the value across the desired number of elements.
1806-
rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, poison,
1814+
rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(broadcast, v, poison,
18071815
zeroValues);
18081816
return success();
18091817
}
18101818
};
18111819

1812-
/// The Splat operation is lowered to an insertelement + a shufflevector
1813-
/// operation. Splat to only 2+-d vector result types are lowered by the
1814-
/// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
1815-
struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
1816-
using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
1820+
/// The broadcast of a scalar is lowered to an insertelement + a shufflevector
1821+
/// operation. Only broadcasts to 2+-d vector result types are lowered by this
1822+
/// pattern, the 1-d case is handled by another pattern. Broadcasts from vectors
1823+
/// are not converted to LLVM, only broadcasts from scalars are.
1824+
struct VectorBroadcastScalarToNdLowering
1825+
: public ConvertOpToLLVMPattern<BroadcastOp> {
1826+
using ConvertOpToLLVMPattern<BroadcastOp>::ConvertOpToLLVMPattern;
18171827

18181828
LogicalResult
1819-
matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
1829+
matchAndRewrite(BroadcastOp broadcast, OpAdaptor adaptor,
18201830
ConversionPatternRewriter &rewriter) const override {
1821-
VectorType resultType = splatOp.getType();
1831+
1832+
if (isa<VectorType>(broadcast.getSourceType()))
1833+
return rewriter.notifyMatchFailure(
1834+
broadcast, "broadcast from vector type not handled");
1835+
1836+
VectorType resultType = broadcast.getType();
18221837
if (resultType.getRank() <= 1)
1823-
return failure();
1838+
return rewriter.notifyMatchFailure(
1839+
broadcast, "broadcast to 1-d or 0-d handled elsewhere");
18241840

18251841
// First insert it into an undef vector so we can shuffle it.
1826-
auto loc = splatOp.getLoc();
1842+
auto loc = broadcast.getLoc();
18271843
auto vectorTypeInfo =
18281844
LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
18291845
auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
@@ -1834,26 +1850,26 @@ struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
18341850
// Construct returned value.
18351851
Value desc = rewriter.create<LLVM::PoisonOp>(loc, llvmNDVectorTy);
18361852

1837-
// Construct a 1-D vector with the splatted value that we insert in all the
1838-
// places within the returned descriptor.
1853+
// Construct a 1-D vector with the broadcasted value that we insert in all
1854+
// the places within the returned descriptor.
18391855
Value vdesc = rewriter.create<LLVM::PoisonOp>(loc, llvm1DVectorTy);
18401856
auto zero = rewriter.create<LLVM::ConstantOp>(
18411857
loc, typeConverter->convertType(rewriter.getIntegerType(32)),
18421858
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
18431859
Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
1844-
adaptor.getInput(), zero);
1860+
adaptor.getSource(), zero);
18451861

18461862
// Shuffle the value across the desired number of elements.
18471863
int64_t width = resultType.getDimSize(resultType.getRank() - 1);
18481864
SmallVector<int32_t> zeroValues(width, 0);
18491865
v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues);
18501866

1851-
// Iterate of linear index, convert to coords space and insert splatted 1-D
1852-
// vector in each position.
1867+
// Iterate of linear index, convert to coords space and insert broadcasted
1868+
// 1-D vector in each position.
18531869
nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
18541870
desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position);
18551871
});
1856-
rewriter.replaceOp(splatOp, desc);
1872+
rewriter.replaceOp(broadcast, desc);
18571873
return success();
18581874
}
18591875
};
@@ -2035,6 +2051,19 @@ struct VectorScalableStepOpLowering
20352051
}
20362052
};
20372053

2054+
/// Convert `vector.splat` to `vector.broadcast`. There is a path to LLVM from
2055+
/// `vector.broadcast` through other patterns.
2056+
struct VectorSplatToBroadcast : public ConvertOpToLLVMPattern<vector::SplatOp> {
2057+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
2058+
LogicalResult
2059+
matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
2060+
ConversionPatternRewriter &rewriter) const override {
2061+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(),
2062+
adaptor.getInput());
2063+
return success();
2064+
}
2065+
};
2066+
20382067
} // namespace
20392068

20402069
void mlir::vector::populateVectorRankReducingFMAPattern(
@@ -2063,7 +2092,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
20632092
VectorInsertOpConversion, VectorPrintOpConversion,
20642093
VectorTypeCastOpConversion, VectorScaleOpConversion,
20652094
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
2066-
VectorSplatOpLowering, VectorSplatNdOpLowering,
2095+
VectorSplatToBroadcast, VectorBroadcastScalarToLowRankLowering,
2096+
VectorBroadcastScalarToNdLowering,
20672097
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
20682098
MaskedReductionOpConversion, VectorInterleaveOpLowering,
20692099
VectorDeinterleaveOpLowering, VectorFromElementsLowering,

mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ struct Strategy<TransferReadOp> {
444444
Location loc = xferOp.getLoc();
445445
auto bufferType = dyn_cast<ShapedType>(buffer.getType());
446446
auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
447-
auto vec = b.create<vector::SplatOp>(loc, vecType, xferOp.getPadding());
447+
auto vec = b.create<vector::BroadcastOp>(loc, vecType, xferOp.getPadding());
448448
b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
449449

450450
return Value();
@@ -1261,7 +1261,7 @@ struct UnrollTransferReadConversion
12611261
if (auto insertOp = getInsertOp(xferOp))
12621262
return insertOp.getDest();
12631263
Location loc = xferOp.getLoc();
1264-
return rewriter.create<vector::SplatOp>(loc, xferOp.getVectorType(),
1264+
return rewriter.create<vector::BroadcastOp>(loc, xferOp.getVectorType(),
12651265
xferOp.getPadding());
12661266
}
12671267

@@ -1583,7 +1583,7 @@ struct Strategy1d<TransferReadOp> {
15831583
static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
15841584
// Inititalize vector with padding value.
15851585
Location loc = xferOp.getLoc();
1586-
return b.create<vector::SplatOp>(loc, xferOp.getVectorType(),
1586+
return b.create<vector::BroadcastOp>(loc, xferOp.getVectorType(),
15871587
xferOp.getPadding());
15881588
}
15891589
};

0 commit comments

Comments
 (0)