Skip to content

Commit 6acd2af

Browse files
author
niushengxiao
committed
feat: add req param cache for decode
1 parent c843243 commit 6acd2af

File tree

9 files changed

+366
-22
lines changed

9 files changed

+366
-22
lines changed

lightllm/common/basemodel/triton_kernel/apply_penalty.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ def _fwd_kernel_apply_penalty(
1919
eos_ids,
2020
mask_eos_reqs,
2121
stride_logit_b,
22-
stride_logit_s,
2322
BLOCK_P: tl.constexpr,
2423
EOS_ID_NUM: tl.constexpr,
24+
IS_EOS_PENALTY: tl.constexpr,
2525
):
2626
cur_batch = tl.program_id(0)
2727
cur_freqency = tl.load(freqency_penalty + cur_batch)
@@ -46,18 +46,19 @@ def _fwd_kernel_apply_penalty(
4646
output_ptr = Logits + cur_batch * stride_logit_b + batch_ids
4747
tl.store(output_ptr, pre_logits, mask=cur_batch_id_offset < cur_batch_end_index)
4848

49-
mask_eos = tl.load(mask_eos_reqs + cur_batch)
50-
exponential_decay_length_penalty = tl.load(exponential_decay_length_penalties + cur_batch)
51-
length_penalty = tl.load(length_penalty_idx + cur_batch)
52-
penalty_scale = tl.exp2(tl.log2(exponential_decay_length_penalty) * length_penalty) - 1
49+
if IS_EOS_PENALTY:
50+
mask_eos = tl.load(mask_eos_reqs + cur_batch)
51+
exponential_decay_length_penalty = tl.load(exponential_decay_length_penalties + cur_batch)
52+
length_penalty = tl.load(length_penalty_idx + cur_batch)
53+
penalty_scale = tl.exp2(tl.log2(exponential_decay_length_penalty) * length_penalty) - 1
5354

54-
for eos_index in range(EOS_ID_NUM):
55-
eos_id = tl.load(eos_ids + eos_index)
56-
cur_eos_logit_ptr = Logits + cur_batch * stride_logit_b + eos_id
57-
cur_eos_logit = tl.load(cur_eos_logit_ptr)
58-
cur_eos_logit = cur_eos_logit + tl.abs(cur_eos_logit) * penalty_scale
59-
cur_eos_logit = tl.where(mask_eos, -10000000.0, cur_eos_logit)
60-
tl.store(cur_eos_logit_ptr, cur_eos_logit)
55+
for eos_index in range(EOS_ID_NUM):
56+
eos_id = tl.load(eos_ids + eos_index)
57+
cur_eos_logit_ptr = Logits + cur_batch * stride_logit_b + eos_id
58+
cur_eos_logit = tl.load(cur_eos_logit_ptr)
59+
cur_eos_logit = cur_eos_logit + tl.abs(cur_eos_logit) * penalty_scale
60+
cur_eos_logit = tl.where(mask_eos, -10000000.0, cur_eos_logit)
61+
tl.store(cur_eos_logit_ptr, cur_eos_logit)
6162
return
6263

6364

@@ -74,6 +75,7 @@ def apply_penalty(
7475
length_penalty_idx,
7576
eos_ids,
7677
mask_eos_reqs,
78+
is_eos_penalty=False,
7779
):
7880
assert Logits.is_contiguous()
7981
BLOCK_P = 1024
@@ -91,9 +93,9 @@ def apply_penalty(
9193
eos_ids,
9294
mask_eos_reqs,
9395
Logits.stride(0),
94-
Logits.stride(1),
9596
num_warps=num_warps,
9697
BLOCK_P=BLOCK_P,
9798
EOS_ID_NUM=eos_ids.shape[0],
99+
IS_EOS_PENALTY=is_eos_penalty,
98100
)
99101
return
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
import torch
2+
3+
import triton
4+
import triton.language as tl
5+
import torch.nn.functional as F
6+
import numpy as np
7+
8+
9+
@triton.jit
10+
def _fwd_kernel_apply_penalty_cache(
11+
Logits,
12+
req_idxs,
13+
presence_penalty,
14+
freqency_penalty,
15+
repetition_penalty,
16+
p_token_vocabs,
17+
stride_logit_b,
18+
stride_p_token_vocabs_b,
19+
BLOCK_P: tl.constexpr,
20+
):
21+
cur_batch = tl.program_id(0)
22+
block_idx = tl.program_id(1)
23+
token_idx = tl.load(req_idxs + cur_batch)
24+
cur_freqency = tl.load(freqency_penalty + token_idx)
25+
cur_presence = tl.load(presence_penalty + token_idx)
26+
cur_repetition = tl.load(repetition_penalty + token_idx)
27+
28+
batch_ids = BLOCK_P * block_idx + tl.arange(0, BLOCK_P)
29+
batch_ids_count = tl.load(
30+
p_token_vocabs + token_idx * stride_p_token_vocabs_b + batch_ids,
31+
mask=batch_ids < stride_p_token_vocabs_b,
32+
other=0,
33+
)
34+
row_start_ptr = Logits + cur_batch * stride_logit_b
35+
cur_offset = row_start_ptr + batch_ids
36+
cur_logits = tl.load(cur_offset, mask=batch_ids_count > 0, other=0.0)
37+
rep_logits = tl.where(cur_logits > 0, cur_logits / cur_repetition, cur_logits * cur_repetition)
38+
freq_logits = rep_logits - batch_ids_count * cur_freqency
39+
pre_logits = freq_logits - cur_presence
40+
output_ptr = Logits + cur_batch * stride_logit_b + batch_ids
41+
tl.store(output_ptr, pre_logits, mask=batch_ids_count > 0)
42+
return
43+
44+
45+
@triton.jit
46+
def _eos_penalty(
47+
Logits,
48+
req_idxs,
49+
p_token_lens,
50+
exponential_decay_length_penalties,
51+
length_penalty_idx,
52+
eos_ids,
53+
mask_eos_reqs,
54+
stride_logit_b,
55+
EOS_ID_NUM: tl.constexpr,
56+
):
57+
cur_batch = tl.program_id(0)
58+
token_idx = tl.load(req_idxs + cur_batch)
59+
exponential_decay_length_penalty = tl.load(exponential_decay_length_penalties + token_idx)
60+
token_lens = tl.load(p_token_lens + cur_batch)
61+
length_penalty = tl.maximum(token_lens - tl.load(length_penalty_idx + token_idx), 0)
62+
penalty_scale = tl.exp2(tl.log2(exponential_decay_length_penalty) * length_penalty) - 1
63+
mask_eos = tl.load(mask_eos_reqs + token_idx)
64+
for eos_index in range(EOS_ID_NUM):
65+
eos_id = tl.load(eos_ids + eos_index)
66+
cur_eos_logit_ptr = Logits + cur_batch * stride_logit_b + eos_id
67+
cur_eos_logit = tl.load(cur_eos_logit_ptr)
68+
cur_eos_logit = cur_eos_logit + tl.abs(cur_eos_logit) * penalty_scale
69+
cur_eos_logit = tl.where(token_lens < mask_eos, -10000000.0, cur_eos_logit)
70+
tl.store(cur_eos_logit_ptr, cur_eos_logit)
71+
return
72+
73+
74+
@torch.no_grad()
75+
def apply_penalty_cache(
76+
Logits,
77+
req_idxs,
78+
presence_penalty,
79+
freqency_penalty,
80+
repetition_penalty,
81+
p_token_vocabs,
82+
p_token_lens,
83+
exponential_decay_length_penalties,
84+
length_penalty_idx,
85+
eos_ids,
86+
mask_eos_reqs,
87+
is_eos_penalty=False,
88+
):
89+
assert Logits.is_contiguous()
90+
BLOCK_P = 1024
91+
num_warps = 8
92+
vocab_size = p_token_vocabs.shape[1]
93+
_fwd_kernel_apply_penalty_cache[(Logits.shape[0], triton.cdiv(vocab_size, BLOCK_P))](
94+
Logits,
95+
req_idxs,
96+
presence_penalty,
97+
freqency_penalty,
98+
repetition_penalty,
99+
p_token_vocabs,
100+
Logits.stride(0),
101+
p_token_vocabs.stride(0),
102+
num_warps=num_warps,
103+
BLOCK_P=BLOCK_P,
104+
)
105+
if is_eos_penalty:
106+
p_token_lens = p_token_vocabs[req_idxs].sum(dim=1).cuda() if p_token_lens is None else p_token_lens
107+
_eos_penalty[(Logits.shape[0],)](
108+
Logits,
109+
req_idxs,
110+
p_token_lens,
111+
exponential_decay_length_penalties,
112+
length_penalty_idx,
113+
eos_ids,
114+
mask_eos_reqs,
115+
Logits.stride(0),
116+
num_warps=num_warps,
117+
EOS_ID_NUM=eos_ids.shape[0],
118+
)
119+
return
120+
121+
122+
if __name__ == "__main__":
123+
from .apply_penalty import apply_penalty
124+
125+
bs = 200
126+
vocab_size = 150000
127+
p_tokens = 2000
128+
repseats = 1000
129+
req_idxs = torch.arange(bs).cuda()
130+
logits = torch.randn((bs, vocab_size), dtype=torch.float32).cuda()
131+
logits2 = logits.clone()
132+
133+
presence_penalty = torch.randn((bs,), dtype=torch.float32).cuda() + 1e-5
134+
freqency_penalty = torch.randn((bs,), dtype=torch.float32).cuda()
135+
repetition_penalty = torch.randn((bs,), dtype=torch.float32).cuda()
136+
exponential_decay_length_penalties = torch.rand(bs).cuda()
137+
eos_ids = torch.tensor([999]).cuda()
138+
139+
p_seq_len = torch.cat([torch.tensor([0]), torch.randint(1, p_tokens, (bs,))]).cuda()
140+
p_token_ids = torch.randint(0, vocab_size, (p_seq_len.sum(),)).cuda()
141+
i = 0
142+
for s_l in p_seq_len[1:]:
143+
p_token_ids[i : i + s_l] = torch.randperm(vocab_size)[:s_l]
144+
i += s_l
145+
p_token_counts = torch.randint(1, repseats, (p_seq_len.sum(),)).cuda()
146+
p_cumsum_seq_len = p_seq_len.cumsum(dim=0).cuda()
147+
p_token_vocabs = torch.zeros((bs, vocab_size), dtype=torch.int16).cuda()
148+
i = 0
149+
b = 0
150+
for token_id, count in zip(p_token_ids, p_token_counts):
151+
p_token_vocabs[b][token_id] = count
152+
i += 1
153+
if i == p_seq_len[b + 1]:
154+
b += 1
155+
i = 0
156+
157+
p_token_lens = p_token_vocabs.sum(dim=1).cuda()
158+
length_penalty_idx = torch.randint(0, p_tokens, (bs,)).cuda()
159+
len_idx = torch.tensor([max(p_token_lens[i] - length_penalty_idx[i], 0) for i in range(bs)]).cuda()
160+
mask_eos_reqs = torch.randint(1, p_tokens, (bs,)).cuda()
161+
mask_bool = torch.tensor([p_token_lens[i] < mask_eos_reqs[i] for i in range(bs)]).cuda()
162+
163+
fn1 = lambda: apply_penalty(
164+
logits,
165+
presence_penalty,
166+
freqency_penalty,
167+
repetition_penalty,
168+
p_token_ids,
169+
p_token_counts,
170+
p_cumsum_seq_len,
171+
exponential_decay_length_penalties,
172+
len_idx,
173+
eos_ids,
174+
mask_bool,
175+
)
176+
177+
fn2 = lambda: apply_penalty_cache(
178+
logits2,
179+
req_idxs,
180+
presence_penalty,
181+
freqency_penalty,
182+
repetition_penalty,
183+
p_token_vocabs,
184+
p_token_lens,
185+
exponential_decay_length_penalties,
186+
length_penalty_idx,
187+
eos_ids,
188+
mask_eos_reqs,
189+
)
190+
fn1()
191+
fn2()
192+
cos = F.cosine_similarity(logits, logits2).mean()
193+
print("cos =", cos)
194+
assert torch.allclose(logits, logits2, atol=1e-2, rtol=0)
195+
196+
ms1 = triton.testing.do_bench(fn1)
197+
ms2 = triton.testing.do_bench(fn2)
198+
print("ms1 =", ms1, "ms2 =", ms2)

lightllm/common/req_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana
5858
(max_request_num + 1, max_sequence_length), dtype=torch.int32, device="cuda"
5959
)
6060
self.mem_manager = mem_manager
61+
self.req_sample_parms_manager = None
6162
self.max_request_num = max_request_num
6263
self.HOLD_REQUEST_ID = max_request_num
6364

