@@ -1439,6 +1439,14 @@ struct LoadOpConversion
1439
1439
1440
1440
Type eltTy = tensorType.getElementType ();
1441
1441
unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth ();
1442
+
1443
+ auto tileParams = Subgroup2DBlockEncodingAttr::getInstrShapeForLayout (
1444
+ cast<DistributedEncodingTrait>(encoding), tensorType.getShape (),
1445
+ memoryRowMajor, elemSizeInBits / 8 , rewriter.getContext ());
1446
+ unsigned tileHeight = tileParams[0 ];
1447
+ const unsigned tileWidth = tileParams[1 ];
1448
+ const unsigned vBlocks = tileParams[2 ];
1449
+
1442
1450
DpasEncodingAttr dpasLayout = getDpasLayout (tensorType);
1443
1451
const ArrayRef<int64_t > tensorShape = tensorType.getShape ();
1444
1452
unsigned numElems = getTotalElemsPerThread (resultType);
@@ -1476,8 +1484,7 @@ struct LoadOpConversion
1476
1484
1477
1485
Value elemSizeInBytes = b.i32_val (elemSizeInBits / 8 );
1478
1486
1479
- SmallVector<unsigned > elemsPerInstr = dpasLayout.getDPASInstShapeC ();
1480
- int64_t elemsPerLane = product<unsigned >(elemsPerInstr) / threadsPerWarp;
1487
+ const unsigned elemsPerLane = tileWidth * tileHeight / threadsPerWarp;
1481
1488
Type load2DGenXType =
1482
1489
LLVM::getVectorType (IntegerType::get (ctx, elemSizeInBits),
1483
1490
elemsPerLane); // make it opaque type.
@@ -1527,12 +1534,12 @@ struct LoadOpConversion
1527
1534
for (int repM = 0 ; repM < repCluster[0 ]; ++repM) {
1528
1535
1529
1536
Value offsetY =
1530
- b.add (warpId0Offset, b. i32_val (m * replicaStride[ 0 ] +
1531
- repM * elemsPerInstr[ 0 ] ));
1537
+ b.add (warpId0Offset,
1538
+ b. i32_val (m * replicaStride[ 0 ] + repM * tileHeight ));
1532
1539
for (int repN = 0 ; repN < repCluster[1 ]; ++repN) {
1533
1540
Value offsetX =
1534
- b.add (warpId1Offset, b. i32_val (n * replicaStride[ 1 ] +
1535
- repN * elemsPerInstr[ 1 ] ));
1541
+ b.add (warpId1Offset,
1542
+ b. i32_val (n * replicaStride[ 1 ] + repN * tileWidth ));
1536
1543
1537
1544
auto load2dOp = rewriter.create <TritonGEN::Matrix2DBlockLoadOp>(
1538
1545
loc, load2DGenXType,
@@ -1543,9 +1550,9 @@ struct LoadOpConversion
1543
1550
/* x*/ b.trunc (i32_ty, offsetX),
1544
1551
/* y*/ b.trunc (i32_ty, offsetY),
1545
1552
/* elem_size_in_bits*/ elemSizeInBits,
1546
- /* tile_width*/ elemsPerInstr[ 1 ] ,
1547
- /* tile_height*/ elemsPerInstr[ 0 ] ,
1548
- /* v_blocks*/ 1 ,
1553
+ /* tile_width*/ tileWidth ,
1554
+ /* tile_height*/ tileHeight ,
1555
+ /* v_blocks*/ vBlocks ,
1549
1556
/* transpose*/ false ,
1550
1557
/* vnni_transform*/ false );
1551
1558
if (failed (load2dOp.verify ())) {
@@ -1659,9 +1666,6 @@ struct LoadOpConversion
1659
1666
offsetBaseY] =
1660
1667
getValuesFromBlockPointerStruct (adaptor.getPtr (), rewriter);
1661
1668
1662
- unsigned tileWidth = elemsPerDPASInst[threadOrder[rank - 2 ]];
1663
- unsigned tileHeight = elemsPerDPASInst[threadOrder[rank - 1 ]];
1664
-
1665
1669
MLIRContext *ctx = rewriter.getContext ();
1666
1670
const StringAttr dimOuterStr = S (" dim" + std::to_string (dimOuter));
1667
1671
const StringAttr dimInnerStr = S (" dim" + std::to_string (dimInner));
@@ -1739,7 +1743,6 @@ struct LoadOpConversion
1739
1743
llvm::dbgs () << " tile layout done\n " ;
1740
1744
});
1741
1745
1742
- unsigned vBlocks = 1 ;
1743
1746
unsigned numOperandsOuterDimPerLoad = 1 ;
1744
1747
unsigned numOperandsInnerDimPerLoad = 1 ;
1745
1748
@@ -1756,25 +1759,24 @@ struct LoadOpConversion
1756
1759
if (!usePackedType)
1757
1760
return failure ();
1758
1761
1759
- std::swap (tileHeight, tileWidth);
1760
-
1761
1762
if (oneMatrixPerLoadForBT) {
1762
1763
// Only load 1 operand per inst on row.
1763
1764
numOperandsPer2DLoadM = 1 ;
1765
+ tileHeight = elemsPerDPASInst[threadOrder[rank - 2 ]];
1764
1766
} else {
1765
1767
// We can decompose the matrix returned by transposed large 2d load
1766
1768
// when threads per warp < column size. Otherwise we have to load one
1767
1769
// operand per inst.
1768
1770
// Note: the tileHeight and numOperandsPer2DLoadM are the column size
1769
1771
// now.
1770
- numOperandsPer2DLoadM =
1771
- (threadsPerWarp <= tileHeight) ? repCluster[rank - 1 ] : 1 ;
1772
+ numOperandsPer2DLoadM = (threadsPerWarp <= tileHeight) ? repCluster[rank - 1 ] : 1 ;
1772
1773
}
1773
1774
// The transpose 2d load only support 1 operand per inst on column.
1774
1775
// (vBlocks = 1)
1775
1776
numOperandsPer2DloadN = 1 ;
1776
1777
}
1777
1778
1779
+ // TODO: move this logic to the instr shape computation
1778
1780
// PVC 2D load supports 32 rows at most. Load multiple dot operands in by
1779
1781
// enlarging the tileHeight.
1780
1782
numOperandsPer2DLoadM = std::min (numOperandsPer2DLoadM, 32 / tileHeight);
@@ -1785,7 +1787,6 @@ struct LoadOpConversion
1785
1787
unsigned totalBytesPerRowPerDPASOp = tileWidth * elemSizeInBits / 8 ;
1786
1788
numOperandsPer2DloadN =
1787
1789
std::min (numOperandsPer2DloadN, 64 / totalBytesPerRowPerDPASOp);
1788
- vBlocks = numOperandsPer2DloadN;
1789
1790
1790
1791
numOperandsOuterDimPerLoad =
1791
1792
isOperandA ? numOperandsPer2DLoadM : numOperandsPer2DloadN;
@@ -1960,7 +1961,6 @@ struct LoadOpConversion
1960
1961
if (isTransposeRequired) {
1961
1962
// adjust the block io parameter to align HW's limitations on
1962
1963
// transposing load.
1963
- tileWidth = tileWidth / (32 / originalElemBits);
1964
1964
elemSizeInBits = 32 ;
1965
1965
}
1966
1966
Value elemSizeInBytes = b.i32_val (originalElemBits / 8 );
0 commit comments