@@ -508,7 +508,7 @@ LinearLayout DPAStoLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
508508 int systolicDepth = dpas.getSystolicDepth ();
509509 int repeatCount = dpas.getRepeatCount ();
510510 int executionSize = dpas.getExecutionSize ();
511- unsigned dimK, dimNonK ;
511+ unsigned KDim, nonKDim ;
512512 if (opIdx == 0 ) { // Operand A
513513 auto regBasesA = DPASRegBasesA (opsPerChannel, repeatCount, threadsPerWarp,
514514 systolicDepth);
@@ -517,16 +517,16 @@ LinearLayout DPAStoLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
517517 tileLayout = LinearLayout ({{kRegister , regBasesA}, {kLane , laneBasesA}},
518518 ArrayRef (outDimNames).take_back (2 ));
519519 // A only repeats by repCluster[rank - 2]
520- dimNonK = rank - 2 ;
521- dimK = rank - 1 ;
522- tileLayout *= LinearLayout::identity1D (repCluster[dimNonK ], kRegister ,
523- outDimNames[dimNonK ]);
520+ nonKDim = rank - 2 ;
521+ KDim = rank - 1 ;
522+ tileLayout *= LinearLayout::identity1D (repCluster[nonKDim ], kRegister ,
523+ outDimNames[nonKDim ]);
524524
525525 // K-dimension is shared among warps
526526 tileLayout *=
527- LinearLayout::zeros1D (warpsPerCTA[dimK ], kWarp , outDimNames[dimK ]);
528- tileLayout *= LinearLayout::identity1D (warpsPerCTA[dimNonK ], kWarp ,
529- outDimNames[dimNonK ]);
527+ LinearLayout::zeros1D (warpsPerCTA[KDim ], kWarp , outDimNames[KDim ]);
528+ tileLayout *= LinearLayout::identity1D (warpsPerCTA[nonKDim ], kWarp ,
529+ outDimNames[nonKDim ]);
530530 if (rank == 3 )
531531 tileLayout *=
532532 LinearLayout::identity1D (warpsPerCTA[0 ], kWarp , outDimNames[0 ]);
@@ -539,16 +539,16 @@ LinearLayout DPAStoLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
539539 tileLayout = LinearLayout ({{kRegister , regBasesB}, {kLane , laneBasesB}},
540540 ArrayRef (outDimNames).take_back (2 ));
541541 // B only repeats by repCluster[rank - 1]
542- dimNonK = rank - 1 ;
543- dimK = rank - 2 ;
544- tileLayout *= LinearLayout::identity1D (repCluster[dimNonK ], kRegister ,
545- outDimNames[dimNonK ]);
542+ nonKDim = rank - 1 ;
543+ KDim = rank - 2 ;
544+ tileLayout *= LinearLayout::identity1D (repCluster[nonKDim ], kRegister ,
545+ outDimNames[nonKDim ]);
546546
547547 // K-dimension is shared among warps
548- tileLayout *= LinearLayout::identity1D (warpsPerCTA[dimNonK ], kWarp ,
549- outDimNames[dimNonK ]);
548+ tileLayout *= LinearLayout::identity1D (warpsPerCTA[nonKDim ], kWarp ,
549+ outDimNames[nonKDim ]);
550550 tileLayout *=
551- LinearLayout::zeros1D (warpsPerCTA[dimK ], kWarp , outDimNames[dimK ]);
551+ LinearLayout::zeros1D (warpsPerCTA[KDim ], kWarp , outDimNames[KDim ]);
552552 if (rank == 3 )
553553 tileLayout *=
554554 LinearLayout::identity1D (warpsPerCTA[0 ], kWarp , outDimNames[0 ]);
@@ -561,18 +561,18 @@ LinearLayout DPAStoLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
561561 // The per-inst layout is repeated at each repCluster.
562562 // Hence, multiply with the identity layouts starting from the
563563 // least significant dimension.
564- dimNonK = rank - 2 ;
565- dimK = rank - 1 ;
566- tileLayout *= LinearLayout::identity1D (repCluster[dimK ], kRegister ,
567- outDimNames[dimK ]);
568- tileLayout *= LinearLayout::identity1D (repCluster[dimNonK ], kRegister ,
569- outDimNames[dimNonK ]);
564+ nonKDim = rank - 2 ;
565+ KDim = rank - 1 ;
566+ tileLayout *= LinearLayout::identity1D (repCluster[KDim ], kRegister ,
567+ outDimNames[KDim ]);
568+ tileLayout *= LinearLayout::identity1D (repCluster[nonKDim ], kRegister ,
569+ outDimNames[nonKDim ]);
570570
571571 // // The identical layout is repeated among warps
572572 tileLayout *=
573- LinearLayout::identity1D (warpsPerCTA[dimK ], kWarp , outDimNames[dimK ]);
574- tileLayout *= LinearLayout::identity1D (warpsPerCTA[dimNonK ], kWarp ,
575- outDimNames[dimNonK ]);
573+ LinearLayout::identity1D (warpsPerCTA[KDim ], kWarp , outDimNames[KDim ]);
574+ tileLayout *= LinearLayout::identity1D (warpsPerCTA[nonKDim ], kWarp ,
575+ outDimNames[nonKDim ]);
576576 if (rank == 3 )
577577 tileLayout *=
578578 LinearLayout::identity1D (warpsPerCTA[0 ], kWarp , outDimNames[0 ]);
@@ -584,12 +584,12 @@ LinearLayout DPAStoLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
584584 SmallVector<int64_t > numReps = dpas.getDPASRepetitions (shape, opIdx);
585585
586586 // numReps is always 3D, we should add 1 to dim id when rank is 2
587- int repDimK = rank == 2 ? dimK + 1 : dimK ;
588- int repDimNonK = rank == 2 ? dimNonK + 1 : dimNonK ;
587+ int repDimK = rank == 2 ? KDim + 1 : KDim ;
588+ int repDimNonK = rank == 2 ? nonKDim + 1 : nonKDim ;
589589 tileLayout *=
590- LinearLayout::identity1D (numReps[repDimK], kRegister , outDimNames[dimK ]);
590+ LinearLayout::identity1D (numReps[repDimK], kRegister , outDimNames[KDim ]);
591591 tileLayout *= LinearLayout::identity1D (numReps[repDimNonK], kRegister ,
592- outDimNames[dimNonK ]);
592+ outDimNames[nonKDim ]);
593593 if (rank == 3 )
594594 tileLayout *=
595595 LinearLayout::identity1D (numReps[0 ], kRegister , outDimNames[0 ]);
0 commit comments