Skip to content

Commit 542a9a6

Browse files
[BlockIOConversion] Added function to check address payload restriction (#5047)
This PR refactors the 2D block address payload restriction checks by extracting the validation logic into a dedicated function. The change improves code maintainability by eliminating duplicate validation code across load and store operations. Signed-off-by: Whitney Tsang <[email protected]>
1 parent c412ccc commit 542a9a6

File tree

1 file changed

+33
-25
lines changed

1 file changed

+33
-25
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,35 @@ struct BlockIOConversionBase : public LoadStoreConversionBase {
316316
hasDotDpasEncoding(tensorTy);
317317
}
318318

319+
static bool
320+
check2DBlockAddressPayloadRestriction(unsigned packedElemSizeInBits,
321+
unsigned tileWidth) {
322+
// Return false if tile width is not supported by HW.
323+
// Note: Tile width is not changeable.
324+
switch (packedElemSizeInBits) {
325+
case 8:
326+
if (tileWidth < 4 || tileWidth > 64)
327+
return false;
328+
break;
329+
case 16:
330+
if (tileWidth < 2 || tileWidth > 32)
331+
return false;
332+
break;
333+
case 32:
334+
if (tileWidth > 16)
335+
return false;
336+
break;
337+
case 64:
338+
if (tileWidth > 8)
339+
return false;
340+
break;
341+
default:
342+
// invalid element type for 2D block io.
343+
return false;
344+
}
345+
return true;
346+
}
347+
319348
template <
320349
typename OpTy,
321350
std::enable_if_t<llvm::is_one_of<OpTy, triton::gpu::intel::PrefetchOp,
@@ -2182,6 +2211,9 @@ struct LoadOpToBlockIOConversion
21822211
if (!pitch)
21832212
return failure();
21842213

2214+
if (!check2DBlockAddressPayloadRestriction(elemSizeInBits, tileWidth))
2215+
return failure();
2216+
21852217
// If the stride is 0, we want to load only the first row.
21862218
int stride = getStride(ptr, 0);
21872219
unsigned baseHeightInt = (stride == 0 ? 1 : tileHeight);
@@ -2715,32 +2747,8 @@ struct StoreOpToBlockIOConversion
27152747
unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth();
27162748
unsigned packedElemSizeInBits = elemSizeInBits * numPackedVals;
27172749
unsigned numElems = getTotalElemsPerThread(tensorType);
2718-
// 2D block store supports 64 bits element at most.
2719-
if (packedElemSizeInBits > 64)
2750+
if (!check2DBlockAddressPayloadRestriction(packedElemSizeInBits, tileWidth))
27202751
return failure();
2721-
// Tile width is not changeable. Return failure if it is not supported by
2722-
// HW.
2723-
switch (packedElemSizeInBits) {
2724-
case 8:
2725-
if (tileWidth < 4 || tileWidth > 64)
2726-
return failure();
2727-
break;
2728-
case 16:
2729-
if (tileWidth < 2 || tileWidth > 32)
2730-
return failure();
2731-
break;
2732-
case 32:
2733-
if (tileWidth > 16)
2734-
return failure();
2735-
break;
2736-
case 64:
2737-
if (tileWidth > 8)
2738-
return failure();
2739-
break;
2740-
default:
2741-
// invalid element type for 2D block store.
2742-
return failure();
2743-
}
27442752

27452753
// TODO: use the axis info to general the handling for both regular pointer
27462754
// and block pointer.

0 commit comments

Comments
 (0)