@@ -1511,32 +1511,133 @@ chooseMfmaLikeStoreLayout(RankedTensorType valType) {
1511
1511
return {};
1512
1512
auto mfmaLayout = cast<AMDMfmaEncodingAttr>(valType.getEncoding ());
1513
1513
1514
- // We currently only support transposed [B]F16 MFMA32x32 on CDNA4.
1514
+ // We currently only support transposed [B]F16 MFMA32x32 and MFMA16x16 on
1515
+ // CDNA4.
1515
1516
bool isMfma32 = mfmaLayout.getMDim () == 32 && mfmaLayout.getNDim () == 32 ;
1517
+ bool isMfma16 = mfmaLayout.getMDim () == 16 && mfmaLayout.getNDim () == 16 ;
1518
+
1519
+ auto valShape = valType.getShape ();
1520
+ // For mfma16x16, to use in-wavefront swap, we need to make sure the tiles
1521
+ // used are in one wavefront if there are multiple tiles, which means
1522
+ // warpsPerCTA = [numWarps, 1] and at least two tiles along the N dim. For
1523
+ // now, it is only possible for FA-like kernels since during mfma generation,
1524
+ // the WarpsPerCTA of the head dot in the chain will be reshaped to [numWaprs,
1525
+ // 1].
1526
+ // TODO: For gemm-like kernel, the transformation here cannot be applied for
1527
+ // now and will support it.
1528
+ bool validForMfma16 = isMfma16 && valShape.back () >= 16 * 2 &&
1529
+ mfmaLayout.getWarpsPerCTA ().back () == 1 ;
1530
+
1516
1531
Type elemType = valType.getElementType ();
1517
1532
if (!(valType.getRank () == 2 && (elemType.isF16 () || elemType.isBF16 ()) &&
1518
1533
mfmaLayout.getVersionMajor () == 4 && mfmaLayout.getIsTransposed () &&
1519
- isMfma32))
1534
+ ( isMfma32 || validForMfma16) ))
1520
1535
return {};
1521
1536
1522
- auto valShape = valType.getShape ();
1523
1537
LinearLayout mfmaLL = mfmaLayout.toLinearLayout (valShape);
1524
1538
auto mfmaOutDims = llvm::to_vector (mfmaLL.getOutDimNames ());
1525
1539
StringAttr dimM = mfmaOutDims[0 ];
1526
1540
StringAttr dimN = mfmaOutDims[1 ];
1527
-
1528
1541
auto swapLL = LinearLayout::empty ();
1529
1542
// The rows are kept as is with an identity linear layout.
1530
1543
swapLL *= LinearLayout::identity1D (valShape[0 ], dimM, dimM);
1531
- // In transposed mfma32 layout, each thread holds 4 consecutive values along N
1532
- // dim. We want to exchange column 4-7 (owned by thread 32-63) and column 8-11
1533
- // (owned by thread 0-31) every 16 columns to make each thread holds 8
1534
- // elements. This would mean exchange the 2nd and 3rd basis vector from an
1535
- // identity linear layout.
1544
+ /*
1545
+ clang-format off
1546
+ In transposed mfma32 layout, Each thread holds 4 consecutive values along N
1547
+ dim. We want to exchange column 4-7 (owned by thread 32-63, BLK0) and column
1548
+ 8-11 (owned by thread 0-31, BLK1) every 16 columns to make each thread holds 8
1549
+ elements. This would mean exchange the 2nd and 3rd basis vector from an
1550
+ identity linear layout on tensor elements.
1551
+
1552
+ Correspondingly, the transposed mfma16 layout, the output of
1553
+ transposed of mfma16x16 is:
1554
+
1555
+ N/register
1556
+ M/Lane v0 v1 v2 v3 v4 v5 v6 v7
1557
+ -------------------------------------------------------------------------
1558
+ row0: 0-15 | tile-0 | tile-0 | tile-0 | tile-0 | tile-1 | tile-1 | tile-1 | tile-1 |
1559
+ -------------------------------------------------------------------------
1560
+ row1: 16-31 | tile-0 | tile-0 | tile-0 | tile-0 | tile-1 | tile-1 | tile-1 | tile-1 |
1561
+ -------------------------------------------------------------------------
1562
+ row2: 32-47 | tile-0 | tile-0 | tile-0 | tile-0 | tile-1 | tile-1 | tile-1 | tile-1 |
1563
+ -------------------------------------------------------------------------
1564
+ row3: 48-63 | tile-0 | tile-0 | tile-0 | tile-0 | tile-1 | tile-1 | tile-1 | tile-1 |
1565
+ -------------------------------------------------------------------------
1566
+ which means:
1567
+ The columns from v0 to v3 are in the one output of mfma16x16 and
1568
+ the columns from v4 to v7 are in the one output of mfma16x16,
1569
+
1570
+ The following graph is the same as the one above, execept the tile number is replaced with coordinates in the tenor,
1571
+ N/register
1572
+ -----------------------------------------------
1573
+ M/lane |(0, 0) ... (0, 3) | (0, 16) ... (0, 19) |
1574
+ |.... | sub-tensor-0 |
1575
+ |(15, 0) ... (15, 3) | (15, 16) ... (15, 19) |
1576
+ -----------------------------------------------
1577
+ |(0, 4) ... (0, 7) | (0, 20) ... (0, 23) |
1578
+ |sub-tensor-1 | .... |
1579
+ |(15, 0) ... (15, 3) | (15, 20) ... (15, 23) |
1580
+ -----------------------------------------------
1581
+ |(0, 8) ... (0, 11)| (0, 24) ... (0, 27) |
1582
+ |.... | sub-tensor-2 |
1583
+ |(15, 8) ... (15, 11)| (15, 24) ... (15, 27) |
1584
+ -----------------------------------------------
1585
+ |(0, 12) ... (0, 15)| (0, 28) ... (0, 31) |
1586
+ |sub-tensor-3 | .... |
1587
+ |(15, 12) ... (15, 15)| (15, 28) ... (15, 31) |
1588
+ -----------------------------------------------
1589
+ The basis vector for lane and register are:
1590
+ Register = {{0, 1}, {0, 2}}
1591
+ Lane = {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 4}, {0, 8}}
1592
+ With this layout, only 4xfp16 can be packed in the final global store.
1593
+
1594
+ To use 128-bits global store, we need to pack 8 elements, which means the layout looks like:
1595
+ N/register
1596
+ M/Lane v0 v1 v2 v3 v4 v5 v6 v7
1597
+ -------------------------------------------------------------------------
1598
+ row0: 0-15 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 |
1599
+ -------------------------------------------------------------------------
1600
+ row1: 16-31 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 |
1601
+ -------------------------------------------------------------------------
1602
+ row2: 32-47 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 |
1603
+ -------------------------------------------------------------------------
1604
+ row3: 48-63 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 |
1605
+ -------------------------------------------------------------------------
1606
+
1607
+ The following graph is the same as the one above, execept the tile number is replaced with coordinates in the tenor:
1608
+ N/register
1609
+ -----------------------------------------------
1610
+ |(0, 0) ... (0, 3) | (0, 4) ... (0, 7) |
1611
+ |.... | sub-tensor-1 |
1612
+ |(15, 0) ... (15, 3) | (15, 16) ... (15, 19) |
1613
+ -----------------------------------------------
1614
+ |(0, 16) ... (0, 19) | (0, 20) ... (0, 23) |
1615
+ |sub-tensor-0 | .... |
1616
+ |(15, 16) ... (15, 19)| (15, 20) ... (15, 23) |
1617
+ -----------------------------------------------
1618
+ |(0, 8) ... (0, 11)| (0, 12) ... (0, 15) |
1619
+ |.... | sub-tensor-3 |
1620
+ |(15, 8) ... (15, 11)| (15, 12) ... (15, 15) |
1621
+ -----------------------------------------------
1622
+ |(0, 24) ... (0, 27)| (0, 28) ... (0, 31) |
1623
+ |sub-tensor-2 | .... |
1624
+ |(15, 24) ... (15, 27)| (15, 28) ... (15, 31) |
1625
+ -----------------------------------------------
1626
+ which means we need to exchange sub-tensor-0 with sub-tensor-1 and sub-tensor-2 and sub-tensor-3.
1627
+ And basis vector for lane and register are:
1628
+ Register = {{0, 1}, {0, 2}, {0, 4}}
1629
+ Lane = {{1, 0}, {2, 0, [4, 0}, {8, 0}, {0, 16}, {0, 8}}
1630
+
1631
+ The steps to get this layout are, firstly we check the last dim of WarpsPerCTA is 1, so we can use v_permlane16.
1632
+ Then, we exchange the 2nd and 4th elements in the basis vector of an identity linear and then it will be composed with
1633
+ the original mfma16 LL.
1634
+ clang-format on
1635
+ */
1636
+ auto destIdxInBases = isMfma32 ? 3 : 4 ;
1536
1637
std::vector<std::vector<int32_t >> dimNBases (mfmaLL.getOutDimSizeLog2 (dimN));
1537
1638
std::generate (dimNBases.begin (), dimNBases.end (),
1538
1639
[i = 0 ]() mutable { return std::vector<int32_t >{1 << i++}; });
1539
- std::swap (dimNBases[2 ], dimNBases[3 ]);
1640
+ std::swap (dimNBases[2 ], dimNBases[destIdxInBases ]);
1540
1641
swapLL *= LinearLayout ({{dimN, dimNBases}}, {dimN});
1541
1642
1542
1643
return mfmaLL.compose (swapLL);
0 commit comments