2727#define DEBUG_TYPE " lower-contract-to-arm-sve-i8mm"
2828
2929using namespace mlir ;
30- using namespace mlir ::arm_sve;
3130
3231namespace {
3332// Get the operand of a `vector.contract`. This function is intended to abstract
3433// away from the particular way a value is extended before feeding it into the
3534// `vector.contract` - via zero-extend or an explicit or implicit sign-extend
3635// (for implicit sign-extension see `vector.contract` documentation).
3736//
38- // The template parameter `Op` indicates the extension operation (explicir or
37+ // The template parameter `Op` indicates the extension operation (explicit or
3938// implicit) for which we are checking.
4039//
4140// Return success only for extensions from `i8` to `i32`.
@@ -59,7 +58,7 @@ std::optional<Value> getExtOperand(Value v, Type i8Ty, Type i32Ty) {
5958 }
6059
6160 // If the operand is defined by an explicit extend operation of the accepted
62- // operation type, check it's extented from `i8` to `i32`.
61+ // operation type, check it's extended from `i8` to `i32`.
6362 auto inOp = extOp.getIn ();
6463 auto inTy = dyn_cast<VectorType>(inOp.getType ());
6564 if (!inTy || inTy.getElementType () != i8Ty)
@@ -81,7 +80,7 @@ enum class MMLA {
8180 MixedSwapped // usmmla with LHS and RHS swapped
8281};
8382
84- // Create the matrix multply and accumulate operation according to `op`.
83+ // Create the matrix mulitply and accumulate operation according to `op`.
8584Value createMMLA (PatternRewriter &rewriter, MMLA op, Location loc,
8685 mlir::VectorType accType, Value acc, Value lhs, Value rhs) {
8786 switch (op) {
@@ -128,7 +127,7 @@ Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
128127// / The LHS operand is initially unpacked into M/2 values, each representing a
129128// / sub-tile with dimensions <2x8>, and then each such sub-tile is replicated
130129// / VSCALE times.
131- // / Multiplying thus replicated LHS sub-tile by the corresposponing RHS sub-tile
130+ // / Multiplying thus replicated LHS sub-tile by the corresponding RHS sub-tile
132131// / correctly computes an entire result sub-tile.
133132class LowerContractionToSVEI8MMPattern
134133 : public OpRewritePattern<vector::ContractionOp> {
@@ -235,7 +234,7 @@ class LowerContractionToSVEI8MMPattern
235234 // Extract LHS sub-tiles with logicall shape <2x8>.
236235 SmallVector<Value> lhsTile;
237236 for (int64_t i = 0 ; i < M; i += 2 ) {
238- // Exract two consective rows of the LHS tile.
237+ // Extract two consecutive rows of the LHS tile.
239238 auto r0 = rewriter.create <vector::ExtractOp>(loc, *maybeLhs,
240239 ArrayRef<int64_t >{i});
241240 auto r1 = rewriter.create <vector::ExtractOp>(loc, *maybeLhs,
@@ -293,17 +292,16 @@ class LowerContractionToSVEI8MMPattern
293292 } else {
294293 // Bitcast them to 64-bit elements, so subsequent
295294 // interleave/deinterleave work on pairs of 32-bit numbers.
296- auto r0_i64 = rewriter.create <vector::BitCastOp>(loc, accRow64Ty, r0);
297- auto r1_i64 = rewriter.create <vector::BitCastOp>(loc, accRow64Ty, r1);
295+ auto r0I64 = rewriter.create <vector::BitCastOp>(loc, accRow64Ty, r0);
296+ auto r1I64 = rewriter.create <vector::BitCastOp>(loc, accRow64Ty, r1);
298297
299298 // Interleave the rows, effectively flattening each 2x2 tile into 4
300299 // consecutive elements.
301- auto intr_i64 =
302- rewriter.create <vector::InterleaveOp>(loc, r0_i64, r1_i64);
300+ auto intrI64 = rewriter.create <vector::InterleaveOp>(loc, r0I64, r1I64);
303301
304302 // Bitcast back to 32-bit elements.
305303 accTileVec =
306- rewriter.create <vector::BitCastOp>(loc, accRowX2Ty, intr_i64 );
304+ rewriter.create <vector::BitCastOp>(loc, accRowX2Ty, intrI64 );
307305 }
308306 // Extract ACC sub-tiles.
309307 for (int64_t j = 0 ; j < N; j += 2 )
0 commit comments