@@ -1407,83 +1407,6 @@ LinearLayout chooseDsReadB64TrLayout(Attribute enc, ArrayRef<int64_t> shape,
14071407 return chooseDotDsReadB64TrLayout (dot, shape, elemBitWidth);
14081408}
14091409
1410- // Warp-level block scaling (sm_120, m16n8k32)
1411- // Reference: NVIDIA PTX ISA "Warp-level block scaling"
1412- // https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
1413- //
1414- // Semantics:
1415- // D = (A * SF_A) * (B * SF_B) + C
1416- // scale_vec::1X -> SF_A shape Mx1 (per-row), SF_B shape 1xN (per-col)
1417- //
1418- // Providers (within each warp quad of 4 lanes):
1419- // - A scales are provided by a lane-pair selected by thread-id-a ∈ {0,1}
1420- // (0 => lanes {0,1}, 1 => lanes {2,3} in the quad).
1421- // - B scales are provided by a single lane selected by thread-id-b ∈
1422- // {0,1,2,3}.
1423- //
1424- // Byte selectors (which subfield of the 32-bit metadata is used):
1425- // - 1X: 1 byte => byte-id ∈ {0,1,2,3}
1426- //
1427- // Implementation notes:
1428- // - We support only scale_vec::1X for now.
1429- // - We choose a fixed provider for A (thread-id-a = 0) and B (thread-id-b =
1430- // 0)
1431- // - In this implementation, each lane in a quad has the same scale factor.
1432- LinearLayout getSM120DotScaledScaleLayout (
1433- MLIRContext *ctx, int dotOperandIdx, ArrayRef<int64_t > dotOperandShape,
1434- ArrayRef<unsigned > tilesPerWarp, ArrayRef<unsigned > warpsPerCTA,
1435- unsigned mmaInstrM, unsigned mmaInstrN, CTALayoutAttr ctaLayoutAttr) {
1436- unsigned rank = dotOperandShape.size ();
1437- auto outDims = standardOutDimNames (ctx, rank);
1438-
1439- StringAttr kRegister = StringAttr::get (ctx, " register" );
1440- StringAttr kLane = StringAttr::get (ctx, " lane" );
1441- StringAttr kWarp = StringAttr::get (ctx, " warp" );
1442-
1443- const unsigned mIndex = 0 ;
1444- const unsigned nIndex = 1 ;
1445- const int instrM = mmaInstrM;
1446- const int instrN = mmaInstrN;
1447- const int kSize = dotOperandShape[1 ];
1448- const int mWarps = warpsPerCTA[mIndex ];
1449- const int nWarps = warpsPerCTA[nIndex];
1450- const int totalWarps = mWarps * nWarps;
1451- const unsigned mRep_warp = tilesPerWarp[mIndex ];
1452- const unsigned nRep_warp = tilesPerWarp[nIndex];
1453- const unsigned kRep = std::min<unsigned >(kSize , 2 );
1454-
1455- std::vector<std::vector<int32_t >> registerBase;
1456- std::vector<std::vector<int32_t >> laneBase;
1457- std::vector<std::vector<int32_t >> warpBase;
1458- if (dotOperandIdx == 0 ) { // per-row A-scale
1459- laneBase = {{0 , 8 }, {0 , 0 }, {0 , 1 }, {0 , 2 }, {0 , 4 }};
1460- for (int offset = instrM * mWarps ; offset < instrM * mWarps * mRep_warp ;
1461- offset <<= 1 )
1462- registerBase.push_back ({0 , offset});
1463- for (int w = mWarps ; w < totalWarps; w <<= 1 )
1464- warpBase.push_back ({0 , 0 });
1465- for (int offset = instrM; offset < instrM * mWarps ; offset <<= 1 )
1466- warpBase.push_back ({0 , offset});
1467- } else { // per-col B-scale
1468- laneBase = {{0 , 0 }, {0 , 0 }, {0 , 1 }, {0 , 2 }, {0 , 4 }};
1469- if (nRep_warp > 1 )
1470- registerBase.push_back ({0 , nWarps * instrN});
1471- for (int k = 1 ; k < kRep ; k += 1 )
1472- registerBase.push_back ({1 << (k - 1 ), 0 });
1473- for (int offset = instrN; offset < instrN * nWarps; offset <<= 1 )
1474- warpBase.push_back ({0 , offset});
1475- for (int w = nWarps; w < totalWarps; w <<= 1 )
1476- warpBase.push_back ({0 , 0 });
1477- }
1478-
1479- const unsigned kIdx = (dotOperandShape[0 ] == 1 ) ? 0 : 1 ;
1480- const unsigned mnIdx = 1 - kIdx ;
1481- LinearLayout ctaLayout (
1482- {{kRegister , registerBase}, {kLane , laneBase}, {kWarp , warpBase}},
1483- {outDims[kIdx ], outDims[mnIdx]});
1484- return combineCtaCgaWithShape (ctaLayout, ctaLayoutAttr, dotOperandShape);
1485- }
1486-
14871410LinearLayout chooseScaledMfmaScaleLayout (MLIRContext *ctx, int dotOperandIdx,
14881411 ArrayRef<int64_t > dotOperandShape,
14891412 unsigned mfmaMDim,
0 commit comments