@@ -1533,37 +1533,39 @@ LinearLayout chooseScaledMfmaScaleLayout(
15331533 return newLL;
15341534}
15351535
1536- LinearLayout chooseMfmaLikeStoreLayout (AMDMfmaEncodingAttr mfmaLayout,
1537- ArrayRef<int64_t > shape) {
1538- assert (shape.size () == 2 && mfmaLayout.getMDim () == 32 &&
1539- mfmaLayout.getNDim () == 32 && mfmaLayout.getIsTransposed ());
1540-
1541- MLIRContext *ctx = mfmaLayout.getContext ();
1542- StringAttr kRegister = S (" register" );
1543- StringAttr kLane = S (" lane" );
1544- StringAttr kWarp = S (" warp" );
1545- StringAttr kBlock = S (" block" );
1546-
1547- SmallVector<unsigned > order = getDefaultMmaOrder (mfmaLayout);
1548- auto standardOutDims = standardOutDimNames (ctx, 2 );
1549- // We make each thread handle 8 consecutive elements to enable 128-bit
1550- // global stores for [b]f16 types and keep the thread pattern in each lane
1551- // similar to the canonical mfmaLayout.
1552- LinearLayout mfma8Layout = LinearLayout::empty ();
1553- mfma8Layout =
1554- LinearLayout ({{kRegister , {{1 , 0 }, {2 , 0 }, {4 , 0 }}},
1555- {kLane , {{0 , 1 }, {0 , 2 }, {0 , 4 }, {0 , 8 }, {0 , 16 }, {8 , 0 }}},
1556- {kWarp , {}},
1557- {kBlock , {}}},
1558- {standardOutDims[order[0 ]], standardOutDims[order[1 ]]});
1559-
1560- LinearLayout warpLayout =
1561- identityStandardND (kWarp , mfmaLayout.getWarpsPerCTA (), order);
1562- LinearLayout ctaLayout = mfma8Layout.transposeOuts (standardOutDims) *
1563- warpLayout.transposeOuts (standardOutDims);
1564- mfma8Layout =
1565- combineCtaCgaWithShape (ctaLayout, mfmaLayout.getCTALayout (), shape);
1566- return mfma8Layout;
1536+ std::optional<LinearLayout>
1537+ chooseMfmaLikeStoreLayout (RankedTensorType valType) {
1538+ auto mfmaLayout = cast<AMDMfmaEncodingAttr>(valType.getEncoding ());
1539+
1540+ // We currently only support transposed [B]F16 MFMA32x32 on CDNA4.
1541+ bool isMfma32 = mfmaLayout.getMDim () == 32 && mfmaLayout.getNDim () == 32 ;
1542+ Type elemType = valType.getElementType ();
1543+ if (!(valType.getRank () == 2 && (elemType.isF16 () || elemType.isBF16 ()) &&
1544+ mfmaLayout.getVersionMajor () == 4 && mfmaLayout.getIsTransposed () &&
1545+ isMfma32))
1546+ return {};
1547+
1548+ auto valShape = valType.getShape ();
1549+ LinearLayout mfmaLL = mfmaLayout.toLinearLayout (valShape);
1550+ auto mfmaOutDims = llvm::to_vector (mfmaLL.getOutDimNames ());
1551+ StringAttr dimM = mfmaOutDims[0 ];
1552+ StringAttr dimN = mfmaOutDims[1 ];
1553+
1554+ auto swapLL = LinearLayout::empty ();
1555+ // The rows are kept as is with an identity linear layout.
1556+ swapLL *= LinearLayout::identity1D (valShape[0 ], dimM, dimM);
1557+ // In transposed mfma32 layout, each thread holds 4 consecutive values along N
1558+ // dim. We want to exchange column 4-7 (owned by thread 32-63) and column 8-11
1559+ // (owned by thread 0-31) every 16 columns to make each thread holds 8
1560+ // elements. This would mean exchange the 2nd and 3rd basis vector from an
1561+ // identity linear layout.
1562+ std::vector<std::vector<int32_t >> dimNBases (mfmaLL.getOutDimSizeLog2 (dimN));
1563+ std::generate (dimNBases.begin (), dimNBases.end (),
1564+ [i = 0 ]() mutable { return std::vector<int32_t >{1 << i++}; });
1565+ std::swap (dimNBases[2 ], dimNBases[3 ]);
1566+ swapLL *= LinearLayout ({{dimN, dimNBases}}, {dimN});
1567+
1568+ return mfmaLL.compose (swapLL);
15671569}
15681570
15691571LinearLayout getScaleTMEMStoreLinearLayout (RankedTensorType scaleType,
0 commit comments