@@ -1877,6 +1877,33 @@ struct LoadOpToBlockIOConversion
18771877 tensorType.getShape ());
18781878 assert (llEncoding.has_value () &&
18791879 " unexpected failure when getting linear layout" );
1880+
1881+ constexpr unsigned MAX_TILE_HEIGHT = 32 ;
1882+ BlockIOTileSizeInfo sizeInfo =
1883+ getBlockIOTileSize<MAX_TILE_HEIGHT>(llEncoding.value ());
1884+ if (!sizeInfo.isValid ())
1885+ return failure ();
1886+ auto [tileHeight, tileWidth, numPackedVals, vBlocks, rowDim, colDim,
1887+ regPackedBases] = sizeInfo;
1888+
1889+ Type eltTy = getTypeConverter ()->convertType (tensorType.getElementType ());
1890+ unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth ();
1891+ unsigned packedElemSizeInBits = elemSizeInBits * numPackedVals;
1892+ if (!check2DBlockAddressPayloadRestriction (packedElemSizeInBits, tileWidth))
1893+ return failure ();
1894+
1895+ // 2D block load supports 64 bytes per row at most.
1896+ constexpr int MAX_WIDTH = 64 ;
1897+ unsigned totalBytesPerRowPerMatrix = tileWidth * packedElemSizeInBits / 8 ;
1898+ if (totalBytesPerRowPerMatrix > MAX_WIDTH)
1899+ return failure ();
1900+
1901+ // Load multiple dot operands by enlarging the vBlocks.
1902+ vBlocks = std::min (vBlocks,
1903+ static_cast <int >(MAX_WIDTH / totalBytesPerRowPerMatrix));
1904+ // vBlocks has HW limitation of 4.
1905+ vBlocks = std::min (vBlocks, 4 );
1906+
18801907 auto llAttr = LinearEncodingAttr::get (rewriter.getContext (), *llEncoding);
18811908 SmallVector<unsigned > threadOrder (llAttr.getThreadOrder ());
18821909 size_t rank = threadOrder.size ();
@@ -1893,8 +1920,6 @@ struct LoadOpToBlockIOConversion
18931920
18941921 // Step 2: Right now we only support DPAS related layout to simplify the
18951922 // lowering.
1896- Type eltTy = getTypeConverter ()->convertType (tensorType.getElementType ());
1897- unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth ();
18981923 DpasEncodingAttr dpasLayout = getDpasLayout (tensorType);
18991924 const ArrayRef<int64_t > tensorShape = tensorType.getShape ();
19001925 unsigned numElems = getTotalElemsPerThread (resultType);
@@ -2022,9 +2047,9 @@ struct LoadOpToBlockIOConversion
20222047 std::min<unsigned >(warpsPerCTA[dimInner], innerDimRequiredWarpNum);
20232048
20242049 // Step 3: Get the tile size of load.
2025- unsigned tileWidth = dpasInstShape[threadOrder[rank - 2 ]];
2026- unsigned tileHeight = dpasInstShape[threadOrder[rank - 1 ]];
2027- unsigned vBlocks = 1 ;
2050+ tileWidth = dpasInstShape[threadOrder[rank - 2 ]];
2051+ tileHeight = dpasInstShape[threadOrder[rank - 1 ]];
2052+ vBlocks = 1 ;
20282053 unsigned numOperandsOuterDimPerLoad = 1 ;
20292054 unsigned numOperandsInnerDimPerLoad = 1 ;
20302055 unsigned maskConstancyHor = 1 , maskConstancyVer = 1 ;
@@ -2151,11 +2176,12 @@ struct LoadOpToBlockIOConversion
21512176
21522177 // PVC 2D load supports 32 rows at most. Load multiple dot operands in by
21532178 // enlarging the tileHeight.
2154- numOperandsPer2DLoadM = std::min (numOperandsPer2DLoadM, 32 / tileHeight);
2179+ numOperandsPer2DLoadM =
2180+ std::min (numOperandsPer2DLoadM,
2181+ static_cast <unsigned >(MAX_TILE_HEIGHT / tileHeight));
21552182
21562183 // PVC 2D load supports 64 bytes per row at most. Load multiple dot operands
21572184 // by enlarging the vBlocks.
2158- constexpr int MAX_WIDTH = 64 ;
21592185 unsigned totalBytesPerRowPerDPASOp = tileWidth * elemSizeInBits / 8 ;
21602186 if (totalBytesPerRowPerDPASOp > MAX_WIDTH)
21612187 return failure ();
@@ -2217,9 +2243,6 @@ struct LoadOpToBlockIOConversion
22172243 if (!pitch)
22182244 return failure ();
22192245
2220- if (!check2DBlockAddressPayloadRestriction (elemSizeInBits, tileWidth))
2221- return failure ();
2222-
22232246 // If the stride is 0, we want to load only the first row.
22242247 int stride = getStride (ptr, 0 );
22252248 unsigned baseHeightInt = (stride == 0 ? 1 : tileHeight);
0 commit comments