Skip to content

Commit 1f6ac21

Browse files
Fix Flex Attn UT failures (#5138)
Before 92eb442, `getStride(ptr, 0)` would be 1 for column major, i.e, pitch would be smaller than 64 bytes (HW restriction), so `getPitch` would return `nullptr`, and 2d block load would not be generated. This PR goes back to the original behavior and early return for column major. Flex Attn UT: https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/17833144533/job/50703282622 (GOOD) Flex Attn benchmark: https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/17836148975/job/50713713366 (GOOD) Fixes #5129 --------- Signed-off-by: Whitney Tsang <[email protected]>
1 parent 7dfe5d8 commit 1f6ac21

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1867,9 +1867,12 @@ struct LoadOpToBlockIOConversion
18671867
Type resultType = op.getType();
18681868
auto tensorType = cast<RankedTensorType>(resultType);
18691869

1870-
// Step 1: Right now we only support 2D rank matrix of row major or column
1871-
// major.
1870+
// Step 1: Right now we only support 2D rank matrix of row major.
18721871
const bool memoryRowMajor = isMemoryRowMajor(op);
1872+
// FIXME: Add support of column major.
1873+
if (!memoryRowMajor)
1874+
return failure();
1875+
18731876
DpasEncodingAttr::OpIdx opIdx = getOpIdx(tensorType);
18741877

18751878
Attribute encoding = tensorType.getEncoding();
@@ -1912,12 +1915,8 @@ struct LoadOpToBlockIOConversion
19121915
// only support rank of 2 for now.
19131916
return failure();
19141917
}
1915-
const bool valueRowMajor =
1916-
(threadOrder[rank - 2] == 1 && threadOrder[rank - 1] == 0);
1917-
assert((valueRowMajor ||
1918-
(threadOrder[rank - 2] == 0 && threadOrder[rank - 1] == 1)) &&
1919-
"Only row_major or column_major is allowed");
1920-
const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor;
1918+
unsigned contiguousDim = memoryRowMajor ? 1 : 0;
1919+
const bool isTransposeRequired = contiguousDim != colDim;
19211920

19221921
// Step 2: Right now we only support DPAS related layout to simplify the
19231922
// lowering.
@@ -2246,7 +2245,7 @@ struct LoadOpToBlockIOConversion
22462245
return failure();
22472246

22482247
// If the stride is 0, we want to load only the first row.
2249-
int stride = getStride(ptr, 0);
2248+
int stride = getStride(ptr, memoryRowMajor ? 0 : 1);
22502249
unsigned baseHeightInt = (stride == 0 ? 1 : tileHeight);
22512250
Value baseHeight = b.i32_val(baseHeightInt);
22522251
Value baseWidth =

0 commit comments

Comments
 (0)