@@ -1434,64 +1434,54 @@ chooseDsReadTrLayout(Attribute enc, ArrayRef<int64_t> shape,
14341434 }
14351435}
14361436
1437- LinearLayout chooseScaledWmmaScaleLayout (
1438- MLIRContext *ctx, int dotOperandIdx,
1439- const std::vector<std::vector<int32_t >> &dotOperandWarpBasis,
1440- ArrayRef<int64_t > dotOperandShape) {
1437+ LinearLayout chooseScaledWmmaScaleLayout (MLIRContext *ctx, int dotOperandIdx,
1438+ ArrayRef<unsigned > warpsPerCTA,
1439+ ArrayRef<int64_t > dotOperandShape) {
14411440 using basisT = std::vector<std::vector<int32_t >>;
14421441 unsigned rank = dotOperandShape.size ();
14431442 auto order = mlir::triton::gpu::getMatrixOrder (rank, /* rowMajor=*/ true );
1444- auto standardOutDims = standardOutDimNames (ctx, rank);
1443+ auto outDimNames = standardOutDimNames (ctx, rank);
1444+
14451445 StringAttr kRegister = StringAttr::get (ctx, " register" );
14461446 StringAttr kLane = StringAttr::get (ctx, " lane" );
14471447 StringAttr kWarp = StringAttr::get (ctx, " warp" );
14481448 StringAttr kBlock = StringAttr::get (ctx, " block" );
1449- unsigned int scaleKWidth = dotOperandShape[1 ];
1450- // Init register layout. Will be adjusted later
1451- auto regs =
1452- mlir::triton::identityStandardND (kRegister , {1 , scaleKWidth}, order);
1453- LinearLayout lanes = LinearLayout::empty ();
1449+
14541450 // In scaled dot, the shapes of operands(without batch dimension) are,
14551451 // respectively:
14561452 // - A: [M, K]
14571453 // - B: [K, N]
14581454 // - aScale: [M, K / 32 or 16]
14591455 // - bScale: [N, K / 32 or 16]
1460- //
1461- // To correctly feed A/B and its scale into instruction, we need to
1462- // distribute aScale/bScale among warps in the same way as A/B. But bScale
1463- // is not transposed like B. So we need to transpose the warp layout of
1464- // bScale.
1465- //
1466- // The tricky part is, our desired outputs are [dim0, dim1], but
1467- // at this position, the layouts are transposed to [dim1, dim0]. So
1468- // instead of reverse bScale's layout, we need to reverse aScale's. There
1469- // will be a transpose in the end to correct everything.
1470- basisT warps = dotOperandWarpBasis;
1471- if (dotOperandIdx == 0 ) {
1472- for (auto &basis : warps) {
1473- std::reverse (basis.begin (), basis.end ());
1474- }
1475- }
1456+ auto dimK = outDimNames[order[0 ]];
1457+ auto dimNonK = outDimNames[order[1 ]];
14761458
1477- lanes = LinearLayout ({{kLane , {{0 , 1 }, {0 , 2 }, {0 , 4 }, {0 , 8 }, {0 , 0 }}},
1478- {kWarp , warps},
1479- {kBlock , {}}},
1480- {standardOutDims[order[0 ]], standardOutDims[order[1 ]]});
1481- LinearLayout newLL = regs * lanes;
1459+ // Each lane holds kWidth=4 consecutive values along the k dim.
1460+ // The first 16 lanes are distributed along the non-k dim. We are not using
1461+ // the remaining 16 lanes, so just let them duplicate values of the first 16
1462+ // lanes. If the shape along the k dim is larger than kWidth, repeat this
1463+ // pattern to fill the k dim.
1464+ unsigned scaleKWidth = 4 ;
1465+ auto kSize = dotOperandShape[1 ];
1466+ LinearLayout tileLayout =
1467+ LinearLayout::identity1D (scaleKWidth, kRegister , dimK) *
1468+ LinearLayout::identity1D (16 , kLane , dimNonK) *
1469+ LinearLayout::zeros1D (2 , kLane , dimK) *
1470+ LinearLayout::identity1D (kSize / scaleKWidth, kRegister , dimK);
14821471
1483- // Adjust register-level layout to fill the shape, at this level, both
1484- // aScale and bScale should align with A operand.
1485- SmallVector< int , 2 > repOrder = { 1 , 0 };
1486- for ( auto d : repOrder) {
1487- auto outDim = standardOutDims[d];
1488- auto dimSize = newLL. getOutDimSize (outDim) ;
1489- newLL *= LinearLayout::identity1D (dotOperandShape[d] / dimSize, kRegister ,
1490- outDim );
1491- }
1492- newLL = newLL .transposeOuts (standardOutDims );
1472+ auto warpsPerCTANew = (dotOperandIdx == 1 )
1473+ ? SmallVector{warpsPerCTA[ 1 ], warpsPerCTA[ 0 ]}
1474+ : SmallVector{warpsPerCTA[ 0 ], warpsPerCTA[ 1 ] };
1475+
1476+ auto warpOrder = (dotOperandIdx == 1 ) ? SmallVector< unsigned >{ 0 , 1 }
1477+ : SmallVector< unsigned >{ 1 , 0 } ;
1478+ LinearLayout warpLayout =
1479+ identityStandardND ( kWarp , warpsPerCTANew, warpOrder );
1480+ LinearLayout ctaLayout = tileLayout. transposeOuts (outDimNames) *
1481+ warpLayout .transposeOuts (outDimNames );
14931482
1494- return newLL;
1483+ return combineCtaCgaWithShape (
1484+ ctaLayout, CTALayoutAttr::getDefault (ctx, /* rank=*/ 2 ), dotOperandShape);
14951485}
14961486
14971487// PTX ISA - Warp-level MMA Block Scaling
0 commit comments