@@ -499,7 +499,9 @@ struct LoadOpConversion
499499 auto tensorType = cast<RankedTensorType>(resultType);
500500
501501 // Only lower loadOp with dpas layout encoding.
502- if (!hasDotDpasEncoding (tensorType))
502+ auto encoding = tensorType.getEncoding ();
503+ const bool hasDpasLayout = isa<DpasEncodingAttr>(encoding);
504+ if (!hasDpasLayout && !hasDotDpasEncoding (tensorType))
503505 return failure ();
504506
505507 Attribute blockIOAttr =
@@ -514,20 +516,24 @@ struct LoadOpConversion
514516 " Only row_major or column_major is supported" );
515517 const bool memoryRowMajor = (memoryLayoutInfo == " row_major" );
516518
517- DotOperandEncodingAttr dotLayout = getDotEncoding (tensorType).value ();
518- auto dotOrder = dotLayout.getThreadOrder ();
519- size_t rank = dotOrder.size ();
520- const bool valueRowMajor =
521- (dotOrder[rank - 2 ] == 1 && dotOrder[rank - 1 ] == 0 );
522- assert ((valueRowMajor ||
523- (dotOrder[rank - 2 ] == 0 && dotOrder[rank - 1 ] == 1 )) &&
524- " Only row_major or column_major is allowed" );
525- const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor;
526-
527- auto dpasLayout = cast<DpasEncodingAttr>(dotLayout.getParent ());
519+ auto getOpIdx = [&]() -> DpasEncodingAttr::OpIdx {
520+ if (hasDpasLayout) {
521+ return DpasEncodingAttr::OpIdx::OperandC;
522+ } else {
523+ auto dotLayout = getDotEncoding (tensorType).value ();
524+ return static_cast <DpasEncodingAttr::OpIdx>(dotLayout.getOpIdx ());
525+ }
526+ };
527+ auto opIdx = getOpIdx ();
528528
529- auto opIdx = static_cast <DpasEncodingAttr::OpIdx>(dotLayout.getOpIdx ());
530529 Type eltTy = tensorType.getElementType ();
530+ unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth ();
531+
532+ auto dpasLayout = hasDpasLayout
533+ ? cast<DpasEncodingAttr>(encoding)
534+ : cast<DpasEncodingAttr>(
535+ getDotEncoding (tensorType).value ().getParent ());
536+
531537 const ArrayRef<int64_t > tensorShape = tensorType.getShape ();
532538 unsigned numElems = getTotalElemsPerThread (resultType);
533539 SmallVector<int64_t > numReps =
@@ -543,6 +549,143 @@ struct LoadOpConversion
543549 SmallVector<Value> multiDimWarpId =
544550 delinearize (rewriter, loc, warpId, warpsPerCTA, dpasOrder);
545551
552+ if (hasDpasLayout) {
553+ // A block load with the DPAS layout but without the DotDpasLayout is
554+ // expected to follow the ordering of the DPAS output. For a 2D block
555+ // load, the rows are distributed across work items/SIMD lanes and the
556+ // column vectors are available for each work item to process. This layout
557+ // aligns to the DPAS layout as the DPAS operation output layout
558+ // distributes rows across work items.
559+
560+ size_t rank = dpasOrder.size ();
561+ const bool valueRowMajor =
562+ (dpasOrder[rank - 2 ] == 1 && dpasOrder[rank - 1 ] == 0 );
563+ assert ((valueRowMajor ||
564+ (dpasOrder[rank - 2 ] == 0 && dpasOrder[rank - 1 ] == 1 )) &&
565+ " Only row_major or column_major is allowed" );
566+ const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor;
567+
568+ if (isTransposeRequired) {
569+ // TODO: this would likely require a shuffle to match the expected
570+ // ordering coming out of the DPAS layout and requires more
571+ // investigation
572+ return failure ();
573+ }
574+
575+ MLIRContext *ctx = rewriter.getContext ();
576+
577+ Value elemSizeInBytes = i32_val (elemSizeInBits / 8 );
578+
579+ SmallVector<unsigned > elemsPerInstr = dpasLayout.getDPASInstShapeC ();
580+ int64_t elemsPerLane = product<unsigned >(elemsPerInstr) / threadsPerWarp;
581+ Type load2DGenXType =
582+ LLVM::getFixedVectorType (IntegerType::get (ctx, elemSizeInBits),
583+ elemsPerLane); // make it opaque type.
584+
585+ auto [base, baseWidth, baseHeight, rowStride, colStride, offsetBaseX,
586+ offsetBaseY] =
587+ getValuesFromBlockPointerStruct (adaptor.getPtr (), rewriter);
588+ baseWidth = trunc (i32_ty, baseWidth);
589+ baseHeight = trunc (i32_ty, baseHeight);
590+
591+ auto pitch = trunc (i32_ty, rowStride);
592+
593+ SmallVector<unsigned > repClusterShape = dpasLayout.getShapeC ();
594+ unsigned outerDimWarpNum =
595+ std::min<unsigned >(warpsPerCTA[rank - 2 ],
596+ mlir::ceil<unsigned >(tensorShape[rank - 2 ],
597+ repClusterShape[rank - 2 ]));
598+ unsigned innerDimWarpNum =
599+ std::min<unsigned >(warpsPerCTA[rank - 1 ],
600+ mlir::ceil<unsigned >(tensorShape[rank - 1 ],
601+ repClusterShape[rank - 1 ]));
602+ Value outerDimWarpId =
603+ urem (multiDimWarpId[rank - 2 ], i32_val (outerDimWarpNum));
604+ Value innerDimWarpId =
605+ urem (multiDimWarpId[rank - 1 ], i32_val (innerDimWarpNum));
606+ int64_t numRepOuter = numReps[1 ];
607+ int64_t numRepInner = numReps[2 ];
608+
609+ std::array<unsigned , 2 > replicaStride = {
610+ outerDimWarpNum * repClusterShape[rank - 2 ],
611+ innerDimWarpNum * repClusterShape[rank - 1 ]};
612+ std::array<unsigned , 2 > warpStride = {repClusterShape[rank - 2 ],
613+ repClusterShape[rank - 1 ]};
614+
615+ Value dimWarpId0 = mul (outerDimWarpId, i32_val (warpStride[0 ]));
616+ Value dimWarpId1 = mul (innerDimWarpId, i32_val (warpStride[1 ]));
617+ Value warpId0Offset = add (dimWarpId0, offsetBaseY);
618+ Value warpId1Offset = add (dimWarpId1, offsetBaseX);
619+
620+ ArrayRef<unsigned > repCluster = dpasLayout.getRepCluster ();
621+ unsigned valOffset = 0 ;
622+
623+ SmallVector<Value> unpackedLoadedVals;
624+
625+ for (int m = 0 ; m < numRepOuter; ++m) {
626+ for (int n = 0 ; n < numRepInner; ++n) {
627+ for (int repM = 0 ; repM < repCluster[0 ]; ++repM) {
628+
629+ Value offsetY =
630+ add (warpId0Offset,
631+ i32_val (m * replicaStride[0 ] + repM * elemsPerInstr[0 ]));
632+ for (int repN = 0 ; repN < repCluster[1 ]; ++repN) {
633+ Value offsetX =
634+ add (warpId1Offset,
635+ i32_val (n * replicaStride[1 ] + repN * elemsPerInstr[1 ]));
636+
637+ auto load2dOp = rewriter.create <TritonGEN::Matrix2DBlockLoadOp>(
638+ loc, load2DGenXType,
639+ /* ptr*/ base,
640+ /* base_width*/ mul (baseWidth, elemSizeInBytes),
641+ /* base_height*/ baseHeight,
642+ /* base_pitch*/ mul (pitch, elemSizeInBytes),
643+ /* x*/ trunc (i32_ty, offsetX),
644+ /* y*/ trunc (i32_ty, offsetY),
645+ /* elem_size_in_bits*/ elemSizeInBits,
646+ /* tile_width*/ elemsPerInstr[1 ],
647+ /* tile_height*/ elemsPerInstr[0 ],
648+ /* v_blocks*/ 1 ,
649+ /* transpose*/ false ,
650+ /* vnni_transform*/ false );
651+ if (failed (load2dOp.verify ())) {
652+ // Explicitly invoke verifier because `triton_gen` ops are
653+ // immediately lowered further to a builtin call.
654+ return failure ();
655+ }
656+
657+ Value ret = bitcast (
658+ load2dOp, LLVM::getFixedVectorType (eltTy, elemsPerLane));
659+
660+ for (size_t i = 0 ; i < elemsPerLane; i++) {
661+ Value loaded = extract_element (eltTy, ret, i32_val (i));
662+ unpackedLoadedVals.push_back (loaded);
663+ }
664+ }
665+ }
666+ }
667+ }
668+
669+ TritonGPUToLLVMTypeConverter *typeConverter = getTypeConverter ();
670+ Type llvmResultStructTy = typeConverter->convertType (op.getType ());
671+ Value resultStruct = packLLElements (
672+ loc, typeConverter, unpackedLoadedVals, rewriter, llvmResultStructTy);
673+ rewriter.replaceOp (op, {resultStruct});
674+
675+ return success ();
676+ }
677+
678+ DotOperandEncodingAttr dotLayout = getDotEncoding (tensorType).value ();
679+ auto dotOrder = dotLayout.getThreadOrder ();
680+
681+ size_t rank = dotOrder.size ();
682+ const bool valueRowMajor =
683+ (dotOrder[rank - 2 ] == 1 && dotOrder[rank - 1 ] == 0 );
684+ assert ((valueRowMajor ||
685+ (dotOrder[rank - 2 ] == 0 && dotOrder[rank - 1 ] == 1 )) &&
686+ " Only row_major or column_major is allowed" );
687+ const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor;
688+
546689 bool isOperandA = (opIdx == DpasEncodingAttr::OpIdx::OperandA);
547690 SmallVector<unsigned > dpasInstShape = isOperandA
548691 ? dpasLayout.getDPASInstShapeA ()
@@ -573,11 +716,11 @@ struct LoadOpConversion
573716 // input operands to DPAS.
574717 // TODO: add support for int4 and int2.
575718 unsigned opsPerChannel = dpasLayout.getOpsPerChannel ();
576- unsigned elemBits = eltTy. getIntOrFloatBitWidth ();
577- if (( opsPerChannel == 4 && elemBits == 8 ) ||
578- (opsPerChannel == 2 && elemBits == 16 ) ||
579- (opsPerChannel == 1 && elemBits == 32 )) {
580- loadResultElemType = (isOperandA && elemBits != 32 ) ? i16_ty : i32_ty;
719+ if ((opsPerChannel == 4 && elemSizeInBits == 8 ) ||
720+ ( opsPerChannel == 2 && elemSizeInBits == 16 ) ||
721+ (opsPerChannel == 1 && elemSizeInBits == 32 )) {
722+ loadResultElemType =
723+ (isOperandA && elemSizeInBits != 32 ) ? i16_ty : i32_ty;
581724 packedElemsPerLanePerDPASInst =
582725 isOperandA ? elemsPerLanePerDPASInst / (opsPerChannel == 4 ? 2 : 1 )
583726 : elemsPerLanePerDPASInst / opsPerChannel;
@@ -651,7 +794,7 @@ struct LoadOpConversion
651794
652795 // PVC 2D load supports 64 bytes per row at most. Load multiple dot operands
653796 // by enlarging the vBlocks.
654- unsigned totalBytesPerRowPerDPASOp = tileWidth * elemBits / 8 ;
797+ unsigned totalBytesPerRowPerDPASOp = tileWidth * elemSizeInBits / 8 ;
655798 numOperandsPer2DloadN =
656799 std::min (numOperandsPer2DloadN, 64 / totalBytesPerRowPerDPASOp);
657800 vBlocks = numOperandsPer2DloadN;
@@ -695,12 +838,12 @@ struct LoadOpConversion
695838 baseWidth = trunc (i32_ty, baseWidth);
696839 baseHeight = trunc (i32_ty, baseHeight);
697840
698- unsigned originalElemBits = elemBits ;
841+ const unsigned originalElemBits = elemSizeInBits ;
699842 if (isTransposeRequired) {
700843 // adjust the block io parameter to align HW's limitations on
701844 // transposing load.
702845 tileWidth = tileWidth / (32 / originalElemBits);
703- elemBits = 32 ;
846+ elemSizeInBits = 32 ;
704847 }
705848 Value elemSizeInBytes = i32_val (originalElemBits / 8 );
706849
@@ -747,14 +890,14 @@ struct LoadOpConversion
747890 /* base_pitch*/ mul (pitch, elemSizeInBytes),
748891 /* x*/ trunc (i32_ty, offsetX),
749892 /* y*/ trunc (i32_ty, offsetY),
750- /* elem_size_in_bits*/ elemBits ,
893+ /* elem_size_in_bits*/ elemSizeInBits ,
751894 /* tile_width*/ tileWidth,
752895 /* tile_height*/ tileHeight,
753896 /* v_blocks*/ vBlocks,
754897 /* transpose*/ isTransposeRequired,
755898 /* vnni_transform*/
756899 (usePackedType && !isOperandA && !isTransposeRequired &&
757- eltTy. getIntOrFloatBitWidth () != 32 ));
900+ originalElemBits != 32 ));
758901 if (failed (load2dOp.verify ())) {
759902 // Explicitly invoke verifier because `triton_gen` ops are
760903 // immediately lowered further to a builtin call.
0 commit comments