Skip to content

Commit 64d24ef

Browse files
committed
ArmSME lowering and final removals
1 parent ff52343 commit 64d24ef

File tree

3 files changed

+17
-63
lines changed

3 files changed

+17
-63
lines changed

mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp

Lines changed: 15 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -255,66 +255,6 @@ struct BroadcastOpToArmSMELowering
255255
}
256256
};
257257

258-
/// Conversion pattern for vector.splat.
259-
///
260-
/// Example:
261-
///
262-
/// %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32>
263-
///
264-
/// is converted to:
265-
///
266-
/// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
267-
/// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
268-
/// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
269-
/// {
270-
/// %tile_update = arm_sme.insert_tile_slice
271-
/// %broadcast_to_1d, %iter_tile[%tile_slice_index] :
272-
/// vector<[4]xi32> into vector<[4]x[4]xi32>
273-
/// scf.yield %tile_update : vector<[4]x[4]xi32>
274-
/// }
275-
///
276-
/// This is identical to vector.broadcast of a scalar.
277-
struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
278-
using OpRewritePattern<vector::SplatOp>::OpRewritePattern;
279-
280-
LogicalResult matchAndRewrite(vector::SplatOp splatOp,
281-
PatternRewriter &rewriter) const final {
282-
auto tileType = splatOp.getResult().getType();
283-
if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
284-
return failure();
285-
286-
auto loc = splatOp.getLoc();
287-
auto srcType = splatOp.getOperand().getType();
288-
289-
assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat");
290-
// Avoid unused-variable warning when building without assertions.
291-
(void)srcType;
292-
293-
// First, broadcast the scalar to a 1-d vector.
294-
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
295-
Value broadcastOp1D = rewriter.create<vector::BroadcastOp>(
296-
loc, tileSliceType, splatOp.getInput());
297-
298-
auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
299-
300-
auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
301-
Value currentTile) {
302-
auto nextTile = b.create<arm_sme::InsertTileSliceOp>(
303-
loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
304-
return nextTile.getResult();
305-
};
306-
307-
// Next, create a loop over ZA tile slices and "move" the generated 1-d
308-
// vector to each slice.
309-
auto forOp =
310-
createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody);
311-
312-
rewriter.replaceOp(splatOp, forOp.getResult(0));
313-
314-
return success();
315-
}
316-
};
317-
318258
/// Conversion pattern for vector.transpose.
319259
///
320260
/// Stores the input tile to memory and reloads vertically.
@@ -790,11 +730,25 @@ struct ExtractFromCreateMaskToPselLowering
790730
}
791731
};
792732

733+
// Convert all `vector.splat` to `vector.broadcast`. There is a path from
734+
// `vector.broadcast` to ArmSME via another pattern.
735+
struct ConvertSplatToBroadcast : public OpRewritePattern<vector::SplatOp> {
736+
using OpRewritePattern<vector::SplatOp>::OpRewritePattern;
737+
738+
LogicalResult matchAndRewrite(vector::SplatOp splatOp,
739+
PatternRewriter &rewriter) const final {
740+
741+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
742+
splatOp.getInput());
743+
return success();
744+
}
745+
};
746+
793747
} // namespace
794748

795749
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
796750
MLIRContext &ctx) {
797-
patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
751+
patterns.add<BroadcastOpToArmSMELowering, ConvertSplatToBroadcast,
798752
TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
799753
TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
800754
VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,

mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ void mlir::arith::populateEmulateUnsupportedFloatsLegality(
123123
vector::FMAOp, vector::OuterProductOp, vector::MatmulOp, vector::ScanOp>(
124124
[&](Operation *op) { return converter.isLegal(op); });
125125
target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
126-
arith::ConstantOp, vector::SplatOp>();
126+
arith::ConstantOp, vector::SplatOp, vector::BroadcastOp>();
127127
}
128128

129129
void EmulateUnsupportedFloatsPass::runOnOperation() {

mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ class DecomposeNDExtractStridedSlice
303303
// Extract/insert on a lower ranked extract strided slice op.
304304
Value zero = rewriter.create<arith::ConstantOp>(
305305
loc, elemType, rewriter.getZeroAttr(elemType));
306-
Value res = rewriter.create<SplatOp>(loc, dstType, zero);
306+
Value res = rewriter.create<BroadcastOp>(loc, dstType, zero);
307307
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
308308
off += stride, ++idx) {
309309
Value one = rewriter.create<ExtractOp>(loc, op.getVector(), off);

0 commit comments

Comments
 (0)