Skip to content

Commit f2b464f

Browse files
committed
Transposed 2d load.
Signed-off-by: Lu,Chengjun <[email protected]>
1 parent c684d08 commit f2b464f

File tree

1 file changed

+59
-19
lines changed

1 file changed

+59
-19
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 59 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2528,17 +2528,6 @@ struct LoadOpToBlockIOConversion
25282528
if (tileHeight * tileWidth * packedElemSizeInBits / 8 < GRF_SIZE)
25292529
vBlocks = 1;
25302530

2531-
// TODO: use the axis info to general the handling for both regular pointer
2532-
// and block pointer.
2533-
const bool memoryRowMajor = isMemoryRowMajor(op);
2534-
unsigned contiguousDim = memoryRowMajor ? 1 : 0;
2535-
const bool isTransposeRequired = contiguousDim != colDim;
2536-
2537-
if (isTransposeRequired) {
2538-
// TODO: support load column major data.
2539-
return failure();
2540-
}
2541-
25422531
Location loc = op.getLoc();
25432532
auto b = TritonLLVMOpBuilder(loc, rewriter);
25442533
MLIRContext *ctx = rewriter.getContext();
@@ -2666,10 +2655,59 @@ struct LoadOpToBlockIOConversion
26662655
}
26672656
}
26682657

2658+
// TODO: use the axis info to general the handling for both regular pointer
2659+
// and block pointer.
2660+
const bool memoryRowMajor = isMemoryRowMajor(op);
2661+
unsigned contiguousDim = memoryRowMajor ? 1 : 0;
2662+
const bool isTransposeRequired = contiguousDim != colDim;
2663+
2664+
if (isTransposeRequired) {
2665+
if (numPackedVals > 1)
2666+
return failure();
2667+
if (elemSizeInBits > 32)
2668+
return failure();
2669+
if (tileWidth > 32)
2670+
return failure(); // tileWidth is limited to 32 for transpose 2d load.
2671+
2672+
vBlocks = 1;
2673+
2674+
// use the d32 for transpose 2d load.
2675+
packedElemSizeInBits = 32;
2676+
numPackedVals = packedElemSizeInBits / elemSizeInBits;
2677+
if (numPackedVals > 1 && tileWidth != threadsPerWarp)
2678+
return failure(); // Couldn't use the transpose 2d load for un-packable
2679+
// along tile height dim.
2680+
tileHeight = std::min(tileHeight / numPackedVals, 8);
2681+
2682+
if (tileHeight * tileWidth < threadsPerWarp)
2683+
return failure(); // The tile size is not large enough for IGC scalar
2684+
// backend vectorization.
2685+
// transpose the width and height of the tile
2686+
std::swap(tileHeight, tileWidth);
2687+
// if (oneMatrixPerLoadForBT) {
2688+
// // Only load 1 operand per inst on row.
2689+
// numOperandsPer2DLoadM = 1;
2690+
// tileHeight = elemsPerDPASInst[threadOrder[rank - 2]];
2691+
// } else {
2692+
// // We can decompose the matrix returned by transposed large 2d load
2693+
// // when threads per warp < column size. Otherwise we have to load one
2694+
// // operand per inst.
2695+
// // Note: the tileHeight and numOperandsPer2DLoadM are the column size
2696+
// // now.
2697+
// numOperandsPer2DLoadM =
2698+
// (threadsPerWarp <= tileHeight) ? repCluster[rank - 1] : 1;
2699+
// }
2700+
// // The transpose 2d load only support 1 operand per inst on column.
2701+
// // (vBlocks = 1)
2702+
// numOperandsPer2DloadN = 1;
2703+
// // TODO: support load column major data.
2704+
// return failure();
2705+
}
2706+
26692707
baseWidth = b.i32_val(
26702708
std::max(64u, vBlocks * tileWidth * (packedElemSizeInBits / 8)));
26712709
// If the stride is 0, we want to load only the first row.
2672-
int stride = getStride(ptr, 0);
2710+
int stride = getStride(ptr, memoryRowMajor ? 0 : 1);
26732711
baseHeightInt = (stride == 0 ? 1 : tileHeight);
26742712
baseHeight = b.i32_val(baseHeightInt);
26752713
pitch = getPitch(rewriter, ptr, elemSizeInBits, memoryRowMajor ? 0 : 1);
@@ -2738,17 +2776,19 @@ struct LoadOpToBlockIOConversion
27382776
}
27392777
} break;
27402778
case DpasEncodingAttr::OpIdx::OperandB: {
2741-
assert(numPackedVals == 1 &&
2742-
"invalid number of packed values for DPAS operand B.");
2779+
// assert(numPackedVals == 1 &&
2780+
// "invalid number of packed values for DPAS operand B.");
27432781
unsigned elemsPerLanePerDPASInst =
27442782
product<unsigned>(dpasLayout.getDPASInstShapeB()) / threadsPerWarp;
27452783
// Block 2D contain at least one DotOp B.
27462784
if (numElemsPerLoad >= elemsPerLanePerDPASInst) {
27472785
unsigned opsPerChannel = dpasLayout.getOpsPerChannel();
27482786
unsigned sysDepth = dpasLayout.getSystolicDepth();
2749-
if (tileHeight >= (opsPerChannel * sysDepth) &&
2750-
((opsPerChannel == 4 && elemSizeInBits == 8) ||
2751-
(opsPerChannel == 2 && elemSizeInBits == 16))) {
2787+
if ((opsPerChannel == 4 && elemSizeInBits == 8) ||
2788+
(opsPerChannel == 2 && elemSizeInBits == 16)) {
2789+
assert(!isTransposeRequired ||
2790+
opsPerChannel == numPackedVals &&
2791+
"invalid opsPerChannel for transposed DotOp B");
27522792
// Use the VNNI packing format for DotOp B layout.
27532793
numValuesPerLoad = numElemsPerLoad / opsPerChannel;
27542794
packedType = i32_ty;
@@ -2815,8 +2855,8 @@ struct LoadOpToBlockIOConversion
28152855
/*tile_width*/ tileWidth,
28162856
/*tile_height*/ tileHeight,
28172857
/*v_blocks*/ vBlocks,
2818-
/*transpose*/ false,
2819-
/*vnni_transform*/ useVNNIFormat);
2858+
/*transpose*/ isTransposeRequired,
2859+
/*vnni_transform*/ !isTransposeRequired && useVNNIFormat);
28202860

28212861
// When strides[0] is 0, we only want to load the first row, so we
28222862
// set the base height to be 1. If tile height is bigger than 1,

0 commit comments

Comments
 (0)