Skip to content

Commit efff84d

Browse files
committed
Transposed 2d load.
Signed-off-by: Lu,Chengjun <[email protected]>
1 parent c464166 commit efff84d

File tree

1 file changed

+48
-16
lines changed

1 file changed

+48
-16
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -373,11 +373,11 @@ struct BlockIOConversionBase : public LoadStoreConversionBase {
373373

374374
// Returns the pitch (stride in bytes) of \p ptr.
375375
Value getPitch(ConversionPatternRewriter &rewriter, Value ptr,
376-
unsigned elemSizeInBits) const {
376+
unsigned elemSizeInBits, unsigned dim = 0) const {
377377
Location loc = ptr.getLoc();
378378
auto b = TritonLLVMOpBuilder(loc, rewriter);
379379

380-
int stride = getStride(ptr, 0);
380+
int stride = getStride(ptr, dim);
381381
// If the stride is 0, we assume a minimum pitch of 64 bytes.
382382
constexpr int MIN_PITCH = 64;
383383
if (stride == 0)
@@ -1884,17 +1884,6 @@ struct LoadOpToBlockIOConversion
18841884
// HW issue for vblock = 4
18851885
vBlocks = vBlocks == 4 ? 1 : vBlocks;
18861886

1887-
// TODO: use the axis info to general the handling for both regular pointer
1888-
// and block pointer.
1889-
const bool memoryRowMajor = isMemoryRowMajor(op);
1890-
unsigned contiguousDim = memoryRowMajor ? 1 : 0;
1891-
const bool isTransposeRequired = contiguousDim != colDim;
1892-
1893-
if (isTransposeRequired) {
1894-
// TODO: support load column major data.
1895-
return failure();
1896-
}
1897-
18981887
Location loc = op.getLoc();
18991888
MLIRContext *ctx = op.getContext();
19001889
auto b = TritonLLVMOpBuilder(loc, rewriter);
@@ -2012,13 +2001,56 @@ struct LoadOpToBlockIOConversion
20122001
otherElems = unpackLLElements(loc, llOther, rewriter);
20132002
}
20142003

2004+
// TODO: use the axis info to general the handling for both regular pointer
2005+
// and block pointer.
2006+
const bool memoryRowMajor = isMemoryRowMajor(op);
2007+
unsigned contiguousDim = memoryRowMajor ? 1 : 0;
2008+
const bool isTransposeRequired = contiguousDim != colDim;
2009+
2010+
if (isTransposeRequired) {
2011+
if (numPackedVals > 1)
2012+
return failure();
2013+
if (elemSizeInBits > 32)
2014+
return failure();
2015+
if (tileWidth > 32)
2016+
return failure(); // tileWidth is limited to 32 for transpose 2d load.
2017+
2018+
vBlocks = 1;
2019+
2020+
// use the d32 for transpose 2d load.
2021+
packedElemSizeInBits = 32;
2022+
numPackedVals = packedElemSizeInBits / elemSizeInBits;
2023+
tileHeight = std::min(tileHeight / numPackedVals, 8);
2024+
2025+
// transpose the width and height of the tile
2026+
std::swap(tileHeight, tileWidth);
2027+
// if (oneMatrixPerLoadForBT) {
2028+
// // Only load 1 operand per inst on row.
2029+
// numOperandsPer2DLoadM = 1;
2030+
// tileHeight = elemsPerDPASInst[threadOrder[rank - 2]];
2031+
// } else {
2032+
// // We can decompose the matrix returned by transposed large 2d load
2033+
// // when threads per warp < column size. Otherwise we have to load one
2034+
// // operand per inst.
2035+
// // Note: the tileHeight and numOperandsPer2DLoadM are the column size
2036+
// // now.
2037+
// numOperandsPer2DLoadM =
2038+
// (threadsPerWarp <= tileHeight) ? repCluster[rank - 1] : 1;
2039+
// }
2040+
// // The transpose 2d load only support 1 operand per inst on column.
2041+
// // (vBlocks = 1)
2042+
// numOperandsPer2DloadN = 1;
2043+
// // TODO: support load column major data.
2044+
// return failure();
2045+
}
2046+
20152047
baseWidth = b.i32_val(
20162048
std::max(64u, vBlocks * tileWidth * (packedElemSizeInBits / 8)));
20172049
// If the stride is 0, we want to load only the first row.
2018-
int stride = getStride(ptr, 0);
2050+
int stride = getStride(ptr, memoryRowMajor ? 0 : 1);
20192051
baseHeightInt = (stride == 0 ? 1 : tileHeight);
20202052
baseHeight = b.i32_val(baseHeightInt);
2021-
pitch = getPitch(rewriter, ptr, elemSizeInBits);
2053+
pitch = getPitch(rewriter, ptr, elemSizeInBits, memoryRowMajor ? 0 : 1);
20222054
if (!pitch)
20232055
return failure();
20242056

@@ -2161,7 +2193,7 @@ struct LoadOpToBlockIOConversion
21612193
/*tile_width*/ tileWidth,
21622194
/*tile_height*/ tileHeight,
21632195
/*v_blocks*/ vBlocks,
2164-
/*transpose*/ false,
2196+
/*transpose*/ isTransposeRequired,
21652197
/*vnni_transform*/ useVNNIFormat);
21662198

21672199
// When strides[0] is 0, we only want to load the first row, so we

0 commit comments

Comments
 (0)