Skip to content

Commit ccd14a9

Browse files
[GEMM] Allow 2D block io when M == 1 (#4540)
When M is 1, `offs_am` is 0, and `a_ptrs` has stride `[0, 1]`. ``` offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) ``` This PR changes MaterializeBlockPointer pass to allow stride to be 0. Further improved GEMM tensor of pointer performance by 8%. ![Screenshot 2025-06-19 102705](https://github.com/user-attachments/assets/d5197a85-a188-4e1f-8db6-9549bc566564) CI: https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/15748650337 Signed-off-by: Whitney Tsang <[email protected]>
1 parent 5bd0ff8 commit ccd14a9

File tree

3 files changed

+33
-8
lines changed

3 files changed

+33
-8
lines changed

test/TritonIntelGPU/materialize-block-pointer.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,20 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
135135
tt.return
136136
}
137137
}
138+
139+
// -----
140+
141+
// COM: Ensure pointer with stride [0, 1] is considered as row major.
142+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}>
143+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32, ttig.support_sg_2d_block} {
144+
tt.func public @tensor_of_ptr(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32}) {
145+
%18 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
146+
%19 = tt.expand_dims %18 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked>
147+
%20 = tt.splat %arg0 : !tt.ptr<bf16> -> tensor<1x32x!tt.ptr<bf16>, #blocked>
148+
%21 = tt.addptr %20, %19 : tensor<1x32x!tt.ptr<bf16>, #blocked>, tensor<1x32xi32, #blocked>
149+
%22 = tt.broadcast %21 : tensor<1x32x!tt.ptr<bf16>, #blocked> -> tensor<256x32x!tt.ptr<bf16>, #blocked>
150+
// CHECK: tt.load {{.*}} {ttig.block_io = "row_major"}
151+
%50 = tt.load %22 : tensor<256x32x!tt.ptr<bf16>, #blocked>
152+
tt.return
153+
}
154+
}

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ struct LoadStoreConversionBase {
139139
const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass)
140140
: targetInfo(targetInfo), axisAnalysisPass(axisAnalysisPass) {}
141141

142-
unsigned getStride(Value ptr, unsigned dim) const {
142+
int getStride(Value ptr, unsigned dim) const {
143143
AxisInfo *axisInfo =
144144
const_cast<triton::intel::ModuleAxisInfoAnalysis &>(axisAnalysisPass)
145145
.getAxisInfo(ptr);
@@ -349,8 +349,12 @@ struct BlockIOConversionBase : public LoadStoreConversionBase {
349349
Location loc = ptr.getLoc();
350350
auto b = TritonLLVMOpBuilder(loc, rewriter);
351351

352-
unsigned stride = getStride(ptr, 0);
353-
if (stride != -1)
352+
int stride = getStride(ptr, 0);
353+
// If the stride is 0, we assume a minimum pitch of 64 bytes.
354+
constexpr int MIN_PITCH = 64;
355+
if (stride == 0)
356+
return b.i32_val(MIN_PITCH);
357+
else if (stride != -1)
354358
return b.i32_val(stride * elemSizeInBits / 8);
355359

356360
// ptrs[{0, 0}] and ptrs[{1, 0}] are currently used to calculate the
@@ -685,7 +689,9 @@ struct PrefetchOpConversion
685689
if (!rowStrideInBytes)
686690
return failure();
687691

688-
Value baseHeight = b.i32_val(tileHeightInElem);
692+
// If the stride is 0, we want to load only the first row.
693+
int stride = getStride(op.getPtr(), 0);
694+
Value baseHeight = b.i32_val(stride == 0 ? 1 : tileHeightInElem);
689695
Value offsetBaseX = b.i32_val(0);
690696
Value offsetBaseY = b.i32_val(0);
691697

@@ -1140,7 +1146,10 @@ struct LoadOpToBlockIOConversion
11401146
if (!pitch)
11411147
return failure();
11421148

1143-
Value baseHeight = b.i32_val(tileHeight);
1149+
// If the stride is 0, we want to load only the first row.
1150+
int stride = getStride(ptr, 0);
1151+
Value baseHeight = b.i32_val(stride == 0 ? 1 : tileHeight);
1152+
11441153
StringAttr kRegister = str_attr("register");
11451154
StringAttr kLane = str_attr("lane");
11461155
StringAttr kWarp = str_attr("warp");

third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,8 @@ struct TritonIntelGPUMaterializeBlockPointerPass
186186
}
187187

188188
// Value -1 is used to represent the unknown stride.
189-
if (axisInfo->getStride(otherDim) <= 0) {
190-
LDBG("Found unknown or non positive stride: "
191-
<< axisInfo->getStride(otherDim));
189+
if (axisInfo->getStride(otherDim) < 0) {
190+
LDBG("Found unknown stride: " << axisInfo->getStride(otherDim));
192191
return false;
193192
}
194193

0 commit comments

Comments
 (0)