@@ -67,6 +68,8 @@ def alloc(self):
6768
def free(self, free_req_indexes: List[int], free_token_index):
6869
for req_index in free_req_indexes:
6970
self.req_list.free(req_index)
71+
if self.req_sample_parms_manager is not None:
72+
self.req_sample_parms_manager.p_token_vocabs[free_req_indexes] = 0
7073

7174
if self.req_list.is_all_free():
7275
logger.debug(f"freed all request size {self.req_list.can_alloc_size}")

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,52 @@
2121
logger = init_logger(__name__)
2222

2323

24+
class ReqSampleParmsManager:
25+
def __init__(self, max_request_num, vocab_size):
26+
self.presence_penalties = torch.empty(max_request_num, dtype=torch.float, device="cpu", pin_memory=True).cuda(
27+
non_blocking=True
28+
)
29+
self.frequency_penalties = torch.empty(max_request_num, dtype=torch.float, device="cpu", pin_memory=True).cuda(
30+
non_blocking=True
31+
)
32+
self.repetition_penalties = torch.empty(max_request_num, dtype=torch.float, device="cpu", pin_memory=True).cuda(
33+
non_blocking=True
34+
)
35+
self.exponential_decay_length_penalties = torch.empty(
36+
max_request_num, dtype=torch.float, device="cpu", pin_memory=True
37+
).cuda(non_blocking=True)
38+
self.temperatures = torch.empty(max_request_num, dtype=torch.float, device="cpu", pin_memory=True).cuda(
39+
non_blocking=True
40+
)
41+
self.top_ps = torch.empty(max_request_num, dtype=torch.float, device="cpu", pin_memory=True).cuda(
42+
non_blocking=True
43+
)
44+
self.top_ks = torch.empty(max_request_num, dtype=torch.int32, device="cpu", pin_memory=True).cuda(
45+
non_blocking=True
46+
)
47+
self.length_penalty_idx = torch.empty(max_request_num, dtype=torch.int32, device="cpu", pin_memory=True).cuda(
48+
non_blocking=True
49+
)
50+
self.mask_eos_reqs = torch.empty(max_request_num, dtype=torch.int32, device="cpu", pin_memory=True).cuda(
51+
non_blocking=True
52+
)
53+
self.p_token_vocabs = torch.zeros(
54+
(max_request_num, vocab_size), dtype=torch.int16, device="cpu", pin_memory=True
55+
).cuda(non_blocking=True)
56+
57+
def set_sample_params(self, req_idx, sampling_param):
58+
self.presence_penalties[req_idx] = sampling_param.shm_param.presence_penalty
59+
self.frequency_penalties[req_idx] = sampling_param.shm_param.frequency_penalty
60+
self.repetition_penalties[req_idx] = sampling_param.shm_param.repetition_penalty
61+
tpl = sampling_param.shm_param.exponential_decay_length_penalty.to_tuple()
62+
self.exponential_decay_length_penalties[req_idx] = tpl[1]
63+
self.temperatures[req_idx] = sampling_param.shm_param.temperature
64+
self.top_ps[req_idx] = sampling_param.shm_param.top_p
65+
self.top_ks[req_idx] = sampling_param.shm_param.top_k
66+
self.length_penalty_idx[req_idx] = tpl[0]
67+
self.mask_eos_reqs[req_idx] = sampling_param.shm_param.min_new_tokens - 1
68+
69+
2470
@dataclass
2571
class InferenceContext:
2672
req_manager: ReqManager = None # gpu 请求管理
@@ -36,6 +82,8 @@ class InferenceContext:
3682
def register(
3783
self, req_manager: ReqManager, radix_cache: RadixCache, shm_req_manager: ShmReqManager, vocab_size: int
3884
):
85+
if os.getenv("ENABLE_REQ_PARAM_CACHE", False):
86+
req_manager.req_sample_parms_manager = ReqSampleParmsManager(req_manager.max_request_num, vocab_size)
3987
self.req_manager = req_manager
4088
self.radix_cache = radix_cache
4189
self.shm_req_manager = shm_req_manager
@@ -55,7 +103,6 @@ def get_overlap_stream(self) -> torch.cuda.Stream:
55103
def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_req_obj=True):
56104
request_ids = []
57105
for r in requests:
58-
59106
r_id, r_index, multimodal_params, _ = r
60107
if r_id not in self.requests_mapping.keys():
61108
r_obj = InferReq(
@@ -264,6 +311,16 @@ def init_all(self):
264311
self.shm_req.link_prompt_ids_shm_array()
265312
self.shm_req.link_logprobs_shm_array()
266313
self.sampling_param: InferSamplingParams = InferSamplingParams(self.shm_req, self.vocab_size)
314+
315+
if os.getenv("ENABLE_REQ_PARAM_CACHE", False):
316+
g_infer_context.req_manager.req_sample_parms_manager.set_sample_params(
317+
self.req_idx, self.sampling_param
318+
)
319+
if self.sampling_param.shm_param.input_penalty:
320+
dct = collections.Counter(self.shm_req.get_prompt_ids())
321+
for idx, count in dct.items():
322+
g_infer_context.req_manager.req_sample_parms_manager.p_token_vocabs[self.req_idx][idx] = count
323+
267324
if self.sampling_param.shm_param.input_penalty:
268325
self.out_token_id_count = collections.Counter(self.shm_req.get_prompt_ids())
269326
else:

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def _get_classed_reqs(self, req_ids: List[int], no_decode: bool = False, strict_
226226
prefill_reqs.append(req_obj)
227227
continue
228228

229-
is_decode = req_obj.cur_kv_len + 1 == req_obj.get_cur_total_len()
229+
is_decode = req_obj.get_output_len() > 0
230230

231231
if not is_decode:
232232
prefill_reqs.append(req_obj)

0 commit comments

Comments
 (0)