@@ -1407,6 +1407,83 @@ LinearLayout chooseDsReadB64TrLayout(Attribute enc, ArrayRef<int64_t> shape,
14071407 return chooseDotDsReadB64TrLayout (dot, shape, elemBitWidth);
14081408}
14091409
1410+ // Warp-level block scaling (sm_120, m16n8k32)
1411+ // Reference: NVIDIA PTX ISA "Warp-level block scaling"
1412+ // https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
1413+ //
1414+ // Semantics:
1415+ // D = (A * SF_A) * (B * SF_B) + C
1416+ // scale_vec::1X -> SF_A shape Mx1 (per-row), SF_B shape 1xN (per-col)
1417+ //
1418+ // Providers (within each warp quad of 4 lanes):
1419+ // - A scales are provided by a lane-pair selected by thread-id-a ∈ {0,1}
1420+ // (0 => lanes {0,1}, 1 => lanes {2,3} in the quad).
1421+ // - B scales are provided by a single lane selected by thread-id-b ∈
1422+ // {0,1,2,3}.
1423+ //
1424+ // Byte selectors (which subfield of the 32-bit metadata is used):
1425+ // - 1X: 1 byte => byte-id ∈ {0,1,2,3}
1426+ //
1427+ // Implementation notes:
1428+ // - We support only scale_vec::1X for now.
1429+ // - We choose a fixed provider for A (thread-id-a = 0) and B (thread-id-b =
1430+ // 0)
1431+ // - In this implementation, each lane in a quad has the same scale factor.
1432+ LinearLayout getSM120DotScaledScaleLayout (
1433+ MLIRContext *ctx, int dotOperandIdx, ArrayRef<int64_t > dotOperandShape,
1434+ ArrayRef<unsigned > tilesPerWarp, ArrayRef<unsigned > warpsPerCTA,
1435+ unsigned mmaInstrM, unsigned mmaInstrN, CTALayoutAttr ctaLayoutAttr) {
1436+ unsigned rank = dotOperandShape.size ();
1437+ auto outDims = standardOutDimNames (ctx, rank);
1438+
1439+ StringAttr kRegister = StringAttr::get (ctx, " register" );
1440+ StringAttr kLane = StringAttr::get (ctx, " lane" );
1441+ StringAttr kWarp = StringAttr::get (ctx, " warp" );
1442+
1443+ const unsigned mIndex = 0 ;
1444+ const unsigned nIndex = 1 ;
1445+ const int instrM = mmaInstrM;
1446+ const int instrN = mmaInstrN;
1447+ const int kSize = dotOperandShape[1 ];
1448+ const int mWarps = warpsPerCTA[mIndex ];
1449+ const int nWarps = warpsPerCTA[nIndex];
1450+ const int totalWarps = mWarps * nWarps;
1451+ const unsigned mRep_warp = tilesPerWarp[mIndex ];
1452+ const unsigned nRep_warp = tilesPerWarp[nIndex];
1453+ const unsigned kRep = std::min<unsigned >(kSize , 2 );
1454+
1455+ std::vector<std::vector<int32_t >> registerBase;
1456+ std::vector<std::vector<int32_t >> laneBase;
1457+ std::vector<std::vector<int32_t >> warpBase;
1458+ if (dotOperandIdx == 0 ) { // per-row A-scale
1459+ laneBase = {{0 , 8 }, {0 , 0 }, {0 , 1 }, {0 , 2 }, {0 , 4 }};
1460+ for (int offset = instrM * mWarps ; offset < instrM * mWarps * mRep_warp ;
1461+ offset <<= 1 )
1462+ registerBase.push_back ({0 , offset});
1463+ for (int w = mWarps ; w < totalWarps; w <<= 1 )
1464+ warpBase.push_back ({0 , 0 });
1465+ for (int offset = instrM; offset < instrM * mWarps ; offset <<= 1 )
1466+ warpBase.push_back ({0 , offset});
1467+ } else { // per-col B-scale
1468+ laneBase = {{0 , 0 }, {0 , 0 }, {0 , 1 }, {0 , 2 }, {0 , 4 }};
1469+ if (nRep_warp > 1 )
1470+ registerBase.push_back ({0 , nWarps * instrN});
1471+ for (int k = 1 ; k < kRep ; k += 1 )
1472+ registerBase.push_back ({1 << (k - 1 ), 0 });
1473+ for (int offset = instrN; offset < instrN * nWarps; offset <<= 1 )
1474+ warpBase.push_back ({0 , offset});
1475+ for (int w = nWarps; w < totalWarps; w <<= 1 )
1476+ warpBase.push_back ({0 , 0 });
1477+ }
1478+
1479+ const unsigned kIdx = (dotOperandShape[0 ] == 1 ) ? 0 : 1 ;
1480+ const unsigned mnIdx = 1 - kIdx ;
1481+ LinearLayout ctaLayout (
1482+ {{kRegister , registerBase}, {kLane , laneBase}, {kWarp , warpBase}},
1483+ {outDims[kIdx ], outDims[mnIdx]});
1484+ return combineCtaCgaWithShape (ctaLayout, ctaLayoutAttr, dotOperandShape);
1485+ }
1486+
14101487LinearLayout chooseScaledMfmaScaleLayout (MLIRContext *ctx, int dotOperandIdx,
14111488 ArrayRef<int64_t > dotOperandShape,
14121489 unsigned mfmaMDim,
0 commit comments