@@ -1403,83 +1403,6 @@ LinearLayout chooseDsReadB64TrLayout(Attribute enc, ArrayRef<int64_t> shape,
14031403 return chooseDotDsReadB64TrLayout (dot, shape, elemBitWidth);
14041404}
14051405
1406- // Warp-level block scaling (sm_120, m16n8k32)
1407- // Reference: NVIDIA PTX ISA "Warp-level block scaling"
1408- // https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
1409- //
1410- // Semantics:
1411- // D = (A * SF_A) * (B * SF_B) + C
1412- // scale_vec::1X -> SF_A shape Mx1 (per-row), SF_B shape 1xN (per-col)
1413- //
1414- // Providers (within each warp quad of 4 lanes):
1415- // - A scales are provided by a lane-pair selected by thread-id-a ∈ {0,1}
1416- // (0 => lanes {0,1}, 1 => lanes {2,3} in the quad).
1417- // - B scales are provided by a single lane selected by thread-id-b ∈
1418- // {0,1,2,3}.
1419- //
1420- // Byte selectors (which subfield of the 32-bit metadata is used):
1421- // - 1X: 1 byte => byte-id ∈ {0,1,2,3}
1422- //
1423- // Implementation notes:
1424- // - We support only scale_vec::1X for now.
1425- // - We choose a fixed provider for A (thread-id-a = 0) and B (thread-id-b =
1426- // 0)
1427- // - In this implementation, each lane in a quad has the same scale factor.
1428- LinearLayout getSM120DotScaledScaleLayout (
1429- MLIRContext *ctx, int dotOperandIdx, ArrayRef<int64_t > dotOperandShape,
1430- ArrayRef<unsigned > tilesPerWarp, ArrayRef<unsigned > warpsPerCTA,
1431- unsigned mmaInstrM, unsigned mmaInstrN, CTALayoutAttr ctaLayoutAttr) {
1432- unsigned rank = dotOperandShape.size ();
1433- auto outDims = standardOutDimNames (ctx, rank);
1434-
1435- StringAttr kRegister = StringAttr::get (ctx, " register" );
1436- StringAttr kLane = StringAttr::get (ctx, " lane" );
1437- StringAttr kWarp = StringAttr::get (ctx, " warp" );
1438-
1439- const unsigned mIndex = 0 ;
1440- const unsigned nIndex = 1 ;
1441- const int instrM = mmaInstrM;
1442- const int instrN = mmaInstrN;
1443- const int kSize = dotOperandShape[1 ];
1444- const int mWarps = warpsPerCTA[mIndex ];
1445- const int nWarps = warpsPerCTA[nIndex];
1446- const int totalWarps = mWarps * nWarps;
1447- const unsigned mRep_warp = tilesPerWarp[mIndex ];
1448- const unsigned nRep_warp = tilesPerWarp[nIndex];
1449- const unsigned kRep = std::min<unsigned >(kSize , 2 );
1450-
1451- std::vector<std::vector<int32_t >> registerBase;
1452- std::vector<std::vector<int32_t >> laneBase;
1453- std::vector<std::vector<int32_t >> warpBase;
1454- if (dotOperandIdx == 0 ) { // per-row A-scale
1455- laneBase = {{0 , 8 }, {0 , 0 }, {0 , 1 }, {0 , 2 }, {0 , 4 }};
1456- for (int offset = instrM * mWarps ; offset < instrM * mWarps * mRep_warp ;
1457- offset <<= 1 )
1458- registerBase.push_back ({0 , offset});
1459- for (int w = mWarps ; w < totalWarps; w <<= 1 )
1460- warpBase.push_back ({0 , 0 });
1461- for (int offset = instrM; offset < instrM * mWarps ; offset <<= 1 )
1462- warpBase.push_back ({0 , offset});
1463- } else { // per-col B-scale
1464- laneBase = {{0 , 0 }, {0 , 0 }, {0 , 1 }, {0 , 2 }, {0 , 4 }};
1465- if (nRep_warp > 1 )
1466- registerBase.push_back ({0 , nWarps * instrN});
1467- for (int k = 1 ; k < kRep ; k += 1 )
1468- registerBase.push_back ({1 << (k - 1 ), 0 });
1469- for (int offset = instrN; offset < instrN * nWarps; offset <<= 1 )
1470- warpBase.push_back ({0 , offset});
1471- for (int w = nWarps; w < totalWarps; w <<= 1 )
1472- warpBase.push_back ({0 , 0 });
1473- }
1474-
1475- const unsigned kIdx = (dotOperandShape[0 ] == 1 ) ? 0 : 1 ;
1476- const unsigned mnIdx = 1 - kIdx ;
1477- LinearLayout ctaLayout (
1478- {{kRegister , registerBase}, {kLane , laneBase}, {kWarp , warpBase}},
1479- {outDims[kIdx ], outDims[mnIdx]});
1480- return combineCtaCgaWithShape (ctaLayout, ctaLayoutAttr, dotOperandShape);
1481- }
1482-
14831406LinearLayout chooseScaledMfmaScaleLayout (MLIRContext *ctx, int dotOperandIdx,
14841407 ArrayRef<int64_t > dotOperandShape,
14851408 unsigned mfmaMDim,
0 commit comments