@@ -183,9 +183,9 @@ class VectorContractRewriter {
183183 Value acc;
184184
185185 // Conventional names for matrix dimensions.
186- int64_t M = 0 ;
187- int64_t N = 0 ;
188- int64_t K = 0 ;
186+ int64_t m = 0 ;
187+ int64_t n = 0 ;
188+ int64_t k = 0 ;
189189
190190 // Create the matrix mulitply and accumulate operation according to
191191 // `mmlaOp`.
@@ -286,41 +286,41 @@ Value VectorContractRewriter::lower(vector::ContractionOp op,
286286
287287 // Single-dimension vector type for the entire RHS tile.
288288
289- auto flatRhsTileType = VectorType::get (/* shape=*/ K * N , operandEltType,
289+ auto flatRhsTileType = VectorType::get (/* shape=*/ k * n , operandEltType,
290290 /* scalableDims=*/ {true });
291291
292292 // Vector type having the same number of elements as a row in the
293293 // accumulator/output tile and the same element type.
294- auto accRowTy = VectorType::get (/* shape=*/ N , resultEltType,
294+ auto accRowTy = VectorType::get (/* shape=*/ n , resultEltType,
295295 /* scalableDims=*/ {true });
296296
297297 // Vector type having twice the number of elements as a row in the
298298 // accumulator/output tile the same element type.
299- auto accRowX2Ty = VectorType::get (/* shape=*/ 2 * N , resultEltType,
299+ auto accRowX2Ty = VectorType::get (/* shape=*/ 2 * n , resultEltType,
300300 /* scalableDims=*/ {true });
301301 // Vector type having half the number of elements as a row in the
302302 // accumulator/output tile and an integer element type with twice the bit
303303 // width.
304- auto accRow64Ty = VectorType::get (/* shape=*/ N / 2 , rewriter.getI64Type (),
304+ auto accRow64Ty = VectorType::get (/* shape=*/ n / 2 , rewriter.getI64Type (),
305305 /* scalableDims=*/ {true });
306306 // Vector type having the same the number of elements as a row in the
307307 // accumulator/output tile and an integer element type with twice the bit
308308 // width.
309- auto accRowX264Ty = VectorType::get (/* shape=*/ N , rewriter.getI64Type (),
309+ auto accRowX264Ty = VectorType::get (/* shape=*/ n , rewriter.getI64Type (),
310310 /* scalableDims=*/ {true });
311311
312312 Location loc = op.getLoc ();
313313
314314 // Extract LHS sub-tiles with logical shape <2xK>.
315315 SmallVector<Value> lhsTile;
316- for (int64_t i = 0 ; i < M ; i += 2 ) {
316+ for (int64_t i = 0 ; i < m ; i += 2 ) {
317317 // Extract two consecutive rows of the LHS tile.
318318 auto r0 =
319319 vector::ExtractOp::create (rewriter, loc, lhs, ArrayRef<int64_t >{i});
320320 auto r1 =
321321 vector::ExtractOp::create (rewriter, loc, lhs, ArrayRef<int64_t >{i + 1 });
322322 // Concatenate to obtain a 2 x K x <input-type> flattened sub-tile.
323- SmallVector<int64_t > shuffleIdx (2 * K );
323+ SmallVector<int64_t > shuffleIdx (2 * k );
324324 std::iota (shuffleIdx.begin (), shuffleIdx.end (), 0 );
325325 auto t = vector::ShuffleOp::create (rewriter, loc, r0, r1, shuffleIdx);
326326 // Turn it into a scalable vector.
@@ -337,13 +337,13 @@ Value VectorContractRewriter::lower(vector::ContractionOp op,
337337
338338 // Extract the RHS sub-tiles with logical shape <Kx[2]>.
339339 SmallVector<Value> rhsTile;
340- for (int64_t j = 0 ; j < N ; j += 2 )
340+ for (int64_t j = 0 ; j < n ; j += 2 )
341341 rhsTile.push_back (vector::ScalableExtractOp::create (
342- rewriter, loc, flatRhsType, rhs, j * K ));
342+ rewriter, loc, flatRhsType, rhs, j * k ));
343343
344344 // Extract and pack the ACC sub-tiles.
345345 SmallVector<Value> accTile;
346- for (int64_t i = 0 ; i < M ; i += 2 ) {
346+ for (int64_t i = 0 ; i < m ; i += 2 ) {
347347 // Extract two consecutive rows of the accumulator tile.
348348 auto r0 = vector::ExtractOp::create (rewriter, loc, op.getAcc (),
349349 ArrayRef<int64_t >{i});
@@ -370,28 +370,28 @@ Value VectorContractRewriter::lower(vector::ContractionOp op,
370370 vector::BitCastOp::create (rewriter, loc, accRowX2Ty, intrI64);
371371 }
372372 // Extract ACC sub-tiles.
373- for (int64_t j = 0 ; j < N ; j += 2 )
373+ for (int64_t j = 0 ; j < n ; j += 2 )
374374 accTile.push_back (vector::ScalableExtractOp::create (
375375 rewriter, loc, flatAccType, accTileVec, j * 2 ));
376376 }
377377
378378 // Emit sub-tile matrix multiplications.
379379 SmallVector<Value> outTile;
380- for (int64_t i = 0 ; i < M / 2 ; ++i)
381- for (int64_t j = 0 ; j < N / 2 ; ++j) {
382- Value mmla = createMMLA (rewriter, loc, accTile[i * N / 2 + j], lhsTile[i],
380+ for (int64_t i = 0 ; i < m / 2 ; ++i)
381+ for (int64_t j = 0 ; j < n / 2 ; ++j) {
382+ Value mmla = createMMLA (rewriter, loc, accTile[i * n / 2 + j], lhsTile[i],
383383 rhsTile[j]);
384384 outTile.push_back (mmla);
385385 }
386386
387387 // Unpack the OUT sub-tiles and insert into the result.
388388 Value result = ub::PoisonOp::create (rewriter, loc, op.getResultType ());
389- for (int64_t i = 0 ; i < M / 2 ; ++i) {
389+ for (int64_t i = 0 ; i < m / 2 ; ++i) {
390390 // Collect a number of sub-tiles in a row.
391391 Value row = ub::PoisonOp::create (rewriter, loc, accRowX2Ty);
392- for (int64_t j = 0 ; j < N / 2 ; ++j)
392+ for (int64_t j = 0 ; j < n / 2 ; ++j)
393393 row = vector::ScalableInsertOp::create (
394- rewriter, loc, outTile[i * N / 2 + j], row, j * 4 );
394+ rewriter, loc, outTile[i * n / 2 + j], row, j * 4 );
395395
396396 // Unpack the row to obtain two rows of the output. If we have the out
397397 // sub-tiles transposed we obtain two consecutive output rows by
@@ -432,18 +432,18 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
432432 VectorType lhsType = op.getLhsType ();
433433 VectorType rhsType = op.getRhsType ();
434434
435- M = lhsType.getDimSize (0 );
436- N = rhsType.getDimSize (0 );
437- K = rhsType.getDimSize (1 );
435+ m = lhsType.getDimSize (0 );
436+ n = rhsType.getDimSize (0 );
437+ k = rhsType.getDimSize (1 );
438438
439439 // Check the operands have the expected shape:
440440 // * for LHS: fixed vector MxK
441441 // * for RHS: scalable vector [N]xK
442442 // * K == 8
443443 // * M and N even and at least 2
444444 if (lhsType.isScalable () || !rhsType.getScalableDims ()[0 ] ||
445- rhsType.getScalableDims ()[1 ] || lhsType.getDimSize (1 ) != K || K != 8 ||
446- M < 2 || M % 2 != 0 || N < 2 || N % 2 != 0 ||
445+ rhsType.getScalableDims ()[1 ] || lhsType.getDimSize (1 ) != k || k != 8 ||
446+ m < 2 || m % 2 != 0 || n < 2 || n % 2 != 0 ||
447447 !rhsType.getScalableDims ()[0 ])
448448 return rewriter.notifyMatchFailure (op, " non-matching operand shape" );
449449
@@ -504,18 +504,18 @@ class VectorContractRewriterBfloat : public VectorContractRewriter {
504504 VectorType lhsType = op.getLhsType ();
505505 VectorType rhsType = op.getRhsType ();
506506
507- M = lhsType.getDimSize (0 );
508- N = rhsType.getDimSize (0 );
509- K = rhsType.getDimSize (1 );
507+ m = lhsType.getDimSize (0 );
508+ n = rhsType.getDimSize (0 );
509+ k = rhsType.getDimSize (1 );
510510
511511 // Check the operands have the expected shape:
512512 // * for LHS: fixed vector MxK
513513 // * for RHS: scalable vector [N]xK
514514 // * K == 4
515515 // * M and N even and at least 2
516516 if (lhsType.isScalable () || !rhsType.getScalableDims ()[0 ] ||
517- rhsType.getScalableDims ()[1 ] || lhsType.getDimSize (1 ) != K || K != 4 ||
518- M < 2 || M % 2 != 0 || N < 2 || N % 2 != 0 ||
517+ rhsType.getScalableDims ()[1 ] || lhsType.getDimSize (1 ) != k || k != 4 ||
518+ m < 2 || m % 2 != 0 || n < 2 || n % 2 != 0 ||
519519 !rhsType.getScalableDims ()[0 ])
520520 return rewriter.notifyMatchFailure (op, " non-matching operand shape" );
521521
0 commit comments