@@ -373,11 +373,11 @@ struct BlockIOConversionBase : public LoadStoreConversionBase {
373
373
374
374
// Returns the pitch (stride in bytes) of \p ptr.
375
375
Value getPitch (ConversionPatternRewriter &rewriter, Value ptr,
376
- unsigned elemSizeInBits) const {
376
+ unsigned elemSizeInBits, unsigned dim = 0 ) const {
377
377
Location loc = ptr.getLoc ();
378
378
auto b = TritonLLVMOpBuilder (loc, rewriter);
379
379
380
- int stride = getStride (ptr, 0 );
380
+ int stride = getStride (ptr, dim );
381
381
// If the stride is 0, we assume a minimum pitch of 64 bytes.
382
382
constexpr int MIN_PITCH = 64 ;
383
383
if (stride == 0 )
@@ -1884,17 +1884,6 @@ struct LoadOpToBlockIOConversion
1884
1884
// HW issue for vblock = 4
1885
1885
vBlocks = vBlocks == 4 ? 1 : vBlocks;
1886
1886
1887
- // TODO: use the axis info to general the handling for both regular pointer
1888
- // and block pointer.
1889
- const bool memoryRowMajor = isMemoryRowMajor (op);
1890
- unsigned contiguousDim = memoryRowMajor ? 1 : 0 ;
1891
- const bool isTransposeRequired = contiguousDim != colDim;
1892
-
1893
- if (isTransposeRequired) {
1894
- // TODO: support load column major data.
1895
- return failure ();
1896
- }
1897
-
1898
1887
Location loc = op.getLoc ();
1899
1888
MLIRContext *ctx = op.getContext ();
1900
1889
auto b = TritonLLVMOpBuilder (loc, rewriter);
@@ -2012,13 +2001,56 @@ struct LoadOpToBlockIOConversion
2012
2001
otherElems = unpackLLElements (loc, llOther, rewriter);
2013
2002
}
2014
2003
2004
+ // TODO: use the axis info to general the handling for both regular pointer
2005
+ // and block pointer.
2006
+ const bool memoryRowMajor = isMemoryRowMajor (op);
2007
+ unsigned contiguousDim = memoryRowMajor ? 1 : 0 ;
2008
+ const bool isTransposeRequired = contiguousDim != colDim;
2009
+
2010
+ if (isTransposeRequired) {
2011
+ if (numPackedVals > 1 )
2012
+ return failure ();
2013
+ if (elemSizeInBits > 32 )
2014
+ return failure ();
2015
+ if (tileWidth > 32 )
2016
+ return failure (); // tileWidth is limited to 32 for transpose 2d load.
2017
+
2018
+ vBlocks = 1 ;
2019
+
2020
+ // use the d32 for transpose 2d load.
2021
+ packedElemSizeInBits = 32 ;
2022
+ numPackedVals = packedElemSizeInBits / elemSizeInBits;
2023
+ tileHeight = std::min (tileHeight / numPackedVals, 8 );
2024
+
2025
+ // transpose the width and height of the tile
2026
+ std::swap (tileHeight, tileWidth);
2027
+ // if (oneMatrixPerLoadForBT) {
2028
+ // // Only load 1 operand per inst on row.
2029
+ // numOperandsPer2DLoadM = 1;
2030
+ // tileHeight = elemsPerDPASInst[threadOrder[rank - 2]];
2031
+ // } else {
2032
+ // // We can decompose the matrix returned by transposed large 2d load
2033
+ // // when threads per warp < column size. Otherwise we have to load one
2034
+ // // operand per inst.
2035
+ // // Note: the tileHeight and numOperandsPer2DLoadM are the column size
2036
+ // // now.
2037
+ // numOperandsPer2DLoadM =
2038
+ // (threadsPerWarp <= tileHeight) ? repCluster[rank - 1] : 1;
2039
+ // }
2040
+ // // The transpose 2d load only support 1 operand per inst on column.
2041
+ // // (vBlocks = 1)
2042
+ // numOperandsPer2DloadN = 1;
2043
+ // // TODO: support load column major data.
2044
+ // return failure();
2045
+ }
2046
+
2015
2047
baseWidth = b.i32_val (
2016
2048
std::max (64u , vBlocks * tileWidth * (packedElemSizeInBits / 8 )));
2017
2049
// If the stride is 0, we want to load only the first row.
2018
- int stride = getStride (ptr, 0 );
2050
+ int stride = getStride (ptr, memoryRowMajor ? 0 : 1 );
2019
2051
baseHeightInt = (stride == 0 ? 1 : tileHeight);
2020
2052
baseHeight = b.i32_val (baseHeightInt);
2021
- pitch = getPitch (rewriter, ptr, elemSizeInBits);
2053
+ pitch = getPitch (rewriter, ptr, elemSizeInBits, memoryRowMajor ? 0 : 1 );
2022
2054
if (!pitch)
2023
2055
return failure ();
2024
2056
@@ -2161,7 +2193,7 @@ struct LoadOpToBlockIOConversion
2161
2193
/* tile_width*/ tileWidth,
2162
2194
/* tile_height*/ tileHeight,
2163
2195
/* v_blocks*/ vBlocks,
2164
- /* transpose*/ false ,
2196
+ /* transpose*/ isTransposeRequired ,
2165
2197
/* vnni_transform*/ useVNNIFormat);
2166
2198
2167
2199
// When strides[0] is 0, we only want to load the first row, so we
0 commit comments