@@ -424,6 +424,11 @@ struct BlockIOConversionBase : public LoadStoreConversionBase {
424424 int rowDim;
425425 int colDim;
426426 std::optional<SetVector<unsigned >> regPackedBases;
427+
428+ bool isValid () const {
429+ return tileHeight >= 0 && tileWidth >= 0 && numElemPerPackedVal >= 0 &&
430+ vBlocks >= 0 && rowDim >= 0 && colDim >= 0 ;
431+ }
427432 };
428433
429434 // Return the tileHeight, tileWidth, numElemPerPackedVal, vBlocks, row Dim and
@@ -1870,7 +1875,8 @@ struct LoadOpToBlockIOConversion
18701875 std::optional<LinearLayout> llEncoding =
18711876 cast<DistributedEncodingTrait>(encoding).toLinearLayout (
18721877 tensorType.getShape ());
1873- assert (llEncoding.has_value () && " invalid dot layout to linear layout" );
1878+ assert (llEncoding.has_value () &&
1879+ " unexpected failure when getting linear layout" );
18741880 auto llAttr = LinearEncodingAttr::get (rewriter.getContext (), *llEncoding);
18751881 SmallVector<unsigned > threadOrder (llAttr.getThreadOrder ());
18761882 size_t rank = threadOrder.size ();
@@ -2720,36 +2726,31 @@ struct StoreOpToBlockIOConversion
27202726 if (!isBlockIOCandidate (op, enableBlockIOForAllLayout))
27212727 return failure ();
27222728
2723- Location loc = op.getLoc ();
2724- auto b = TritonLLVMOpBuilder (loc, rewriter);
2725- Type resultType = op.getValue ().getType ();
2726- auto tensorType = cast<RankedTensorType>(resultType);
2727- MLIRContext *ctx = rewriter.getContext ();
2728-
27292729 // Get the max tile shape supported by the layout.
2730+ auto tensorType = cast<RankedTensorType>(op.getValue ().getType ());
27302731 Attribute encoding = tensorType.getEncoding ();
27312732 std::optional<LinearLayout> llEncoding =
27322733 cast<DistributedEncodingTrait>(encoding).toLinearLayout (
27332734 tensorType.getShape ());
27342735 assert (llEncoding.has_value () &&
27352736 " unexpected failure when getting linear layout" );
27362737
2737- auto [tileHeight, tileWidth, numPackedVals, vBlocks, rowDim, colDim,
2738- regPackedBases] =
2739- getBlockIOTileSize<8 /* MAX_TILE_HEIGHT*/ >(*llEncoding);
2740- // Limit vBlock to 1
2741- vBlocks = 1 ;
2742- // no valid tile shape for 2D block IO.
2743- if (colDim < 0 )
2738+ BlockIOTileSizeInfo sizeInfo =
2739+ getBlockIOTileSize<8 /* MAX_TILE_HEIGHT*/ >(llEncoding.value ());
2740+ if (!sizeInfo.isValid ())
27442741 return failure ();
2742+ auto [tileHeight, tileWidth, numPackedVals, vBlocks, rowDim, colDim,
2743+ regPackedBases] = sizeInfo;
27452744
27462745 Type eltTy = getTypeConverter ()->convertType (tensorType.getElementType ());
27472746 unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth ();
27482747 unsigned packedElemSizeInBits = elemSizeInBits * numPackedVals;
2749- unsigned numElems = getTotalElemsPerThread (tensorType);
27502748 if (!check2DBlockAddressPayloadRestriction (packedElemSizeInBits, tileWidth))
27512749 return failure ();
27522750
2751+ // Limit vBlock to 1
2752+ vBlocks = 1 ;
2753+
27532754 // TODO: use the axis info to general the handling for both regular pointer
27542755 // and block pointer.
27552756 const bool memoryRowMajor = isMemoryRowMajor (op);
@@ -2759,6 +2760,9 @@ struct StoreOpToBlockIOConversion
27592760 return failure ();
27602761 }
27612762
2763+ Location loc = op.getLoc ();
2764+ auto b = TritonLLVMOpBuilder (loc, rewriter);
2765+ MLIRContext *ctx = rewriter.getContext ();
27622766 Value warpId = rewriter.create <arith::IndexCastOp>(
27632767 loc, i32_ty,
27642768 rewriter.create <mlir::gpu::SubgroupIdOp>(loc,
@@ -2770,6 +2774,7 @@ struct StoreOpToBlockIOConversion
27702774 Value baseWidth, baseHeight, pitch, offsetBaseX, offsetBaseY;
27712775
27722776 Value ptr = op.getPtr ();
2777+ unsigned numElems = getTotalElemsPerThread (tensorType);
27732778 bool isBlockPointer = isTensorPointerType (ptr.getType ());
27742779 if (isBlockPointer) {
27752780 auto [base, width, height, rowStride, colStride, offsetX, offsetY] =
@@ -2794,7 +2799,6 @@ struct StoreOpToBlockIOConversion
27942799 " the number of pointer values is not matched with the number of "
27952800 " elements" );
27962801
2797- unsigned maskConstancyHor = 1 , maskConstancyVer = 1 ;
27982802 Value llMask = adaptor.getMask ();
27992803 // Get the LLVM values for mask
28002804 if (llMask) {
@@ -2806,23 +2810,22 @@ struct StoreOpToBlockIOConversion
28062810 auto axisInfo = const_cast <triton::intel::ModuleAxisInfoAnalysis &>(
28072811 axisAnalysisPass)
28082812 .getAxisInfo (mask);
2813+ unsigned maskConstancyHor = 1 , maskConstancyVer = 1 ;
28092814 if (axisInfo) {
28102815 maskConstancyHor = axisInfo->getConstancy (colDim);
28112816 maskConstancyVer = axisInfo->getConstancy (rowDim);
2812- } else {
2813- maskConstancyHor = 1 ;
2814- maskConstancyVer = 1 ;
2817+ // The mask constancy has to be power of 2 for block IO.
2818+ if (!llvm::isPowerOf2_64 (maskConstancyHor) ||
2819+ !llvm::isPowerOf2_64 (maskConstancyVer))
2820+ return failure ();
28152821 }
2816- } else {
2817- // no mask
2818- maskConstancyHor = std::numeric_limits<unsigned >::max ();
2819- maskConstancyVer = std::numeric_limits<unsigned >::max ();
2820- }
28212822
2822- // Check the constancy of the mask support to load the memory in 2D block.
2823- if (!(maskConstancyHor >= (tileWidth * numPackedVals) &&
2824- maskConstancyVer >= tileHeight))
2825- return failure ();
2823+ // Check the constancy of the mask support to load the memory in 2D
2824+ // block.
2825+ if (!(maskConstancyHor >= (tileWidth * numPackedVals) &&
2826+ maskConstancyVer >= tileHeight))
2827+ return failure ();
2828+ }
28262829
28272830 baseWidth = b.i32_val (
28282831 std::max (64u , vBlocks * tileWidth * (packedElemSizeInBits / 8 )));
0 commit comments