@@ -706,8 +706,6 @@ struct PrefetchOpConversion
706706 Attribute blockIOAttr =
707707 op->getAttr (TritonIntelGPUDialect::getBlockIOAttrName ());
708708 if (!blockIOAttr) {
709- // TODO: Fallback to gather semantic prefetching. Simply erase the
710- // prefetching op which is not supported for now.
711709 rewriter.eraseOp (op);
712710 return success ();
713711 }
@@ -727,33 +725,20 @@ struct PrefetchOpConversion
727725 // Swap the shape to make it row major and then get the tiling
728726 // size base on row major shape.
729727 std::swap (tensorShape[0 ], tensorShape[1 ]);
730-
731- // Create the new tensor type with swapped row and col.
732- tensorType = RankedTensorType::get (
733- tensorShape, tensorType.getElementType (), tensorType.getEncoding ());
734728 }
735-
736729 unsigned numWarps = triton::gpu::lookupNumWarps (op);
737730
738- SmallVector< unsigned , 2 > shapePerWarp =
739- get2DPrefetchShapePerWarp (tensorType );
731+ auto [tileHeightInElem, tileWidthInElem, warpsM, warpsN] =
732+ get2DPrefetchWarpsPerCTA (tensorShape, eltTy, numWarps );
740733
741- SmallVector< unsigned , 2 > warpsPerCTA =
742- getWarpsPerCTA ( tensorShape, shapePerWarp, numWarps );
734+ auto llEncoding = getLinearLayout (
735+ tensorShape, {tileHeightInElem, tileWidthInElem}, {warpsM, warpsN} );
743736
744- // To adjust the row shape per warp to fit the tensor shape and avoid
745- // duplication in prefetching.
746- unsigned factor =
747- mlir::ceil (shapePerWarp[0 ] * warpsPerCTA[0 ], (unsigned )tensorShape[0 ]);
748- shapePerWarp[0 ] = mlir::ceil (shapePerWarp[0 ], factor);
749-
750- SmallVector<int64_t > numReps = {
751- mlir::ceil<int64_t >(tensorShape[0 ], shapePerWarp[0 ] * warpsPerCTA[0 ]),
752- mlir::ceil<int64_t >(tensorShape[1 ], shapePerWarp[1 ] * warpsPerCTA[1 ])};
737+ unsigned tileSizeInElem = tileHeightInElem * tileWidthInElem;
738+ unsigned numTilesPerWarp =
739+ (tensorShape[0 ] * tensorShape[1 ]) / (tileSizeInElem * numWarps);
753740
754741 unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth ();
755- unsigned tileWidthInElem = shapePerWarp[1 ];
756- unsigned tileHeightInElem = shapePerWarp[0 ];
757742 unsigned vBlocks = 1 ;
758743 switch (elemSizeInBits) {
759744 case 8 :
@@ -774,12 +759,6 @@ struct PrefetchOpConversion
774759 break ;
775760 }
776761
777- Value warpId = rewriter.create <arith::IndexCastOp>(
778- loc, i32_ty,
779- rewriter.create <mlir::gpu::SubgroupIdOp>(loc, /* upperBound=*/ nullptr ));
780- SmallVector<Value> multiDimWarpId =
781- mlir::LLVM::delinearize (rewriter, loc, warpId, warpsPerCTA, {1 , 0 });
782-
783762 auto [base, baseWidth, baseHeight, rowStride, colStride, offsetBaseX,
784763 offsetBaseY] =
785764 getValuesFromBlockPointerStruct (adaptor.getPtr (), rewriter);
@@ -788,6 +767,7 @@ struct PrefetchOpConversion
788767 // Swap the width/height and strides to the row major.
789768 std::swap (baseWidth, baseHeight);
790769 std::swap (colStride, rowStride);
770+ std::swap (offsetBaseX, offsetBaseY);
791771 }
792772
793773 baseWidth = b.mul (baseWidth, b.i64_val (eltTy.getIntOrFloatBitWidth () / 8 ));
@@ -799,46 +779,43 @@ struct PrefetchOpConversion
799779 b.mul (rowStride, b.i64_val (eltTy.getIntOrFloatBitWidth () / 8 ));
800780 rowStrideInBytes = b.trunc (i32_ty, rowStrideInBytes);
801781
802- for (int row = 0 ; row < numReps[0 ]; ++row) {
803- for (int col = 0 ; col < numReps[1 ]; ++col) {
804- Value offsetX, offsetY;
805- offsetX = b.add (
806- // the offset of this warp.
807- b.mul (multiDimWarpId[1 ], b.i32_val (shapePerWarp[1 ])),
808- // add the replica offset with a warp stride.
809- b.i32_val (col * warpsPerCTA[1 ] * shapePerWarp[1 ]));
810- // Round the offset into to the tensor shape
811- offsetX = b.urem (offsetX, b.i32_val (tensorShape[1 ]));
812- offsetX = b.add (offsetX, offsetBaseX);
813- offsetY = b.add (
814- // the offset of this warp.
815- b.mul (multiDimWarpId[0 ], b.i32_val (shapePerWarp[0 ])),
816- // add the replica offset with a warp stride.
817- b.i32_val (row * warpsPerCTA[0 ] * shapePerWarp[0 ]));
818- // Round the offset into to the tensor shape
819- offsetY = b.urem (offsetY, b.i32_val (tensorShape[0 ]));
820- offsetY = b.add (offsetY, offsetBaseY);
821-
822- auto newOp = rewriter.create <TritonGEN::Matrix2DBlockPrefetchOp>(
823- loc,
824- /* ptr*/ base,
825- /* base_width*/ baseWidth,
826- /* base_height*/ baseHeight,
827- /* base_pitch*/ rowStrideInBytes,
828- /* x*/ offsetX,
829- /* y*/ offsetY,
830- /* elem_size_in_bits*/ elemSizeInBits,
831- /* tile_width*/ tileWidthInElem,
832- /* tile_height*/ tileHeightInElem,
833- /* v_blocks*/ vBlocks,
834- /* cache_opt*/ TritonGEN::LoadCacheControl::L1C_L3C);
835- if (failed (newOp.verify ())) {
836- // delete the op so that the verifier will not abort the pass
837- // pipeline later, as we can fail this path and try a different
838- // approach.
839- rewriter.eraseOp (newOp);
840- return failure ();
841- }
782+ MLIRContext *ctx = getContext ();
783+ StringAttr kOffset = S (" offset" );
784+ StringAttr kWarp = S (" warp" );
785+ StringAttr kBlock = S (" block" );
786+
787+ Value warpId = rewriter.create <arith::IndexCastOp>(
788+ loc, i32_ty,
789+ rewriter.create <mlir::gpu::SubgroupIdOp>(loc,
790+ /* upperBound=*/ nullptr ));
791+
792+ for (unsigned tile = 0 ; tile < numTilesPerWarp; ++tile) {
793+ unsigned off = tile * tileSizeInElem;
794+ auto offsets = applyLinearLayout (
795+ loc, rewriter, llEncoding,
796+ {{kOffset , b.i32_val (off)}, {kWarp , warpId}, {kBlock , b.i32_val (0 )}});
797+ Value offsetX = b.add (offsets[1 ].second , offsetBaseX);
798+ Value offsetY = b.add (offsets[0 ].second , offsetBaseY);
799+
800+ auto newOp = rewriter.create <TritonGEN::Matrix2DBlockPrefetchOp>(
801+ loc,
802+ /* ptr*/ base,
803+ /* base_width*/ baseWidth,
804+ /* base_height*/ baseHeight,
805+ /* base_pitch*/ rowStrideInBytes,
806+ /* x*/ offsetX,
807+ /* y*/ offsetY,
808+ /* elem_size_in_bits*/ elemSizeInBits,
809+ /* tile_width*/ tileWidthInElem,
810+ /* tile_height*/ tileHeightInElem,
811+ /* v_blocks*/ vBlocks,
812+ /* cache_opt*/ TritonGEN::LoadCacheControl::L1C_L3C);
813+ if (failed (newOp.verify ())) {
814+ // delete the op so that the verifier will not abort the pass
815+ // pipeline later, as we can fail this path and try a different
816+ // approach.
817+ rewriter.eraseOp (newOp);
818+ return failure ();
842819 }
843820 }
844821
@@ -1050,6 +1027,58 @@ struct PrefetchOpConversion
10501027 rewriter.eraseOp (op);
10511028 return success ();
10521029 }
1030+
1031+ private:
1032+ // tensor shape has to be in row major.
1033+ // Returns:
1034+ // Prefetch Op Shape in {M, N}
1035+ // Warps per CTA in {M, N}
1036+ std::tuple<unsigned , unsigned , unsigned , unsigned >
1037+ get2DPrefetchWarpsPerCTA (const ArrayRef<int64_t > tensorShape, Type eltTy,
1038+ unsigned numWarps) const {
1039+ unsigned rank = tensorShape.size ();
1040+ assert (rank >= 2 && " Only rank >= 2 tensor is supported for now" );
1041+ unsigned dimM = rank - 2 , dimN = rank - 1 ;
1042+ unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth ();
1043+ unsigned elemSizeInBytes = elemSizeInBits / 8 ;
1044+ constexpr unsigned maxBytesPerRow = 64 ;
1045+ unsigned numColsPerPrefOps =
1046+ std::min<unsigned >(tensorShape[dimN], maxBytesPerRow / elemSizeInBytes);
1047+
1048+ unsigned repNumN =
1049+ mlir::ceil ((unsigned )tensorShape[dimN], numColsPerPrefOps);
1050+ unsigned warpsNumN = std::min (numWarps, repNumN);
1051+ unsigned warpsNumM = mlir::ceil (numWarps, warpsNumN);
1052+
1053+ // Get the number of rows per warp to fit the shape to the tensor shape to
1054+ // avoid duplication in prefetching.
1055+ unsigned rowNumPerWarp = mlir::ceil<unsigned >(tensorShape[dimM], warpsNumM);
1056+ unsigned numRowsPerPrefOps = std::min<unsigned >(rowNumPerWarp, 32 );
1057+ SmallVector<unsigned , 2 > tilePerPrefOps{numRowsPerPrefOps,
1058+ numColsPerPrefOps};
1059+
1060+ return {numRowsPerPrefOps, numColsPerPrefOps, warpsNumM, warpsNumN};
1061+ }
1062+
1063+ // Get the linear layout for the cooperative prefetching.
1064+ LinearLayout getLinearLayout (const ArrayRef<int64_t > tensorShape,
1065+ const ArrayRef<unsigned > tileShape,
1066+ const ArrayRef<unsigned > warpsPerCTA) const {
1067+ MLIRContext *ctx = getContext ();
1068+ unsigned rank = warpsPerCTA.size ();
1069+ assert (rank >= 2 && " Only rank >= 2 tensor is supported for now" );
1070+ SmallVector<unsigned > order (rank);
1071+ for (size_t i = 0 ; i < warpsPerCTA.size (); ++i) {
1072+ // The fastest change dim is the first.
1073+ order[i] = rank - i - 1 ;
1074+ }
1075+ LinearLayout ctaLayout = identityStandardND (S (" offset" ), tileShape, order) *
1076+ identityStandardND (S (" warp" ), warpsPerCTA, order);
1077+
1078+ return combineCtaCgaWithShape (std::move (ctaLayout),
1079+ CTALayoutAttr::getDefault (ctx, rank),
1080+ tensorShape);
1081+ }
10531082};
10541083
10551084struct LoadOpToBlockIOConversion
0 commit comments