Skip to content

Commit a0e532a

Browse files
[LoadStoreOpToLLVM] Improve LoadOpToBlockIOConversion (#5068)
Use `getBlockIOTileSize` function to calculate tile size to perform early exit. Refactoring of existing calculation logic will be done in following PRs. Signed-off-by: Whitney Tsang <[email protected]>
1 parent ce08f57 commit a0e532a

File tree

1 file changed

+33
-10
lines changed

1 file changed

+33
-10
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1877,6 +1877,33 @@ struct LoadOpToBlockIOConversion
18771877
tensorType.getShape());
18781878
assert(llEncoding.has_value() &&
18791879
"unexpected failure when getting linear layout");
1880+
1881+
constexpr unsigned MAX_TILE_HEIGHT = 32;
1882+
BlockIOTileSizeInfo sizeInfo =
1883+
getBlockIOTileSize<MAX_TILE_HEIGHT>(llEncoding.value());
1884+
if (!sizeInfo.isValid())
1885+
return failure();
1886+
auto [tileHeight, tileWidth, numPackedVals, vBlocks, rowDim, colDim,
1887+
regPackedBases] = sizeInfo;
1888+
1889+
Type eltTy = getTypeConverter()->convertType(tensorType.getElementType());
1890+
unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth();
1891+
unsigned packedElemSizeInBits = elemSizeInBits * numPackedVals;
1892+
if (!check2DBlockAddressPayloadRestriction(packedElemSizeInBits, tileWidth))
1893+
return failure();
1894+
1895+
// 2D block load supports 64 bytes per row at most.
1896+
constexpr int MAX_WIDTH = 64;
1897+
unsigned totalBytesPerRowPerMatrix = tileWidth * packedElemSizeInBits / 8;
1898+
if (totalBytesPerRowPerMatrix > MAX_WIDTH)
1899+
return failure();
1900+
1901+
// Load multiple dot operands by enlarging the vBlocks.
1902+
vBlocks = std::min(vBlocks,
1903+
static_cast<int>(MAX_WIDTH / totalBytesPerRowPerMatrix));
1904+
// vBlocks has HW limitation of 4.
1905+
vBlocks = std::min(vBlocks, 4);
1906+
18801907
auto llAttr = LinearEncodingAttr::get(rewriter.getContext(), *llEncoding);
18811908
SmallVector<unsigned> threadOrder(llAttr.getThreadOrder());
18821909
size_t rank = threadOrder.size();
@@ -1893,8 +1920,6 @@ struct LoadOpToBlockIOConversion
18931920

18941921
// Step 2: Right now we only support DPAS related layout to simplify the
18951922
// lowering.
1896-
Type eltTy = getTypeConverter()->convertType(tensorType.getElementType());
1897-
unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth();
18981923
DpasEncodingAttr dpasLayout = getDpasLayout(tensorType);
18991924
const ArrayRef<int64_t> tensorShape = tensorType.getShape();
19001925
unsigned numElems = getTotalElemsPerThread(resultType);
@@ -2022,9 +2047,9 @@ struct LoadOpToBlockIOConversion
20222047
std::min<unsigned>(warpsPerCTA[dimInner], innerDimRequiredWarpNum);
20232048

20242049
// Step 3: Get the tile size of load.
2025-
unsigned tileWidth = dpasInstShape[threadOrder[rank - 2]];
2026-
unsigned tileHeight = dpasInstShape[threadOrder[rank - 1]];
2027-
unsigned vBlocks = 1;
2050+
tileWidth = dpasInstShape[threadOrder[rank - 2]];
2051+
tileHeight = dpasInstShape[threadOrder[rank - 1]];
2052+
vBlocks = 1;
20282053
unsigned numOperandsOuterDimPerLoad = 1;
20292054
unsigned numOperandsInnerDimPerLoad = 1;
20302055
unsigned maskConstancyHor = 1, maskConstancyVer = 1;
@@ -2151,11 +2176,12 @@ struct LoadOpToBlockIOConversion
21512176

21522177
// PVC 2D load supports 32 rows at most. Load multiple dot operands in by
21532178
// enlarging the tileHeight.
2154-
numOperandsPer2DLoadM = std::min(numOperandsPer2DLoadM, 32 / tileHeight);
2179+
numOperandsPer2DLoadM =
2180+
std::min(numOperandsPer2DLoadM,
2181+
static_cast<unsigned>(MAX_TILE_HEIGHT / tileHeight));
21552182

21562183
// PVC 2D load supports 64 bytes per row at most. Load multiple dot operands
21572184
// by enlarging the vBlocks.
2158-
constexpr int MAX_WIDTH = 64;
21592185
unsigned totalBytesPerRowPerDPASOp = tileWidth * elemSizeInBits / 8;
21602186
if (totalBytesPerRowPerDPASOp > MAX_WIDTH)
21612187
return failure();
@@ -2217,9 +2243,6 @@ struct LoadOpToBlockIOConversion
22172243
if (!pitch)
22182244
return failure();
22192245

2220-
if (!check2DBlockAddressPayloadRestriction(elemSizeInBits, tileWidth))
2221-
return failure();
2222-
22232246
// If the stride is 0, we want to load only the first row.
22242247
int stride = getStride(ptr, 0);
22252248
unsigned baseHeightInt = (stride == 0 ? 1 : tileHeight);

0 commit comments

Comments
 (0)