Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2919,6 +2919,8 @@ def Vector_SplatOp : Vector_Op<"splat", [
]> {
let summary = "vector splat or broadcast operation";
let description = [{
Note: This operation is deprecated. Please use vector.broadcast.

Broadcast the operand to all elements of the result vector. The type of the
operand must match the element type of the vector type.

Expand All @@ -2928,6 +2930,13 @@ def Vector_SplatOp : Vector_Op<"splat", [
%s = arith.constant 10.1 : f32
%t = vector.splat %s : vector<8x16xf32>
```

This operation is deprecated, the preferred representation of the above is:

```mlir
%s = arith.constant 10.1 : f32
%t = vector.broadcast %s : f32 to vector<8x16xf32>
```
}];

let arguments = (ins AnyType:$input);
Expand All @@ -2939,6 +2948,9 @@ def Vector_SplatOp : Vector_Op<"splat", [
let assemblyFormat = "$input attr-dict `:` type($aggregate)";

let hasFolder = 1;

// As vector.splat is deprecated, it is canonicalized to vector.broadcast.
let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
25 changes: 14 additions & 11 deletions mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,

if (inVecType.getShape().empty()) {
Value zerodSplat =
rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
rewriter.createOrFold<vector::BroadcastOp>(loc, outType, zero);
Value scalarIn =
rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
Value scalarExt =
Expand All @@ -166,7 +166,7 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,

VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
outType.getElementType());
Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
Value result = rewriter.createOrFold<vector::BroadcastOp>(loc, flatTy, zero);

if (inVecType.getRank() > 1) {
inVecType = VectorType::get(SmallVector<int64_t>{numElements},
Expand Down Expand Up @@ -315,7 +315,7 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,

VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
outVecType.getElementType());
Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
Value result = rewriter.createOrFold<vector::BroadcastOp>(loc, flatTy, zero);

if (inVectorTy.getRank() > 1) {
inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
Expand Down Expand Up @@ -383,7 +383,8 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
int64_t numElements = outVecType.getNumElements();
Value zero = rewriter.createOrFold<arith::ConstantOp>(
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
Value result =
rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);

if (inVectorTy.getRank() > 1) {
inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
Expand Down Expand Up @@ -478,8 +479,8 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
VectorType extScaleResultType = VectorType::get(opWidth, outType);

if (!outVecType) {
Value inCast =
rewriter.create<vector::SplatOp>(loc, VectorType::get(1, inType), in);
Value inCast = rewriter.create<vector::BroadcastOp>(
loc, VectorType::get(1, inType), in);
// TODO: replace this with non-packed ScaledExtOp
Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>(
loc, extScaleResultType, inCast, scale, 0);
Expand Down Expand Up @@ -509,7 +510,8 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,

Value zero = rewriter.create<arith::ConstantOp>(
loc, outType, rewriter.getFloatAttr(outType, 0.0));
Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
Value result =
rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);

for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
SmallVector<int64_t> strides(offsets.size(), 1);
Expand All @@ -523,7 +525,7 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,

VectorType blockResultType = VectorType::get(blockSize, outType);
Value blockResult =
rewriter.createOrFold<vector::SplatOp>(loc, blockResultType, zero);
rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);

for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
i < blockSize;
Expand Down Expand Up @@ -587,7 +589,7 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,

if (!outVecType) {
Type inVecType = VectorType::get(1, inType);
Value inCast = rewriter.create<vector::SplatOp>(loc, inVecType, in);
Value inCast = rewriter.create<vector::BroadcastOp>(loc, inVecType, in);
// TODO: replace this with non-packed ScaledTruncOp
Value scaleTrunc = rewriter.create<amdgpu::PackedScaledTruncOp>(
loc, truncScaleResultType, inCast, scale, 0, /*existing=*/nullptr);
Expand Down Expand Up @@ -616,7 +618,8 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,

int64_t blockSize = computeProduct(ratio);

Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
Value result =
rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);

for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
SmallVector<int64_t> strides(offsets.size(), 1);
Expand All @@ -630,7 +633,7 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,

VectorType blockResultType = VectorType::get(blockSize, outType);
Value blockResult =
rewriter.createOrFold<vector::SplatOp>(loc, blockResultType, zero);
rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);

for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
i < blockSize;
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ struct InsertTileSliceConversion
rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
/*scalableDims=*/{true});
auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
auto allActiveMask = rewriter.create<vector::BroadcastOp>(loc, predTy, one);

// Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
switch (insertTileSliceOp.getLayout()) {
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,8 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion

// Splat pad into 1-D vector matching type of tile slice.
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
auto pad1DOp = rewriter.create<vector::SplatOp>(loc, tileSliceType, padOp);
auto pad1DOp =
rewriter.create<vector::BroadcastOp>(loc, tileSliceType, padOp);

auto loadSlice = rewriter.create<vector::MaskedLoadOp>(
loc, tileSliceType, tileLoadOp.getBase(), memrefIndices, maskOp1D,
Expand Down
76 changes: 15 additions & 61 deletions mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,66 +255,6 @@ struct BroadcastOpToArmSMELowering
}
};

/// Conversion pattern for vector.splat.
///
/// Example:
///
/// %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32>
///
/// is converted to:
///
/// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
/// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
/// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
/// {
/// %tile_update = arm_sme.insert_tile_slice
/// %broadcast_to_1d, %iter_tile[%tile_slice_index] :
/// vector<[4]xi32> into vector<[4]x[4]xi32>
/// scf.yield %tile_update : vector<[4]x[4]xi32>
/// }
///
/// This is identical to vector.broadcast of a scalar.
struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
using OpRewritePattern<vector::SplatOp>::OpRewritePattern;

LogicalResult matchAndRewrite(vector::SplatOp splatOp,
PatternRewriter &rewriter) const final {
auto tileType = splatOp.getResult().getType();
if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
return failure();

auto loc = splatOp.getLoc();
auto srcType = splatOp.getOperand().getType();

assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat");
// Avoid unused-variable warning when building without assertions.
(void)srcType;

// First, broadcast the scalar to a 1-d vector.
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
Value broadcastOp1D = rewriter.create<vector::BroadcastOp>(
loc, tileSliceType, splatOp.getInput());

auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);

auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
Value currentTile) {
auto nextTile = b.create<arm_sme::InsertTileSliceOp>(
loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
return nextTile.getResult();
};

// Next, create a loop over ZA tile slices and "move" the generated 1-d
// vector to each slice.
auto forOp =
createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody);

rewriter.replaceOp(splatOp, forOp.getResult(0));

return success();
}
};

/// Conversion pattern for vector.transpose.
///
/// Stores the input tile to memory and reloads vertically.
Expand Down Expand Up @@ -790,11 +730,25 @@ struct ExtractFromCreateMaskToPselLowering
}
};

// Convert all `vector.splat` to `vector.broadcast`. There is a path from
// `vector.broadcast` to ArmSME via another pattern.
struct ConvertSplatToBroadcast : public OpRewritePattern<vector::SplatOp> {
using OpRewritePattern<vector::SplatOp>::OpRewritePattern;

LogicalResult matchAndRewrite(vector::SplatOp splatOp,
PatternRewriter &rewriter) const final {

rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
splatOp.getInput());
return success();
}
};

} // namespace

void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
MLIRContext &ctx) {
patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
patterns.add<BroadcastOpToArmSMELowering, ConvertSplatToBroadcast,
TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,7 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
op.getLoc(), vectorType.getElementType(),
rewriter.getZeroAttr(vectorType.getElementType()));
Value result =
rewriter.create<vector::SplatOp>(op.getLoc(), fill, vectorType);
rewriter.create<vector::BroadcastOp>(op.getLoc(), vectorType, fill);

bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();

Expand Down
Loading
Loading