Skip to content

Commit 0fa61fd

Browse files
[LoadOpToBlockIOConversion] Fix codegen when baseHeight < tileHeight (#5264)
This PR fixes a code generation issue in the LoadOpToBlockIOConversion when baseHeight < tileHeight. The fix adds bounds checking using umin operations to prevent out-of-bounds zero values. Fixes #5250, pytorch/helion#795 --------- Signed-off-by: Whitney Tsang <[email protected]>
1 parent 5d61262 commit 0fa61fd

File tree

2 files changed

+106
-3
lines changed

2 files changed

+106
-3
lines changed
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
import triton
5+
import triton.language as tl
6+
from torch import Tensor
7+
from torch._inductor.runtime import triton_helpers
8+
from typing import Callable
9+
from helion.runtime import default_launcher as _default_launcher
10+
11+
DEVICE = 'xpu'
12+
13+
14+
@triton.jit
15+
def _helion_matmul(x, y, epilogue_closure_0, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr,
16+
_BLOCK_SIZE_2: tl.constexpr):
17+
num_pid_m = tl.cdiv(1024, _BLOCK_SIZE_0)
18+
num_pid_n = tl.cdiv(1024, _BLOCK_SIZE_1)
19+
inner_2d_pid = tl.program_id(0)
20+
num_pid_in_group = 64 * num_pid_n
21+
group_id = inner_2d_pid // num_pid_in_group
22+
first_pid_m = group_id * 64
23+
group_size_m = min(num_pid_m - first_pid_m, 64)
24+
pid_0 = first_pid_m + inner_2d_pid % num_pid_in_group % group_size_m
25+
pid_1 = inner_2d_pid % num_pid_in_group // group_size_m
26+
offset_0 = pid_0 * _BLOCK_SIZE_0
27+
offset_1 = pid_1 * _BLOCK_SIZE_1
28+
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
29+
for offset_2 in tl.range(0, 1024, _BLOCK_SIZE_2):
30+
acc_copy = acc
31+
load = tl.load(
32+
tl.make_block_ptr(x, [1024, 1024], [1024, 1], [offset_0, offset_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_2], [1, 0]),
33+
boundary_check=[0, 1], padding_option='zero')
34+
load_1 = tl.load(
35+
tl.make_block_ptr(y, [1024, 1024], [1024, 1], [offset_2, offset_1], [_BLOCK_SIZE_2, _BLOCK_SIZE_1], [1, 0]),
36+
boundary_check=[0, 1], padding_option='zero')
37+
acc = tl.dot(tl.cast(load, tl.float16), tl.cast(load_1, tl.float16), acc=acc_copy, input_precision='tf32',
38+
out_dtype=tl.float32)
39+
load_2 = tl.load(
40+
tl.make_block_ptr(epilogue_closure_0, [1, 1024], [1024, 1], [0, offset_1], [1, _BLOCK_SIZE_1], [1, 0]))
41+
v_0 = tl.cast(load_2, tl.float32)
42+
v_1 = acc + v_0
43+
v_2 = tl.full([], 0, tl.int32)
44+
v_3 = triton_helpers.maximum(v_2, v_1)
45+
v_4 = tl.cast(v_3, tl.float16)
46+
tl.store(
47+
tl.make_block_ptr(out, [1024, 1024], [1024, 1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]),
48+
v_4, boundary_check=[0, 1])
49+
50+
51+
def matmul(x, y, epilogue: Callable[[Tensor, tuple[Tensor, ...]], Tensor] = lambda acc, tile: acc, *,
52+
_launcher=_default_launcher):
53+
"""
54+
Performs matrix multiplication of x and y with an optional epilogue function.
55+
Args:
56+
x (Tensor): Left matrix of shape [m, k].
57+
y (Tensor): Right matrix of shape [k, n].
58+
epilogue (Callable, optional): Function applied to the accumulator and tile indices
59+
after the matmul. Defaults to identity (no change).
60+
Returns:
61+
Tensor: Resulting matrix of shape [m, n].
62+
"""
63+
m, k = x.size()
64+
k2, n = y.size()
65+
assert k == k2, f'size mismatch {k} != {k2}'
66+
out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
67+
_BLOCK_SIZE_0 = 64
68+
_BLOCK_SIZE_1 = 64
69+
_BLOCK_SIZE_2 = 16
70+
_launcher(_helion_matmul, (triton.cdiv(1024, _BLOCK_SIZE_0) * triton.cdiv(1024, _BLOCK_SIZE_1), ), x, y,
71+
epilogue.__closure__[0].cell_contents, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=2,
72+
num_stages=4)
73+
return out
74+
75+
76+
bias = torch.ones([1, 1024], device=DEVICE, dtype=torch.float16)
77+
args = (
78+
torch.ones([1024, 1024], device=DEVICE, dtype=torch.float16),
79+
torch.ones([1024, 1024], device=DEVICE, dtype=torch.float16),
80+
lambda acc, tile: torch.relu(acc + bias[tile]),
81+
)
82+
83+
bias.fill_(0.7)
84+
args[0].fill_(0.1)
85+
args[1].fill_(0.2)
86+
87+
88+
def make_epilogue(bias):
89+
90+
def epilogue(acc, tile):
91+
return acc + bias[tile[0], tile[1]]
92+
93+
return epilogue
94+
95+
96+
epilogue = make_epilogue(bias)
97+
98+
out = matmul(args[0], args[1], epilogue)
99+
torch.xpu.synchronize()
100+
torch.testing.assert_close(out, torch.relu(args[0] @ args[1] + bias))

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1158,8 +1158,9 @@ struct LoadOpToBlockIOConversion
11581158
for (int repM = 0; repM < repCluster[0]; ++repM) {
11591159

11601160
Value offsetY =
1161-
b.add(warpId0Offset,
1162-
b.i32_val(m * replicaStride[0] + repM * tileHeight));
1161+
b.umin(b.sub(baseHeight, b.i32_val(1)),
1162+
b.add(warpId0Offset, b.i32_val(m * replicaStride[0] +
1163+
repM * tileHeight)));
11631164
for (int repN = 0; repN < repCluster[1]; ++repN) {
11641165
Value offsetX =
11651166
b.add(warpId1Offset,
@@ -1191,7 +1192,9 @@ struct LoadOpToBlockIOConversion
11911192
b.bitcast(load2dOp, LLVM::getVectorType(eltTy, elemsPerLane));
11921193

11931194
for (size_t i = 0; i < elemsPerLane; ++i) {
1194-
Value loaded = b.extract_element(eltTy, ret, b.i32_val(i));
1195+
Value loaded = b.extract_element(
1196+
eltTy, ret,
1197+
b.umin(b.sub(baseHeight, b.i32_val(1)), b.i32_val(i)));
11951198
unpackedLoadedVals.push_back(loaded);
11961199
}
11971200
}

0 commit comments

Comments
 (0)