@@ -526,6 +526,21 @@ struct LoadOpConversion
526526 };
527527 auto opIdx = getOpIdx ();
528528
529+ std::optional<LinearLayout> llEncoding =
530+ cast<DistributedEncodingTrait>(encoding).toLinearLayout (
531+ tensorType.getShape ());
532+ assert (llEncoding.has_value () && " invalid dot layout to linear layout" );
533+ LinearEncodingAttr llAttr =
534+ LinearEncodingAttr::get (rewriter.getContext (), *llEncoding);
535+ SmallVector<unsigned > threadOrder = llAttr.getThreadOrder ();
536+ size_t rank = threadOrder.size ();
537+ const bool valueRowMajor =
538+ (threadOrder[rank - 2 ] == 1 && threadOrder[rank - 1 ] == 0 );
539+ assert ((valueRowMajor ||
540+ (threadOrder[rank - 2 ] == 0 && threadOrder[rank - 1 ] == 1 )) &&
541+ " Only row_major or column_major is allowed" );
542+ const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor;
543+
529544 Type eltTy = tensorType.getElementType ();
530545 unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth ();
531546
@@ -539,15 +554,15 @@ struct LoadOpConversion
539554 SmallVector<int64_t > numReps =
540555 dpasLayout.getDPASRepetitions (tensorShape, opIdx);
541556 const SmallVector<unsigned > warpsPerCTA = dpasLayout.getWarpsPerCTA ();
542- SmallVector<unsigned > dpasOrder = triton::gpu::getOrder (dpasLayout);
557+ SmallVector<unsigned > dpasWarpsOrder = triton::gpu::getOrder (dpasLayout);
543558 int threadsPerWarp = triton::gpu::getWarpSize (dpasLayout);
544559
545560 Value warpId = rewriter.create <arith::IndexCastOp>(
546561 loc, i32_ty,
547562 rewriter.create <mlir::gpu::SubgroupIdOp>(loc, /* upperBound=*/ nullptr ));
548563
549564 SmallVector<Value> multiDimWarpId =
550- delinearize (rewriter, loc, warpId, warpsPerCTA, dpasOrder );
565+ delinearize (rewriter, loc, warpId, warpsPerCTA, dpasWarpsOrder );
551566
552567 if (hasDpasLayout) {
553568 // A block load with the DPAS layout but without the DotDpasLayout is
@@ -557,14 +572,6 @@ struct LoadOpConversion
557572 // aligns to the DPAS layout as the DPAS operation output layout
558573 // distributes rows across work items.
559574
560- size_t rank = dpasOrder.size ();
561- const bool valueRowMajor =
562- (dpasOrder[rank - 2 ] == 1 && dpasOrder[rank - 1 ] == 0 );
563- assert ((valueRowMajor ||
564- (dpasOrder[rank - 2 ] == 0 && dpasOrder[rank - 1 ] == 1 )) &&
565- " Only row_major or column_major is allowed" );
566- const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor;
567-
568575 if (isTransposeRequired) {
569576 // TODO: this would likely require a shuffle to match the expected
570577 // ordering coming out of the DPAS layout and requires more
@@ -675,17 +682,6 @@ struct LoadOpConversion
675682 return success ();
676683 }
677684
678- DotOperandEncodingAttr dotLayout = getDotEncoding (tensorType).value ();
679- auto dotOrder = dotLayout.getThreadOrder ();
680-
681- size_t rank = dotOrder.size ();
682- const bool valueRowMajor =
683- (dotOrder[rank - 2 ] == 1 && dotOrder[rank - 1 ] == 0 );
684- assert ((valueRowMajor ||
685- (dotOrder[rank - 2 ] == 0 && dotOrder[rank - 1 ] == 1 )) &&
686- " Only row_major or column_major is allowed" );
687- const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor;
688-
689685 bool isOperandA = (opIdx == DpasEncodingAttr::OpIdx::OperandA);
690686 SmallVector<unsigned > dpasInstShape = isOperandA
691687 ? dpasLayout.getDPASInstShapeA ()
@@ -749,8 +745,8 @@ struct LoadOpConversion
749745 offsetBaseY] =
750746 getValuesFromBlockPointerStruct (adaptor.getPtr (), rewriter);
751747
752- unsigned tileWidth = elemsPerDPASInst[dotOrder [rank - 2 ]];
753- unsigned tileHeight = elemsPerDPASInst[dotOrder [rank - 1 ]];
748+ unsigned tileWidth = elemsPerDPASInst[threadOrder [rank - 2 ]];
749+ unsigned tileHeight = elemsPerDPASInst[threadOrder [rank - 1 ]];
754750 unsigned vBlocks = 1 ;
755751 unsigned numOperandsOuterDimPerLoad = 1 ;
756752 unsigned numOperandsInnerDimPerLoad = 1 ;
0 commit comments