Skip to content

Commit d2e272c

Browse files
author
niushengxiao
committed
opt: add token vocab cache in the post processing
1 parent 058eb80 commit d2e272c

File tree

9 files changed

+333
-27
lines changed

9 files changed

+333
-27
lines changed

lightllm/common/basemodel/triton_kernel/apply_penalty.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ def _fwd_kernel_apply_penalty(
1919
eos_ids,
2020
mask_eos_reqs,
2121
stride_logit_b,
22-
stride_logit_s,
2322
BLOCK_P: tl.constexpr,
2423
EOS_ID_NUM: tl.constexpr,
24+
IS_EOS_PENALTY: tl.constexpr,
2525
):
2626
cur_batch = tl.program_id(0)
2727
cur_freqency = tl.load(freqency_penalty + cur_batch)
@@ -46,18 +46,19 @@ def _fwd_kernel_apply_penalty(
4646
output_ptr = Logits + cur_batch * stride_logit_b + batch_ids
4747
tl.store(output_ptr, pre_logits, mask=cur_batch_id_offset < cur_batch_end_index)
4848

49-
mask_eos = tl.load(mask_eos_reqs + cur_batch)
50-
exponential_decay_length_penalty = tl.load(exponential_decay_length_penalties + cur_batch)
51-
length_penalty = tl.load(length_penalty_idx + cur_batch)
52-
penalty_scale = tl.exp2(tl.log2(exponential_decay_length_penalty) * length_penalty) - 1
49+
if IS_EOS_PENALTY:
50+
mask_eos = tl.load(mask_eos_reqs + cur_batch)
51+
exponential_decay_length_penalty = tl.load(exponential_decay_length_penalties + cur_batch)
52+
length_penalty = tl.load(length_penalty_idx + cur_batch)
53+
penalty_scale = tl.exp2(tl.log2(exponential_decay_length_penalty) * length_penalty) - 1
5354

54-
for eos_index in range(EOS_ID_NUM):
55-
eos_id = tl.load(eos_ids + eos_index)
56-
cur_eos_logit_ptr = Logits + cur_batch * stride_logit_b + eos_id
57-
cur_eos_logit = tl.load(cur_eos_logit_ptr)
58-
cur_eos_logit = cur_eos_logit + tl.abs(cur_eos_logit) * penalty_scale
59-
cur_eos_logit = tl.where(mask_eos, -10000000.0, cur_eos_logit)
60-
tl.store(cur_eos_logit_ptr, cur_eos_logit)
55+
for eos_index in range(EOS_ID_NUM):
56+
eos_id = tl.load(eos_ids + eos_index)
57+
cur_eos_logit_ptr = Logits + cur_batch * stride_logit_b + eos_id
58+
cur_eos_logit = tl.load(cur_eos_logit_ptr)
59+
cur_eos_logit = cur_eos_logit + tl.abs(cur_eos_logit) * penalty_scale
60+
cur_eos_logit = tl.where(mask_eos, -10000000.0, cur_eos_logit)
61+
tl.store(cur_eos_logit_ptr, cur_eos_logit)
6162
return
6263

6364

@@ -74,6 +75,7 @@ def apply_penalty(
7475
length_penalty_idx,
7576
eos_ids,
7677
mask_eos_reqs,
78+
is_eos_penalty=False,
7779
):
7880
assert Logits.is_contiguous()
7981
BLOCK_P = 1024
@@ -91,9 +93,9 @@ def apply_penalty(
9193
eos_ids,
9294
mask_eos_reqs,
9395
Logits.stride(0),
94-
Logits.stride(1),
9596
num_warps=num_warps,
9697
BLOCK_P=BLOCK_P,
9798
EOS_ID_NUM=eos_ids.shape[0],
99+
IS_EOS_PENALTY=is_eos_penalty,
98100
)
99101
return
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
import torch
2+
3+
import triton
4+
import triton.language as tl
5+
import torch.nn.functional as F
6+
import numpy as np
7+
8+
9+
@triton.jit
10+
def _fwd_kernel_apply_penalty_cache(
11+
Logits,
12+
req_idxs,
13+
presence_penalty,
14+
freqency_penalty,
15+
repetition_penalty,
16+
p_token_vocabs,
17+
stride_logit_b,
18+
stride_p_token_vocabs_b,
19+
BLOCK_P: tl.constexpr,
20+
):
21+
cur_batch = tl.program_id(0)
22+
block_idx = tl.program_id(1)
23+
token_idx = tl.load(req_idxs + cur_batch)
24+
cur_freqency = tl.load(freqency_penalty + cur_batch)
25+
cur_presence = tl.load(presence_penalty + cur_batch)
26+
cur_repetition = tl.load(repetition_penalty + cur_batch)
27+
28+
batch_ids = BLOCK_P * block_idx + tl.arange(0, BLOCK_P)
29+
batch_ids_count = tl.load(
30+
p_token_vocabs + token_idx * stride_p_token_vocabs_b + batch_ids,
31+
mask=batch_ids < stride_p_token_vocabs_b,
32+
other=0,
33+
)
34+
row_start_ptr = Logits + cur_batch * stride_logit_b
35+
cur_offset = row_start_ptr + batch_ids
36+
cur_logits = tl.load(cur_offset, mask=batch_ids_count > 0, other=0.0)
37+
rep_logits = tl.where(cur_logits > 0, cur_logits / cur_repetition, cur_logits * cur_repetition)
38+
freq_logits = rep_logits - batch_ids_count * cur_freqency
39+
pre_logits = freq_logits - cur_presence
40+
output_ptr = Logits + cur_batch * stride_logit_b + batch_ids
41+
tl.store(output_ptr, pre_logits, mask=batch_ids_count > 0)
42+
return
43+
44+
45+
@triton.jit
46+
def _eos_penalty(
47+
Logits,
48+
p_token_lens,
49+
exponential_decay_length_penalties,
50+
length_penalty_idx,
51+
eos_ids,
52+
mask_eos_reqs,
53+
stride_logit_b,
54+
EOS_ID_NUM: tl.constexpr,
55+
):
56+
cur_batch = tl.program_id(0)
57+
exponential_decay_length_penalty = tl.load(exponential_decay_length_penalties + cur_batch)
58+
token_lens = tl.load(p_token_lens + cur_batch)
59+
length_penalty = tl.maximum(token_lens - tl.load(length_penalty_idx + cur_batch), 0)
60+
penalty_scale = tl.exp2(tl.log2(exponential_decay_length_penalty) * length_penalty) - 1
61+
mask_eos = tl.load(mask_eos_reqs + cur_batch)
62+
for eos_index in range(EOS_ID_NUM):
63+
eos_id = tl.load(eos_ids + eos_index)
64+
cur_eos_logit_ptr = Logits + cur_batch * stride_logit_b + eos_id
65+
cur_eos_logit = tl.load(cur_eos_logit_ptr)
66+
cur_eos_logit = cur_eos_logit + tl.abs(cur_eos_logit) * penalty_scale
67+
cur_eos_logit = tl.where(token_lens < mask_eos, -10000000.0, cur_eos_logit)
68+
tl.store(cur_eos_logit_ptr, cur_eos_logit)
69+
return
70+
71+
72+
@torch.no_grad()
73+
def apply_penalty_cache(
74+
Logits,
75+
req_idxs,
76+
presence_penalty,
77+
freqency_penalty,
78+
repetition_penalty,
79+
p_token_vocabs,
80+
p_token_lens,
81+
exponential_decay_length_penalties,
82+
length_penalty_idx,
83+
eos_ids,
84+
mask_eos_reqs,
85+
is_eos_penalty=False,
86+
):
87+
assert Logits.is_contiguous()
88+
BLOCK_P = 1024
89+
num_warps = 8
90+
vocab_size = p_token_vocabs.shape[1]
91+
_fwd_kernel_apply_penalty_cache[(Logits.shape[0], triton.cdiv(vocab_size, BLOCK_P))](
92+
Logits,
93+
req_idxs,
94+
presence_penalty,
95+
freqency_penalty,
96+
repetition_penalty,
97+
p_token_vocabs,
98+
Logits.stride(0),
99+
p_token_vocabs.stride(0),
100+
num_warps=num_warps,
101+
BLOCK_P=BLOCK_P,
102+
)
103+
if is_eos_penalty:
104+
p_token_lens = p_token_vocabs[req_idxs].count_nonzero(dim=1) if p_token_lens is None else p_token_lens
105+
_eos_penalty[(Logits.shape[0],)](
106+
Logits,
107+
p_token_lens,
108+
exponential_decay_length_penalties,
109+
length_penalty_idx,
110+
eos_ids,
111+
mask_eos_reqs,
112+
Logits.stride(0),
113+
num_warps=num_warps,
114+
EOS_ID_NUM=eos_ids.shape[0],
115+
)
116+
return
117+
118+
119+
if __name__ == "__main__":
120+
from .apply_penalty import apply_penalty
121+
122+
bs = 200
123+
vocab_size = 150000
124+
p_tokens = 2000
125+
repseats = 1000
126+
req_idxs = torch.arange(bs).cuda()
127+
logits = torch.randn((bs, vocab_size), dtype=torch.float32).cuda()
128+
logits2 = logits.clone()
129+
130+
presence_penalty = torch.randn((bs,), dtype=torch.float32).cuda() + 1e-5
131+
freqency_penalty = torch.randn((bs,), dtype=torch.float32).cuda()
132+
repetition_penalty = torch.randn((bs,), dtype=torch.float32).cuda()
133+
exponential_decay_length_penalties = torch.rand(bs).cuda()
134+
eos_ids = torch.tensor([999]).cuda()
135+
136+
p_seq_len = torch.cat([torch.tensor([0]), torch.randint(1, p_tokens, (bs,))]).cuda()
137+
p_token_ids = torch.randint(0, vocab_size, (p_seq_len.sum(),)).cuda()
138+
i = 0
139+
for s_l in p_seq_len[1:]:
140+
p_token_ids[i : i + s_l] = torch.randperm(vocab_size)[:s_l]
141+
i += s_l
142+
p_token_counts = torch.randint(1, repseats, (p_seq_len.sum(),)).cuda()
143+
p_cumsum_seq_len = p_seq_len.cumsum(dim=0).cuda()
144+
p_token_vocabs = torch.zeros((bs, vocab_size), dtype=torch.int16).cuda()
145+
i = 0
146+
b = 0
147+
for token_id, count in zip(p_token_ids, p_token_counts):
148+
p_token_vocabs[b][token_id] = count
149+
i += 1
150+
if i == p_seq_len[b + 1]:
151+
b += 1
152+
i = 0
153+
154+
p_token_lens = p_token_vocabs.sum(dim=1).cuda()
155+
length_penalty_idx = torch.randint(0, p_tokens, (bs,)).cuda()
156+
len_idx = torch.tensor([max(p_token_lens[i] - length_penalty_idx[i], 0) for i in range(bs)]).cuda()
157+
mask_eos_reqs = torch.randint(1, p_tokens, (bs,)).cuda()
158+
mask_bool = torch.tensor([p_token_lens[i] < mask_eos_reqs[i] for i in range(bs)]).cuda()
159+
160+
fn1 = lambda: apply_penalty(
161+
logits,
162+
presence_penalty,
163+
freqency_penalty,
164+
repetition_penalty,
165+
p_token_ids,
166+
p_token_counts,
167+
p_cumsum_seq_len,
168+
exponential_decay_length_penalties,
169+
len_idx,
170+
eos_ids,
171+
mask_bool,
172+
)
173+
174+
fn2 = lambda: apply_penalty_cache(
175+
logits2,
176+
req_idxs,
177+
presence_penalty,
178+
freqency_penalty,
179+
repetition_penalty,
180+
p_token_vocabs,
181+
p_token_lens,
182+
exponential_decay_length_penalties,
183+
length_penalty_idx,
184+
eos_ids,
185+
mask_eos_reqs,
186+
)
187+
fn1()
188+
fn2()
189+
cos = F.cosine_similarity(logits, logits2).mean()
190+
print("cos =", cos)
191+
assert torch.allclose(logits, logits2, atol=1e-2, rtol=0)
192+
193+
ms1 = triton.testing.do_bench(fn1)
194+
ms2 = triton.testing.do_bench(fn2)
195+
print("ms1 =", ms1, "ms2 =", ms2)

lightllm/common/req_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana
5858
(max_request_num + 1, max_sequence_length), dtype=torch.int32, device="cuda"
5959
)
6060
self.mem_manager = mem_manager
61+
self.req_sample_parms_manager = None
6162
self.max_request_num = max_request_num
6263
self.HOLD_REQUEST_ID = max_request_num
6364

