@@ -1515,32 +1515,133 @@ chooseMfmaLikeStoreLayout(RankedTensorType valType) {
1515
1515
return {};
1516
1516
auto mfmaLayout = cast<AMDMfmaEncodingAttr>(valType.getEncoding ());
1517
1517
1518
- // We currently only support transposed [B]F16 MFMA32x32 on CDNA4.
1518
+ // We currently only support transposed [B]F16 MFMA32x32 and MFMA16x16 on
1519
+ // CDNA4.
1519
1520
bool isMfma32 = mfmaLayout.getMDim () == 32 && mfmaLayout.getNDim () == 32 ;
1521
+ bool isMfma16 = mfmaLayout.getMDim () == 16 && mfmaLayout.getNDim () == 16 ;
1522
+
1523
+ auto valShape = valType.getShape ();
1524
+ // For mfma16x16, to use in-wavefront swap, we need to make sure the tiles
1525
+ // used are in one wavefront if there are multiple tiles, which means
1526
+ // warpsPerCTA = [numWarps, 1] and at least two tiles along the N dim. For
1527
+ // now, it is only possible for FA-like kernels since during mfma generation,
1528
+ // the WarpsPerCTA of the head dot in the chain will be reshaped to [numWaprs,
1529
+ // 1].
1530
+ // TODO: For gemm-like kernel, the transformation here cannot be applied for
1531
+ // now and will support it.
1532
+ bool validForMfma16 = isMfma16 && valShape.back () >= 16 * 2 &&
1533
+ mfmaLayout.getWarpsPerCTA ().back () == 1 ;
1534
+
1520
1535
Type elemType = valType.getElementType ();
1521
1536
if (!(valType.getRank () == 2 && (elemType.isF16 () || elemType.isBF16 ()) &&
1522
1537
mfmaLayout.getVersionMajor () == 4 && mfmaLayout.getIsTransposed () &&
1523
- isMfma32))
1538
+ ( isMfma32 || validForMfma16) ))
1524
1539
return {};
1525
1540
1526
- auto valShape = valType.getShape ();
1527
1541
LinearLayout mfmaLL = mfmaLayout.toLinearLayout (valShape);
1528
1542
auto mfmaOutDims = llvm::to_vector (mfmaLL.getOutDimNames ());
1529
1543
StringAttr dimM = mfmaOutDims[0 ];
1530
1544
StringAttr dimN = mfmaOutDims[1 ];
1531
-
1532
1545
auto swapLL = LinearLayout::empty ();
1533
1546
// The rows are kept as is with an identity linear layout.
1534
1547
swapLL *= LinearLayout::identity1D (valShape[0 ], dimM, dimM);
1535
- // In transposed mfma32 layout, each thread holds 4 consecutive values along N
1536
- // dim. We want to exchange column 4-7 (owned by thread 32-63) and column 8-11
1537
- // (owned by thread 0-31) every 16 columns to make each thread holds 8
1538
- // elements. This would mean exchange the 2nd and 3rd basis vector from an
1539
- // identity linear layout.
1548
+ /*
1549
+ clang-format off
1550
+ In transposed mfma32 layout, Each thread holds 4 consecutive values along N
1551
+ dim. We want to exchange column 4-7 (owned by thread 32-63, BLK0) and column
1552
+ 8-11 (owned by thread 0-31, BLK1) every 16 columns to make each thread holds 8
1553
+ elements. This would mean exchange the 2nd and 3rd basis vector from an
1554
+ identity linear layout on tensor elements.
1555
+
1556
+ Correspondingly, the transposed mfma16 layout, the output of
1557
+ transposed of mfma16x16 is:
1558
+
1559
+ N/register
1560
+ M/Lane v0 v1 v2 v3 v4 v5 v6 v7
1561
+ -------------------------------------------------------------------------
1562
+ row0: 0-15 | tile-0 | tile-0 | tile-0 | tile-0 | tile-1 | tile-1 | tile-1 | tile-1 |
1563
+ -------------------------------------------------------------------------
1564
+ row1: 16-31 | tile-0 | tile-0 | tile-0 | tile-0 | tile-1 | tile-1 | tile-1 | tile-1 |
1565
+ -------------------------------------------------------------------------
1566
+ row2: 32-47 | tile-0 | tile-0 | tile-0 | tile-0 | tile-1 | tile-1 | tile-1 | tile-1 |
1567
+ -------------------------------------------------------------------------
1568
+ row3: 48-63 | tile-0 | tile-0 | tile-0 | tile-0 | tile-1 | tile-1 | tile-1 | tile-1 |
1569
+ -------------------------------------------------------------------------
1570
+ which means:
1571
+ The columns from v0 to v3 are in the one output of mfma16x16 and
1572
+ the columns from v4 to v7 are in the one output of mfma16x16,
1573
+
1574
+ The following graph is the same as the one above, execept the tile number is replaced with coordinates in the tenor,
1575
+ N/register
1576
+ -----------------------------------------------
1577
+ M/lane |(0, 0) ... (0, 3) | (0, 16) ... (0, 19) |
1578
+ |.... | sub-tensor-0 |
1579
+ |(15, 0) ... (15, 3) | (15, 16) ... (15, 19) |
1580
+ -----------------------------------------------
1581
+ |(0, 4) ... (0, 7) | (0, 20) ... (0, 23) |
1582
+ |sub-tensor-1 | .... |
1583
+ |(15, 0) ... (15, 3) | (15, 20) ... (15, 23) |
1584
+ -----------------------------------------------
1585
+ |(0, 8) ... (0, 11)| (0, 24) ... (0, 27) |
1586
+ |.... | sub-tensor-2 |
1587
+ |(15, 8) ... (15, 11)| (15, 24) ... (15, 27) |
1588
+ -----------------------------------------------
1589
+ |(0, 12) ... (0, 15)| (0, 28) ... (0, 31) |
1590
+ |sub-tensor-3 | .... |
1591
+ |(15, 12) ... (15, 15)| (15, 28) ... (15, 31) |
1592
+ -----------------------------------------------
1593
+ The basis vector for lane and register are:
1594
+ Register = {{0, 1}, {0, 2}}
1595
+ Lane = {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 4}, {0, 8}}
1596
+ With this layout, only 4xfp16 can be packed in the final global store.
1597
+
1598
+ To use 128-bits global store, we need to pack 8 elements, which means the layout looks like:
1599
+ N/register
1600
+ M/Lane v0 v1 v2 v3 v4 v5 v6 v7
1601
+ -------------------------------------------------------------------------
1602
+ row0: 0-15 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 |
1603
+ -------------------------------------------------------------------------
1604
+ row1: 16-31 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 |
1605
+ -------------------------------------------------------------------------
1606
+ row2: 32-47 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 |
1607
+ -------------------------------------------------------------------------
1608
+ row3: 48-63 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 |
1609
+ -------------------------------------------------------------------------
1610
+
1611
+ The following graph is the same as the one above, execept the tile number is replaced with coordinates in the tenor:
1612
+ N/register
1613
+ -----------------------------------------------
1614
+ |(0, 0) ... (0, 3) | (0, 4) ... (0, 7) |
1615
+ |.... | sub-tensor-1 |
1616
+ |(15, 0) ... (15, 3) | (15, 16) ... (15, 19) |
1617
+ -----------------------------------------------
1618
+ |(0, 16) ... (0, 19) | (0, 20) ... (0, 23) |
1619
+ |sub-tensor-0 | .... |
1620
+ |(15, 16) ... (15, 19)| (15, 20) ... (15, 23) |
1621
+ -----------------------------------------------
1622
+ |(0, 8) ... (0, 11)| (0, 12) ... (0, 15) |
1623
+ |.... | sub-tensor-3 |
1624
+ |(15, 8) ... (15, 11)| (15, 12) ... (15, 15) |
1625
+ -----------------------------------------------
1626
+ |(0, 24) ... (0, 27)| (0, 28) ... (0, 31) |
1627
+ |sub-tensor-2 | .... |
1628
+ |(15, 24) ... (15, 27)| (15, 28) ... (15, 31) |
1629
+ -----------------------------------------------
1630
+ which means we need to exchange sub-tensor-0 with sub-tensor-1 and sub-tensor-2 and sub-tensor-3.
1631
+ And basis vector for lane and register are:
1632
+ Register = {{0, 1}, {0, 2}, {0, 4}}
1633
+ Lane = {{1, 0}, {2, 0, [4, 0}, {8, 0}, {0, 16}, {0, 8}}
1634
+
1635
+ The steps to get this layout are, firstly we check the last dim of WarpsPerCTA is 1, so we can use v_permlane16.
1636
+ Then, we exchange the 2nd and 4th elements in the basis vector of an identity linear and then it will be composed with
1637
+ the original mfma16 LL.
1638
+ clang-format on
1639
+ */
1640
+ auto destIdxInBases = isMfma32 ? 3 : 4 ;
1540
1641
std::vector<std::vector<int32_t >> dimNBases (mfmaLL.getOutDimSizeLog2 (dimN));
1541
1642
std::generate (dimNBases.begin (), dimNBases.end (),
1542
1643
[i = 0 ]() mutable { return std::vector<int32_t >{1 << i++}; });
1543
- std::swap (dimNBases[2 ], dimNBases[3 ]);
1644
+ std::swap (dimNBases[2 ], dimNBases[destIdxInBases ]);
1544
1645
swapLL *= LinearLayout ({{dimN, dimNBases}}, {dimN});
1545
1646
1546
1647
return mfmaLL.compose (swapLL);
0 commit comments