@@ -2528,17 +2528,6 @@ struct LoadOpToBlockIOConversion
25282528 if (tileHeight * tileWidth * packedElemSizeInBits / 8 < GRF_SIZE)
25292529 vBlocks = 1 ;
25302530
2531- // TODO: use the axis info to general the handling for both regular pointer
2532- // and block pointer.
2533- const bool memoryRowMajor = isMemoryRowMajor (op);
2534- unsigned contiguousDim = memoryRowMajor ? 1 : 0 ;
2535- const bool isTransposeRequired = contiguousDim != colDim;
2536-
2537- if (isTransposeRequired) {
2538- // TODO: support load column major data.
2539- return failure ();
2540- }
2541-
25422531 Location loc = op.getLoc ();
25432532 auto b = TritonLLVMOpBuilder (loc, rewriter);
25442533 MLIRContext *ctx = rewriter.getContext ();
@@ -2666,10 +2655,59 @@ struct LoadOpToBlockIOConversion
26662655 }
26672656 }
26682657
2658+ // TODO: use the axis info to general the handling for both regular pointer
2659+ // and block pointer.
2660+ const bool memoryRowMajor = isMemoryRowMajor (op);
2661+ unsigned contiguousDim = memoryRowMajor ? 1 : 0 ;
2662+ const bool isTransposeRequired = contiguousDim != colDim;
2663+
2664+ if (isTransposeRequired) {
2665+ if (numPackedVals > 1 )
2666+ return failure ();
2667+ if (elemSizeInBits > 32 )
2668+ return failure ();
2669+ if (tileWidth > 32 )
2670+ return failure (); // tileWidth is limited to 32 for transpose 2d load.
2671+
2672+ vBlocks = 1 ;
2673+
2674+ // use the d32 for transpose 2d load.
2675+ packedElemSizeInBits = 32 ;
2676+ numPackedVals = packedElemSizeInBits / elemSizeInBits;
2677+ if (numPackedVals > 1 && tileWidth != threadsPerWarp)
2678+ return failure (); // Couldn't use the transpose 2d load for un-packable
2679+ // along tile height dim.
2680+ tileHeight = std::min (tileHeight / numPackedVals, 8 );
2681+
2682+ if (tileHeight * tileWidth < threadsPerWarp)
2683+ return failure (); // The tile size is not large enough for IGC scalar
2684+ // backend vectorization.
2685+ // transpose the width and height of the tile
2686+ std::swap (tileHeight, tileWidth);
2687+ // if (oneMatrixPerLoadForBT) {
2688+ // // Only load 1 operand per inst on row.
2689+ // numOperandsPer2DLoadM = 1;
2690+ // tileHeight = elemsPerDPASInst[threadOrder[rank - 2]];
2691+ // } else {
2692+ // // We can decompose the matrix returned by transposed large 2d load
2693+ // // when threads per warp < column size. Otherwise we have to load one
2694+ // // operand per inst.
2695+ // // Note: the tileHeight and numOperandsPer2DLoadM are the column size
2696+ // // now.
2697+ // numOperandsPer2DLoadM =
2698+ // (threadsPerWarp <= tileHeight) ? repCluster[rank - 1] : 1;
2699+ // }
2700+ // // The transpose 2d load only support 1 operand per inst on column.
2701+ // // (vBlocks = 1)
2702+ // numOperandsPer2DloadN = 1;
2703+ // // TODO: support load column major data.
2704+ // return failure();
2705+ }
2706+
26692707 baseWidth = b.i32_val (
26702708 std::max (64u , vBlocks * tileWidth * (packedElemSizeInBits / 8 )));
26712709 // If the stride is 0, we want to load only the first row.
2672- int stride = getStride (ptr, 0 );
2710+ int stride = getStride (ptr, memoryRowMajor ? 0 : 1 );
26732711 baseHeightInt = (stride == 0 ? 1 : tileHeight);
26742712 baseHeight = b.i32_val (baseHeightInt);
26752713 pitch = getPitch (rewriter, ptr, elemSizeInBits, memoryRowMajor ? 0 : 1 );
@@ -2738,17 +2776,19 @@ struct LoadOpToBlockIOConversion
27382776 }
27392777 } break ;
27402778 case DpasEncodingAttr::OpIdx::OperandB: {
2741- assert (numPackedVals == 1 &&
2742- " invalid number of packed values for DPAS operand B." );
2779+ // assert(numPackedVals == 1 &&
2780+ // "invalid number of packed values for DPAS operand B.");
27432781 unsigned elemsPerLanePerDPASInst =
27442782 product<unsigned >(dpasLayout.getDPASInstShapeB ()) / threadsPerWarp;
27452783 // Block 2D contain at least one DotOp B.
27462784 if (numElemsPerLoad >= elemsPerLanePerDPASInst) {
27472785 unsigned opsPerChannel = dpasLayout.getOpsPerChannel ();
27482786 unsigned sysDepth = dpasLayout.getSystolicDepth ();
2749- if (tileHeight >= (opsPerChannel * sysDepth) &&
2750- ((opsPerChannel == 4 && elemSizeInBits == 8 ) ||
2751- (opsPerChannel == 2 && elemSizeInBits == 16 ))) {
2787+ if ((opsPerChannel == 4 && elemSizeInBits == 8 ) ||
2788+ (opsPerChannel == 2 && elemSizeInBits == 16 )) {
2789+ assert (!isTransposeRequired ||
2790+ opsPerChannel == numPackedVals &&
2791+ " invalid opsPerChannel for transposed DotOp B" );
27522792 // Use the VNNI packing format for DotOp B layout.
27532793 numValuesPerLoad = numElemsPerLoad / opsPerChannel;
27542794 packedType = i32_ty;
@@ -2815,8 +2855,8 @@ struct LoadOpToBlockIOConversion
28152855 /* tile_width*/ tileWidth,
28162856 /* tile_height*/ tileHeight,
28172857 /* v_blocks*/ vBlocks,
2818- /* transpose*/ false ,
2819- /* vnni_transform*/ useVNNIFormat);
2858+ /* transpose*/ isTransposeRequired ,
2859+ /* vnni_transform*/ !isTransposeRequired && useVNNIFormat);
28202860
28212861 // When strides[0] is 0, we only want to load the first row, so we
28222862 // set the base height to be 1. If tile height is bigger than 1,
0 commit comments