@@ -67,6 +68,8 @@ def alloc(self):
6768
def free(self, free_req_indexes: List[int], free_token_index):
6869
for req_index in free_req_indexes:
6970
self.req_list.free(req_index)
71+
if self.req_sample_parms_manager is not None:
72+
self.req_sample_parms_manager.p_token_vocabs[free_req_indexes] = 0
7073

7174
if self.req_list.is_all_free():
7275
logger.debug(f"freed all request size {self.req_list.can_alloc_size}")

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,16 @@
1717
from lightllm.common.basemodel.infer_lock import g_infer_state_lock
1818
from lightllm.server.multimodal_params import MultimodalParams
1919
from lightllm.utils.custom_kernel_utis import custom_cat
20+
from lightllm.utils.envs_utils import enable_env_vars
2021

2122
logger = init_logger(__name__)
2223

2324

25+
class ReqSampleParmsManager:
26+
def __init__(self, max_request_num, vocab_size):
27+
self.p_token_vocabs = torch.zeros((max_request_num, vocab_size), dtype=torch.int16, device="cuda")
28+
29+
2430
@dataclass
2531
class InferenceContext:
2632
req_manager: ReqManager = None # gpu 请求管理
@@ -36,6 +42,8 @@ class InferenceContext:
3642
def register(
3743
self, req_manager: ReqManager, radix_cache: RadixCache, shm_req_manager: ShmReqManager, vocab_size: int
3844
):
45+
if enable_env_vars("ENABLE_REQ_PARAM_CACHE"):
46+
req_manager.req_sample_parms_manager = ReqSampleParmsManager(req_manager.max_request_num, vocab_size)
3947
self.req_manager = req_manager
4048
self.radix_cache = radix_cache
4149
self.shm_req_manager = shm_req_manager
@@ -55,7 +63,6 @@ def get_overlap_stream(self) -> torch.cuda.Stream:
5563
def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_req_obj=True):
5664
request_ids = []
5765
for r in requests:
58-
5966
r_id, r_index, multimodal_params, _ = r
6067
if r_id not in self.requests_mapping.keys():
6168
r_obj = InferReq(
@@ -264,10 +271,19 @@ def init_all(self):
264271
self.shm_req.link_prompt_ids_shm_array()
265272
self.shm_req.link_logprobs_shm_array()
266273
self.sampling_param: InferSamplingParams = InferSamplingParams(self.shm_req, self.vocab_size)
267-
if self.sampling_param.shm_param.input_penalty:
268-
self.out_token_id_count = collections.Counter(self.shm_req.get_prompt_ids())
274+
275+
if enable_env_vars("ENABLE_REQ_PARAM_CACHE"):
276+
if self.sampling_param.shm_param.input_penalty:
277+
idxs = torch.bincount(self.shm_req.get_prompt_ids())
278+
g_infer_context.req_manager.req_sample_parms_manager.p_token_vocabs[self.req_idx][
279+
: len(idxs)
280+
] = idxs
281+
self.out_token_id_count = None
269282
else:
270-
self.out_token_id_count = collections.defaultdict(int)
283+
if self.sampling_param.shm_param.input_penalty:
284+
self.out_token_id_count = collections.Counter(self.shm_req.get_prompt_ids())
285+
else:
286+
self.out_token_id_count = collections.defaultdict(int)
271287

272288
self.stop_sequences = self.sampling_param.shm_param.stop_sequences.to_list()
273289
# token healing mode 才被使用的管理对象

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,8 @@ def _post_handle(
282282
req_obj.set_next_gen_token_id(next_token_id, next_token_logprob)
283283
req_obj.cur_output_len += 1
284284

285-
req_obj.out_token_id_count[next_token_id] += 1
285+
if req_obj.out_token_id_count is not None:
286+
req_obj.out_token_id_count[next_token_id] += 1
286287
req_obj.update_finish_status(self.eos_id)
287288

288289
if extra_post_req_handle_func is not None:

lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_return_all_prompt_logprobs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ def prefill(self, run_reqs: List[Tuple]):
5555
for i in range(req_obj.shm_req.input_len - 1):
5656
req_obj.shm_req.shm_logprobs.arr[i + 1] = cur_logprobs[i]
5757

58-
req_obj.out_token_id_count[next_token_id] += 1
58+
if req_obj.out_token_id_count is not None:
59+
req_obj.out_token_id_count[next_token_id] += 1
5960
req_obj.update_finish_status(self.eos_id)
6061

6162
if req_obj.finish_status.is_finished() or req_obj.shm_req.router_aborted:

lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_reward_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ def prefill(self, reqs: List[Tuple]):
3232
req_obj.set_next_gen_token_id(next_token_id, next_token_logprob)
3333
req_obj.cur_output_len += 1
3434

35-
req_obj.out_token_id_count[next_token_id] += 1
35+
if req_obj.out_token_id_count is not None:
36+
req_obj.out_token_id_count[next_token_id] += 1
3637
req_obj.update_finish_status(self.eos_id)
3738

3839
if req_obj.finish_status.is_finished() or req_obj.shm_req.router_aborted:

0 commit comments

Comments
 (0)