@@ -598,122 +598,6 @@ emitOffsetForBlockedLayout(const BlockedEncodingAttr &blockedLayout,
598598// Mma layout indices
599599// -----------------------------------------------------------------------
600600
601- inline SmallVector<Value>
602- emitBaseIndexWithinCTAForMmaLayoutV1 (Location loc, RewriterBase &rewriter,
603- const NvidiaMmaEncodingAttr &mmaLayout,
604- RankedTensorType type) {
605- auto shape = type.getShape ();
606- auto wpt = mmaLayout.getWarpsPerCTA ();
607- static constexpr std::array<int , 3 > fpw{{2 , 2 , 1 }};
608- auto [isARow, isBRow, isAVec4, isBVec4, _] =
609- mmaLayout.decodeVoltaLayoutStates ();
610-
611- Value thread = getThreadId (rewriter, loc);
612- auto *ctx = thread.getContext ();
613- Value _1 = i32_val (1 );
614- Value _2 = i32_val (2 );
615- Value _4 = i32_val (4 );
616- Value _16 = i32_val (16 );
617- Value _32 = i32_val (32 );
618- Value _fpw0 = i32_val (fpw[0 ]);
619- Value _fpw1 = i32_val (fpw[1 ]);
620-
621- // A info
622- auto aRep = mmaLayout.getMMAv1Rep (0 );
623- auto aSpw = mmaLayout.getMMAv1ShapePerWarp (0 );
624- // B info
625- auto bSpw = mmaLayout.getMMAv1ShapePerWarp (1 );
626- auto bRep = mmaLayout.getMMAv1Rep (1 );
627-
628- SmallVector<int , 2 > rep ({aRep[0 ], bRep[1 ]});
629- SmallVector<int , 2 > spw ({aSpw[0 ], bSpw[1 ]});
630- SmallVector<unsigned , 2 > shapePerCTA ({spw[0 ] * wpt[0 ], spw[1 ] * wpt[1 ]});
631-
632- Value lane = urem (thread, _32);
633- Value warp = udiv (thread, _32);
634-
635- Value warp0 = urem (warp, i32_val (wpt[0 ]));
636- Value warp12 = udiv (warp, i32_val (wpt[0 ]));
637- Value warp1 = urem (warp12, i32_val (wpt[1 ]));
638-
639- // warp offset
640- Value offWarpM = mul (warp0, i32_val (spw[0 ]));
641- Value offWarpN = mul (warp1, i32_val (spw[1 ]));
642- // quad offset
643- Value offQuadM = mul (udiv (and_ (lane, _16), _4), _fpw0);
644- Value offQuadN = mul (udiv (and_ (lane, _16), _4), _fpw1);
645- // pair offset
646- Value offPairM = udiv (urem (lane, _16), _4);
647- offPairM = urem (offPairM, _fpw0);
648- offPairM = mul (offPairM, _4);
649- Value offPairN = udiv (urem (lane, _16), _4);
650- offPairN = udiv (offPairN, _fpw0);
651- offPairN = urem (offPairN, _fpw1);
652- offPairN = mul (offPairN, _4);
653- offPairM = mul (offPairM, i32_val (rep[0 ] / 2 ));
654- offQuadM = mul (offQuadM, i32_val (rep[0 ] / 2 ));
655- offPairN = mul (offPairN, i32_val (rep[1 ] / 2 ));
656- offQuadN = mul (offQuadN, i32_val (rep[1 ] / 2 ));
657- // quad pair offset
658- Value offLaneM = add (offPairM, offQuadM);
659- Value offLaneN = add (offPairN, offQuadN);
660- // a, b offset
661- Value offsetAM = add (offWarpM, offLaneM);
662- Value offsetBN = add (offWarpN, offLaneN);
663- // m indices
664- Value offsetCM = add (and_ (lane, _1), offsetAM);
665- // n indices
666- Value offsetCN = add ((and_ (lane, _2)), (add (offWarpN, offPairN)));
667- return {offsetCM, offsetCN};
668- }
669-
670- inline SmallVector<SmallVector<unsigned >>
671- emitOffsetForMmaLayoutV1 (const NvidiaMmaEncodingAttr &mmaLayout,
672- RankedTensorType type) {
673- auto shape = type.getShape ();
674-
675- auto [isARow, isBRow, isAVec4, isBVec4, _] =
676- mmaLayout.decodeVoltaLayoutStates ();
677-
678- // TODO: seems like the pattern below to get `rep`/`spw` appears quite often
679- // A info
680- auto aRep = mmaLayout.getMMAv1Rep (0 );
681- auto aSpw = mmaLayout.getMMAv1ShapePerWarp (0 );
682- // B info
683- auto bSpw = mmaLayout.getMMAv1ShapePerWarp (1 );
684- auto bRep = mmaLayout.getMMAv1Rep (1 );
685-
686- auto wpt = mmaLayout.getWarpsPerCTA ();
687- static constexpr std::array<int , 3 > fpw{{2 , 2 , 1 }};
688- SmallVector<int , 2 > rep ({aRep[0 ], bRep[1 ]});
689- SmallVector<int , 2 > spw ({aSpw[0 ], bSpw[1 ]});
690- SmallVector<unsigned , 2 > shapePerCTA ({spw[0 ] * wpt[0 ], spw[1 ] * wpt[1 ]});
691-
692- SmallVector<unsigned > idxM;
693- for (unsigned m = 0 ; m < shape[0 ]; m += shapePerCTA[0 ])
694- for (unsigned mm = 0 ; mm < rep[0 ]; ++mm)
695- idxM.push_back (m + mm * 2 );
696-
697- SmallVector<unsigned > idxN;
698- for (int n = 0 ; n < shape[1 ]; n += shapePerCTA[1 ]) {
699- for (int nn = 0 ; nn < rep[1 ]; ++nn) {
700- idxN.push_back (n + nn / 2 * 4 + (nn % 2 ) * 2 * fpw[1 ] * rep[1 ]);
701- idxN.push_back (n + nn / 2 * 4 + (nn % 2 ) * 2 * fpw[1 ] * rep[1 ] + 1 );
702- }
703- }
704-
705- SmallVector<SmallVector<unsigned >> ret;
706- for (unsigned x1 : idxN) { // N
707- for (unsigned x0 : idxM) { // M
708- SmallVector<unsigned > idx (2 );
709- idx[0 ] = x0; // M
710- idx[1 ] = x1; // N
711- ret.push_back (std::move (idx));
712- }
713- }
714- return ret;
715- }
716-
717601inline SmallVector<SmallVector<unsigned >>
718602emitOffsetForMmaLayoutV2 (const NvidiaMmaEncodingAttr &mmaLayout,
719603 RankedTensorType type) {
@@ -1179,9 +1063,6 @@ emitBaseIndexForLayoutImpl(Location loc, RewriterBase &rewriter,
11791063 result = emitBaseIndexWithinCTAForBlockedLayout (loc, rewriter,
11801064 blockedLayout, type);
11811065 } else if (auto mmaLayout = mlir::dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
1182- if (mmaLayout.isVolta ())
1183- result =
1184- emitBaseIndexWithinCTAForMmaLayoutV1 (loc, rewriter, mmaLayout, type);
11851066 if (mmaLayout.isAmpere () || mmaLayout.isHopper ())
11861067 result = emitBaseIndexWithinCTAForMmaLayoutV2V3 (loc, rewriter, mmaLayout,
11871068 type);
@@ -1536,18 +1417,6 @@ inline Value packLLVector(Location loc, ValueRange vals,
15361417 return vec;
15371418}
15381419
1539- inline bool isLayoutMmaV1 (Attribute layout) {
1540- bool isMmaV1 = false ;
1541- if (auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
1542- isMmaV1 = mmaLayout.isVolta ();
1543- }
1544- if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(layout)) {
1545- isMmaV1 = isa<NvidiaMmaEncodingAttr>(sliceLayout.getParent ()) &&
1546- cast<NvidiaMmaEncodingAttr>(sliceLayout.getParent ()).isVolta ();
1547- }
1548- return isMmaV1;
1549- }
1550-
15511420} // namespace mlir
15521421
15531422#endif
0 commit comments