@@ -1439,6 +1439,14 @@ struct LoadOpConversion
14391439
14401440 Type eltTy = tensorType.getElementType ();
14411441 unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth ();
1442+
1443+ auto tileParams = Subgroup2DBlockEncodingAttr::getInstrShapeForLayout (
1444+ cast<DistributedEncodingTrait>(encoding), tensorType.getShape (),
1445+ memoryRowMajor, elemSizeInBits / 8 , rewriter.getContext ());
1446+ unsigned tileHeight = tileParams[0 ];
1447+ const unsigned tileWidth = tileParams[1 ];
1448+ const unsigned vBlocks = tileParams[2 ];
1449+
14421450 DpasEncodingAttr dpasLayout = getDpasLayout (tensorType);
14431451 const ArrayRef<int64_t > tensorShape = tensorType.getShape ();
14441452 unsigned numElems = getTotalElemsPerThread (resultType);
@@ -1476,8 +1484,7 @@ struct LoadOpConversion
14761484
14771485 Value elemSizeInBytes = b.i32_val (elemSizeInBits / 8 );
14781486
1479- SmallVector<unsigned > elemsPerInstr = dpasLayout.getDPASInstShapeC ();
1480- int64_t elemsPerLane = product<unsigned >(elemsPerInstr) / threadsPerWarp;
1487+ const unsigned elemsPerLane = tileWidth * tileHeight / threadsPerWarp;
14811488 Type load2DGenXType =
14821489 LLVM::getVectorType (IntegerType::get (ctx, elemSizeInBits),
14831490 elemsPerLane); // make it opaque type.
@@ -1527,12 +1534,12 @@ struct LoadOpConversion
15271534 for (int repM = 0 ; repM < repCluster[0 ]; ++repM) {
15281535
15291536 Value offsetY =
1530- b.add (warpId0Offset, b. i32_val (m * replicaStride[ 0 ] +
1531- repM * elemsPerInstr[ 0 ] ));
1537+ b.add (warpId0Offset,
1538+ b. i32_val (m * replicaStride[ 0 ] + repM * tileHeight ));
15321539 for (int repN = 0 ; repN < repCluster[1 ]; ++repN) {
15331540 Value offsetX =
1534- b.add (warpId1Offset, b. i32_val (n * replicaStride[ 1 ] +
1535- repN * elemsPerInstr[ 1 ] ));
1541+ b.add (warpId1Offset,
1542+ b. i32_val (n * replicaStride[ 1 ] + repN * tileWidth ));
15361543
15371544 auto load2dOp = rewriter.create <TritonGEN::Matrix2DBlockLoadOp>(
15381545 loc, load2DGenXType,
@@ -1543,9 +1550,9 @@ struct LoadOpConversion
15431550 /* x*/ b.trunc (i32_ty, offsetX),
15441551 /* y*/ b.trunc (i32_ty, offsetY),
15451552 /* elem_size_in_bits*/ elemSizeInBits,
1546- /* tile_width*/ elemsPerInstr[ 1 ] ,
1547- /* tile_height*/ elemsPerInstr[ 0 ] ,
1548- /* v_blocks*/ 1 ,
1553+ /* tile_width*/ tileWidth ,
1554+ /* tile_height*/ tileHeight ,
1555+ /* v_blocks*/ vBlocks ,
15491556 /* transpose*/ false ,
15501557 /* vnni_transform*/ false );
15511558 if (failed (load2dOp.verify ())) {
@@ -1659,9 +1666,6 @@ struct LoadOpConversion
16591666 offsetBaseY] =
16601667 getValuesFromBlockPointerStruct (adaptor.getPtr (), rewriter);
16611668
1662- unsigned tileWidth = elemsPerDPASInst[threadOrder[rank - 2 ]];
1663- unsigned tileHeight = elemsPerDPASInst[threadOrder[rank - 1 ]];
1664-
16651669 MLIRContext *ctx = rewriter.getContext ();
16661670 const StringAttr dimOuterStr = S (" dim" + std::to_string (dimOuter));
16671671 const StringAttr dimInnerStr = S (" dim" + std::to_string (dimInner));
@@ -1739,7 +1743,6 @@ struct LoadOpConversion
17391743 llvm::dbgs () << " tile layout done\n " ;
17401744 });
17411745
1742- unsigned vBlocks = 1 ;
17431746 unsigned numOperandsOuterDimPerLoad = 1 ;
17441747 unsigned numOperandsInnerDimPerLoad = 1 ;
17451748
@@ -1756,11 +1759,10 @@ struct LoadOpConversion
17561759 if (!usePackedType)
17571760 return failure ();
17581761
1759- std::swap (tileHeight, tileWidth);
1760-
17611762 if (oneMatrixPerLoadForBT) {
17621763 // Only load 1 operand per inst on row.
17631764 numOperandsPer2DLoadM = 1 ;
1765+ tileHeight = elemsPerDPASInst[threadOrder[rank - 2 ]];
17641766 } else {
17651767 // We can decompose the matrix returned by transposed large 2d load
17661768 // when threads per warp < column size. Otherwise we have to load one
@@ -1775,6 +1777,7 @@ struct LoadOpConversion
17751777 numOperandsPer2DloadN = 1 ;
17761778 }
17771779
1780+ // TODO: move this logic to the instr shape computation
17781781 // PVC 2D load supports 32 rows at most. Load multiple dot operands in by
17791782 // enlarging the tileHeight.
17801783 numOperandsPer2DLoadM = std::min (numOperandsPer2DLoadM, 32 / tileHeight);
@@ -1785,7 +1788,6 @@ struct LoadOpConversion
17851788 unsigned totalBytesPerRowPerDPASOp = tileWidth * elemSizeInBits / 8 ;
17861789 numOperandsPer2DloadN =
17871790 std::min (numOperandsPer2DloadN, 64 / totalBytesPerRowPerDPASOp);
1788- vBlocks = numOperandsPer2DloadN;
17891791
17901792 numOperandsOuterDimPerLoad =
17911793 isOperandA ? numOperandsPer2DLoadM : numOperandsPer2DloadN;
@@ -1960,7 +1962,6 @@ struct LoadOpConversion
19601962 if (isTransposeRequired) {
19611963 // adjust the block io parameter to align HW's limitations on
19621964 // transposing load.
1963- tileWidth = tileWidth / (32 / originalElemBits);
19641965 elemSizeInBits = 32 ;
19651966 }
19661967 Value elemSizeInBytes = b.i32_val (originalElemBits / 8 );
0 commit comments