Skip to content

Commit f9885fb

Browse files
[LoadStoreOpToLLVM] Improve StoreOpToBlockIOConversion (#5072)
This PR improves the StoreOpToBlockIOConversion implementation by adding validation logic and reorganizing code structure. The changes focus on enhancing error handling and code readability while maintaining the same core functionality. - Adds validation method isValid() to BlockIOTileSizeInfo struct - Reorganizes variable declarations to improve code flow and reduce scope - Improves mask constancy validation with power-of-2 checks Signed-off-by: Whitney Tsang <[email protected]>
1 parent 07ac3a4 commit f9885fb

File tree

1 file changed

+31
-28
lines changed

1 file changed

+31
-28
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,11 @@ struct BlockIOConversionBase : public LoadStoreConversionBase {
424424
int rowDim;
425425
int colDim;
426426
std::optional<SetVector<unsigned>> regPackedBases;
427+
428+
bool isValid() const {
429+
return tileHeight >= 0 && tileWidth >= 0 && numElemPerPackedVal >= 0 &&
430+
vBlocks >= 0 && rowDim >= 0 && colDim >= 0;
431+
}
427432
};
428433

429434
// Return the tileHeight, tileWidth, numElemPerPackedVal, vBlocks, row Dim and
@@ -1870,7 +1875,8 @@ struct LoadOpToBlockIOConversion
18701875
std::optional<LinearLayout> llEncoding =
18711876
cast<DistributedEncodingTrait>(encoding).toLinearLayout(
18721877
tensorType.getShape());
1873-
assert(llEncoding.has_value() && "invalid dot layout to linear layout");
1878+
assert(llEncoding.has_value() &&
1879+
"unexpected failure when getting linear layout");
18741880
auto llAttr = LinearEncodingAttr::get(rewriter.getContext(), *llEncoding);
18751881
SmallVector<unsigned> threadOrder(llAttr.getThreadOrder());
18761882
size_t rank = threadOrder.size();
@@ -2720,36 +2726,31 @@ struct StoreOpToBlockIOConversion
27202726
if (!isBlockIOCandidate(op, enableBlockIOForAllLayout))
27212727
return failure();
27222728

2723-
Location loc = op.getLoc();
2724-
auto b = TritonLLVMOpBuilder(loc, rewriter);
2725-
Type resultType = op.getValue().getType();
2726-
auto tensorType = cast<RankedTensorType>(resultType);
2727-
MLIRContext *ctx = rewriter.getContext();
2728-
27292729
// Get the max tile shape supported by the layout.
2730+
auto tensorType = cast<RankedTensorType>(op.getValue().getType());
27302731
Attribute encoding = tensorType.getEncoding();
27312732
std::optional<LinearLayout> llEncoding =
27322733
cast<DistributedEncodingTrait>(encoding).toLinearLayout(
27332734
tensorType.getShape());
27342735
assert(llEncoding.has_value() &&
27352736
"unexpected failure when getting linear layout");
27362737

2737-
auto [tileHeight, tileWidth, numPackedVals, vBlocks, rowDim, colDim,
2738-
regPackedBases] =
2739-
getBlockIOTileSize<8 /*MAX_TILE_HEIGHT*/>(*llEncoding);
2740-
// Limit vBlock to 1
2741-
vBlocks = 1;
2742-
// no valid tile shape for 2D block IO.
2743-
if (colDim < 0)
2738+
BlockIOTileSizeInfo sizeInfo =
2739+
getBlockIOTileSize<8 /*MAX_TILE_HEIGHT*/>(llEncoding.value());
2740+
if (!sizeInfo.isValid())
27442741
return failure();
2742+
auto [tileHeight, tileWidth, numPackedVals, vBlocks, rowDim, colDim,
2743+
regPackedBases] = sizeInfo;
27452744

27462745
Type eltTy = getTypeConverter()->convertType(tensorType.getElementType());
27472746
unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth();
27482747
unsigned packedElemSizeInBits = elemSizeInBits * numPackedVals;
2749-
unsigned numElems = getTotalElemsPerThread(tensorType);
27502748
if (!check2DBlockAddressPayloadRestriction(packedElemSizeInBits, tileWidth))
27512749
return failure();
27522750

2751+
// Limit vBlock to 1
2752+
vBlocks = 1;
2753+
27532754
// TODO: use the axis info to general the handling for both regular pointer
27542755
// and block pointer.
27552756
const bool memoryRowMajor = isMemoryRowMajor(op);
@@ -2759,6 +2760,9 @@ struct StoreOpToBlockIOConversion
27592760
return failure();
27602761
}
27612762

2763+
Location loc = op.getLoc();
2764+
auto b = TritonLLVMOpBuilder(loc, rewriter);
2765+
MLIRContext *ctx = rewriter.getContext();
27622766
Value warpId = rewriter.create<arith::IndexCastOp>(
27632767
loc, i32_ty,
27642768
rewriter.create<mlir::gpu::SubgroupIdOp>(loc,
@@ -2770,6 +2774,7 @@ struct StoreOpToBlockIOConversion
27702774
Value baseWidth, baseHeight, pitch, offsetBaseX, offsetBaseY;
27712775

27722776
Value ptr = op.getPtr();
2777+
unsigned numElems = getTotalElemsPerThread(tensorType);
27732778
bool isBlockPointer = isTensorPointerType(ptr.getType());
27742779
if (isBlockPointer) {
27752780
auto [base, width, height, rowStride, colStride, offsetX, offsetY] =
@@ -2794,7 +2799,6 @@ struct StoreOpToBlockIOConversion
27942799
"the number of pointer values is not matched with the number of "
27952800
"elements");
27962801

2797-
unsigned maskConstancyHor = 1, maskConstancyVer = 1;
27982802
Value llMask = adaptor.getMask();
27992803
// Get the LLVM values for mask
28002804
if (llMask) {
@@ -2806,23 +2810,22 @@ struct StoreOpToBlockIOConversion
28062810
auto axisInfo = const_cast<triton::intel::ModuleAxisInfoAnalysis &>(
28072811
axisAnalysisPass)
28082812
.getAxisInfo(mask);
2813+
unsigned maskConstancyHor = 1, maskConstancyVer = 1;
28092814
if (axisInfo) {
28102815
maskConstancyHor = axisInfo->getConstancy(colDim);
28112816
maskConstancyVer = axisInfo->getConstancy(rowDim);
2812-
} else {
2813-
maskConstancyHor = 1;
2814-
maskConstancyVer = 1;
2817+
// The mask constancy has to be power of 2 for block IO.
2818+
if (!llvm::isPowerOf2_64(maskConstancyHor) ||
2819+
!llvm::isPowerOf2_64(maskConstancyVer))
2820+
return failure();
28152821
}
2816-
} else {
2817-
// no mask
2818-
maskConstancyHor = std::numeric_limits<unsigned>::max();
2819-
maskConstancyVer = std::numeric_limits<unsigned>::max();
2820-
}
28212822

2822-
// Check the constancy of the mask support to load the memory in 2D block.
2823-
if (!(maskConstancyHor >= (tileWidth * numPackedVals) &&
2824-
maskConstancyVer >= tileHeight))
2825-
return failure();
2823+
// Check the constancy of the mask support to load the memory in 2D
2824+
// block.
2825+
if (!(maskConstancyHor >= (tileWidth * numPackedVals) &&
2826+
maskConstancyVer >= tileHeight))
2827+
return failure();
2828+
}
28262829

28272830
baseWidth = b.i32_val(
28282831
std::max(64u, vBlocks * tileWidth * (packedElemSizeInBits / 8)));

0 commit comments

Comments
 (0)