@@ -1401,6 +1401,83 @@ LinearLayout chooseDsReadB64TrLayout(Attribute enc, ArrayRef<int64_t> shape,
14011401 return chooseDotDsReadB64TrLayout (dot, shape, elemBitWidth);
14021402}
14031403
1404+ // Warp-level block scaling (sm_120, m16n8k32)
1405+ // Reference: NVIDIA PTX ISA "Warp-level block scaling"
1406+ // https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
1407+ //
1408+ // Semantics:
1409+ // D = (A * SF_A) * (B * SF_B) + C
1410+ // scale_vec::1X -> SF_A shape Mx1 (per-row), SF_B shape 1xN (per-col)
1411+ //
1412+ // Providers (within each warp quad of 4 lanes):
1413+ // - A scales are provided by a lane-pair selected by thread-id-a ∈ {0,1}
1414+ // (0 => lanes {0,1}, 1 => lanes {2,3} in the quad).
1415+ // - B scales are provided by a single lane selected by thread-id-b ∈
1416+ // {0,1,2,3}.
1417+ //
1418+ // Byte selectors (which subfield of the 32-bit metadata is used):
1419+ // - 1X: 1 byte => byte-id ∈ {0,1,2,3}
1420+ //
1421+ // Implementation notes:
1422+ // - We support only scale_vec::1X for now.
1423+ // - We choose a fixed provider for A (thread-id-a = 0) and B (thread-id-b =
1424+ // 0)
1425+ // - In this implementation, each lane in a quad has the same scale factor.
1426+ LinearLayout getSM120DotScaledScaleLayout (
1427+ MLIRContext *ctx, int dotOperandIdx, ArrayRef<int64_t > dotOperandShape,
1428+ ArrayRef<unsigned > tilesPerWarp, ArrayRef<unsigned > warpsPerCTA,
1429+ unsigned mmaInstrM, unsigned mmaInstrN, CTALayoutAttr ctaLayoutAttr) {
1430+ unsigned rank = dotOperandShape.size ();
1431+ auto outDims = standardOutDimNames (ctx, rank);
1432+
1433+ StringAttr kRegister = StringAttr::get (ctx, " register" );
1434+ StringAttr kLane = StringAttr::get (ctx, " lane" );
1435+ StringAttr kWarp = StringAttr::get (ctx, " warp" );
1436+
1437+ const unsigned mIndex = 0 ;
1438+ const unsigned nIndex = 1 ;
1439+ const int instrM = mmaInstrM;
1440+ const int instrN = mmaInstrN;
1441+ const int kSize = dotOperandShape[1 ];
1442+ const int mWarps = warpsPerCTA[mIndex ];
1443+ const int nWarps = warpsPerCTA[nIndex];
1444+ const int totalWarps = mWarps * nWarps;
1445+ const unsigned mRep_warp = tilesPerWarp[mIndex ];
1446+ const unsigned nRep_warp = tilesPerWarp[nIndex];
1447+ const unsigned kRep = std::min<unsigned >(kSize , 2 );
1448+
1449+ std::vector<std::vector<int32_t >> registerBase;
1450+ std::vector<std::vector<int32_t >> laneBase;
1451+ std::vector<std::vector<int32_t >> warpBase;
1452+ if (dotOperandIdx == 0 ) { // per-row A-scale
1453+ laneBase = {{0 , 8 }, {0 , 0 }, {0 , 1 }, {0 , 2 }, {0 , 4 }};
1454+ for (int offset = instrM * mWarps ; offset < instrM * mWarps * mRep_warp ;
1455+ offset <<= 1 )
1456+ registerBase.push_back ({0 , offset});
1457+ for (int w = mWarps ; w < totalWarps; w <<= 1 )
1458+ warpBase.push_back ({0 , 0 });
1459+ for (int offset = instrM; offset < instrM * mWarps ; offset <<= 1 )
1460+ warpBase.push_back ({0 , offset});
1461+ } else { // per-col B-scale
1462+ laneBase = {{0 , 0 }, {0 , 0 }, {0 , 1 }, {0 , 2 }, {0 , 4 }};
1463+ if (nRep_warp > 1 )
1464+ registerBase.push_back ({0 , nWarps * instrN});
1465+ for (int k = 1 ; k < kRep ; k += 1 )
1466+ registerBase.push_back ({1 << (k - 1 ), 0 });
1467+ for (int offset = instrN; offset < instrN * nWarps; offset <<= 1 )
1468+ warpBase.push_back ({0 , offset});
1469+ for (int w = nWarps; w < totalWarps; w <<= 1 )
1470+ warpBase.push_back ({0 , 0 });
1471+ }
1472+
1473+ const unsigned kIdx = (dotOperandShape[0 ] == 1 ) ? 0 : 1 ;
1474+ const unsigned mnIdx = 1 - kIdx ;
1475+ LinearLayout ctaLayout (
1476+ {{kRegister , registerBase}, {kLane , laneBase}, {kWarp , warpBase}},
1477+ {outDims[kIdx ], outDims[mnIdx]});
1478+ return combineCtaCgaWithShape (ctaLayout, ctaLayoutAttr, dotOperandShape);
1479+ }
1480+
14041481LinearLayout chooseScaledMfmaScaleLayout (MLIRContext *ctx, int dotOperandIdx,
14051482 ArrayRef<int64_t > dotOperandShape,
14061483 unsigned mfmaMDim,
0 commit comments