Skip to content

Commit 058eb80

Browse files
authored
fix: fix an int32 overflow bug in destindex_copy_kv (#907)
1 parent 1009039 commit 058eb80

File tree

2 files changed

+65
-29
lines changed

2 files changed

+65
-29
lines changed

lightllm/common/basemodel/triton_kernel/destindex_copy_kv.py

Lines changed: 54 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,24 @@
66

77
@triton.jit
88
def _fwd_kernel_destindex_copy_kv(
9-
K, Dest_loc,
9+
K,
10+
Dest_loc,
1011
Out,
11-
stride_k_bs, stride_k_h, stride_k_d,
12-
stride_o_bs, stride_o_h, stride_o_d,
12+
stride_k_bs,
13+
stride_k_h,
14+
stride_k_d,
15+
stride_o_bs,
16+
stride_o_h,
17+
stride_o_d,
1318
head_num,
1419
BLOCK_DMODEL: tl.constexpr,
15-
BLOCK_HEAD: tl.constexpr
20+
BLOCK_HEAD: tl.constexpr,
1621
):
1722
cur_index = tl.program_id(0)
1823
offs_h = tl.arange(0, BLOCK_HEAD)
1924
offs_d = tl.arange(0, BLOCK_DMODEL)
2025

21-
dest_index = tl.load(Dest_loc + cur_index)
26+
dest_index = tl.load(Dest_loc + cur_index).to(tl.int64)
2227

2328
k_ptrs = K + cur_index * stride_k_bs + stride_k_h * offs_h[:, None] + stride_k_d * offs_d[None, :]
2429
o_ptrs = Out + dest_index * stride_o_bs + stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :]
@@ -39,9 +44,15 @@ def destindex_copy_kv(K, DestLoc, Out):
3944
num_warps = 1
4045

4146
_fwd_kernel_destindex_copy_kv[grid](
42-
K, DestLoc, Out,
43-
K.stride(0), K.stride(1), K.stride(2),
44-
Out.stride(0), Out.stride(1), Out.stride(2),
47+
K,
48+
DestLoc,
49+
Out,
50+
K.stride(0),
51+
K.stride(1),
52+
K.stride(2),
53+
Out.stride(0),
54+
Out.stride(1),
55+
Out.stride(2),
4556
head_num,
4657
BLOCK_DMODEL=head_dim,
4758
BLOCK_HEAD=BLOCK_HEAD,
@@ -53,23 +64,35 @@ def destindex_copy_kv(K, DestLoc, Out):
5364

5465
@triton.jit
5566
def _fwd_kernel_destindex_copy_quantize_kv(
56-
K, Dest_loc, Out, Out_scale,
57-
stride_k_bs, stride_k_h, stride_k_d,
58-
stride_o_bs, stride_o_h, stride_o_d,
59-
stride_os_bs, stride_os_h, stride_os_d,
67+
K,
68+
Dest_loc,
69+
Out,
70+
Out_scale,
71+
stride_k_bs,
72+
stride_k_h,
73+
stride_k_d,
74+
stride_o_bs,
75+
stride_o_h,
76+
stride_o_d,
77+
stride_os_bs,
78+
stride_os_h,
79+
stride_os_d,
6080
head_num,
6181
BLOCK_DMODEL: tl.constexpr,
62-
BLOCK_HEAD: tl.constexpr
82+
BLOCK_HEAD: tl.constexpr,
6383
):
6484
cur_index = tl.program_id(0)
6585
offs_h = tl.arange(0, BLOCK_HEAD)
6686
offs_d = tl.arange(0, BLOCK_DMODEL)
6787

68-
dest_index = tl.load(Dest_loc + cur_index)
69-
src_data = tl.load(K + cur_index * stride_k_bs + offs_h[:, None] * stride_k_h + stride_k_d * offs_d[None, :],
70-
mask=offs_h[:, None] < head_num, other=0.0)
88+
dest_index = tl.load(Dest_loc + cur_index).to(tl.int64)
89+
src_data = tl.load(
90+
K + cur_index * stride_k_bs + offs_h[:, None] * stride_k_h + stride_k_d * offs_d[None, :],
91+
mask=offs_h[:, None] < head_num,
92+
other=0.0,
93+
)
7194
abs_data = tl.abs(src_data)
72-
data_scale = (tl.max(abs_data, axis=1) / 127.).to(Out_scale.dtype.element_ty)[:, None]
95+
data_scale = (tl.max(abs_data, axis=1) / 127.0).to(Out_scale.dtype.element_ty)[:, None]
7396
q_src_data = (src_data / data_scale).to(tl.int8)
7497
o_ptrs = Out + dest_index * stride_o_bs + stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :]
7598
os_ptrs = Out_scale + dest_index * stride_os_bs + stride_os_h * offs_h[:, None]
@@ -88,10 +111,19 @@ def destindex_copy_quantize_kv(K, DestLoc, Out, Out_scale):
88111
num_warps = 1
89112

90113
_fwd_kernel_destindex_copy_quantize_kv[grid](
91-
K, DestLoc, Out, Out_scale,
92-
K.stride(0), K.stride(1), K.stride(2),
93-
Out.stride(0), Out.stride(1), Out.stride(2),
94-
Out_scale.stride(0), Out_scale.stride(1), Out_scale.stride(2),
114+
K,
115+
DestLoc,
116+
Out,
117+
Out_scale,
118+
K.stride(0),
119+
K.stride(1),
120+
K.stride(2),
121+
Out.stride(0),
122+
Out.stride(1),
123+
Out.stride(2),
124+
Out_scale.stride(0),
125+
Out_scale.stride(1),
126+
Out_scale.stride(2),
95127
head_num,
96128
BLOCK_DMODEL=head_dim,
97129
BLOCK_HEAD=BLOCK_HEAD,
@@ -149,6 +181,6 @@ def test2():
149181
print("cos ", cos(src.flatten().to(torch.float32), (value_dest * scale_dest).flatten().to(torch.float32)))
150182

151183

152-
if __name__ == '__main__':
184+
if __name__ == "__main__":
153185
test1()
154186
test2()

test/benchmark_qps.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,6 @@ async def async_post_stream_openai(url, prompt, max_new_tokens, session):
105105
async for line in response.content:
106106
line = line.strip()
107107
if line:
108-
line = line.decode("utf-8")[6:] # remove "data: "
109-
if line == "[DONE]":
110-
continue
111-
data = json.loads(line)
112-
if not data["choices"][0]["text"]:
113-
continue
114108
current_time = time.time()
115109
elapsed_time = current_time - last_time
116110
used_time.append(elapsed_time)
@@ -249,7 +243,17 @@ async def run_continuous_benchmark(
249243
end_time = [0.0]
250244
pending_tasks = []
251245

252-
async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(limit=10 * reqs_num)) as session:
246+
timeout = aiohttp.ClientTimeout(
247+
total=3600, # 总超时时间1小时
248+
connect=300, # 连接超时5分钟
249+
sock_connect=300,
250+
sock_read=3600,
251+
)
252+
253+
async with aiohttp.ClientSession(
254+
connector=aiohttp.TCPConnector(limit=10 * reqs_num),
255+
timeout=timeout,
256+
) as session:
253257
sender_task = asyncio.create_task(
254258
continuous_sender(
255259
session,

0 commit comments

Comments
 (0)