Skip to content

Commit 24c9a9f

Browse files
[fixup] Minor style/spelling fixes
1 parent 0ddabca commit 24c9a9f

File tree

2 files changed

+10
-12
lines changed

2 files changed

+10
-12
lines changed

mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class LowerContractionToSMMLAPattern
5757
if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
5858
return failure();
5959
// This codegen does not work for scalable vectors. Return failure so this
60-
// pattern not accidentally chosen over patterns that lower to ArmSVE.
60+
// pattern is not accidentally chosen over patterns that lower to ArmSVE.
6161
if (lhsType.isScalable() || rhsType.isScalable())
6262
return failure();
6363
auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);

mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,14 @@
2727
#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
2828

2929
using namespace mlir;
30-
using namespace mlir::arm_sve;
3130

3231
namespace {
3332
// Get the operand of a `vector.contract`. This function is intended to abstract
3433
// away from the particular way a value is extended before feeding it into the
3534
// `vector.contract` - via zero-extend or an explicit or implicit sign-extend
3635
// (for implicit sign-extension see `vector.contract` documentation).
3736
//
38-
// The template parameter `Op` indicates the extension operation (explicir or
37+
// The template parameter `Op` indicates the extension operation (explicit or
3938
// implicit) for which we are checking.
4039
//
4140
// Return success only for extensions from `i8` to `i32`.
@@ -59,7 +58,7 @@ std::optional<Value> getExtOperand(Value v, Type i8Ty, Type i32Ty) {
5958
}
6059

6160
// If the operand is defined by an explicit extend operation of the accepted
62-
// operation type, check it's extented from `i8` to `i32`.
61+
// operation type, check it's extended from `i8` to `i32`.
6362
auto inOp = extOp.getIn();
6463
auto inTy = dyn_cast<VectorType>(inOp.getType());
6564
if (!inTy || inTy.getElementType() != i8Ty)
@@ -81,7 +80,7 @@ enum class MMLA {
8180
MixedSwapped // usmmla with LHS and RHS swapped
8281
};
8382

84-
// Create the matrix multply and accumulate operation according to `op`.
83+
// Create the matrix mulitply and accumulate operation according to `op`.
8584
Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
8685
mlir::VectorType accType, Value acc, Value lhs, Value rhs) {
8786
switch (op) {
@@ -128,7 +127,7 @@ Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
128127
/// The LHS operand is initially unpacked into M/2 values, each representing a
129128
/// sub-tile with dimensions <2x8>, and then each such sub-tile is replicated
130129
/// VSCALE times.
131-
/// Multiplying thus replicated LHS sub-tile by the corresposponing RHS sub-tile
130+
/// Multiplying thus replicated LHS sub-tile by the corresponding RHS sub-tile
132131
/// correctly computes an entire result sub-tile.
133132
class LowerContractionToSVEI8MMPattern
134133
: public OpRewritePattern<vector::ContractionOp> {
@@ -235,7 +234,7 @@ class LowerContractionToSVEI8MMPattern
235234
// Extract LHS sub-tiles with logicall shape <2x8>.
236235
SmallVector<Value> lhsTile;
237236
for (int64_t i = 0; i < M; i += 2) {
238-
// Exract two consective rows of the LHS tile.
237+
// Extract two consecutive rows of the LHS tile.
239238
auto r0 = rewriter.create<vector::ExtractOp>(loc, *maybeLhs,
240239
ArrayRef<int64_t>{i});
241240
auto r1 = rewriter.create<vector::ExtractOp>(loc, *maybeLhs,
@@ -293,17 +292,16 @@ class LowerContractionToSVEI8MMPattern
293292
} else {
294293
// Bitcast them to 64-bit elements, so subsequent
295294
// interleave/deinterleave work on pairs of 32-bit numbers.
296-
auto r0_i64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r0);
297-
auto r1_i64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r1);
295+
auto r0I64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r0);
296+
auto r1I64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r1);
298297

299298
// Interleave the rows, effectively flattening each 2x2 tile into 4
300299
// consecutive elements.
301-
auto intr_i64 =
302-
rewriter.create<vector::InterleaveOp>(loc, r0_i64, r1_i64);
300+
auto intrI64 = rewriter.create<vector::InterleaveOp>(loc, r0I64, r1I64);
303301

304302
// Bitcast back to 32-bit elements.
305303
accTileVec =
306-
rewriter.create<vector::BitCastOp>(loc, accRowX2Ty, intr_i64);
304+
rewriter.create<vector::BitCastOp>(loc, accRowX2Ty, intrI64);
307305
}
308306
// Extract ACC sub-tiles.
309307
for (int64_t j = 0; j < N; j += 2)

0 commit comments

Comments
 (0)