Skip to content

Commit a0486d2

Browse files
author
niushengxiao
committed
opt: reduce gpu memeory alloc for req_param_cache
1 parent 6acd2af commit a0486d2

File tree

4 files changed

+77
-26
lines changed

4 files changed

+77
-26
lines changed

lightllm/common/basemodel/triton_kernel/apply_penalty_cache.py

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
@triton.jit
10-
def _fwd_kernel_apply_penalty_cache(
10+
def _kernel_apply_penalty_cache(
1111
Logits,
1212
req_idxs,
1313
presence_penalty,
@@ -27,7 +27,7 @@ def _fwd_kernel_apply_penalty_cache(
2727

2828
batch_ids = BLOCK_P * block_idx + tl.arange(0, BLOCK_P)
2929
batch_ids_count = tl.load(
30-
p_token_vocabs + token_idx * stride_p_token_vocabs_b + batch_ids,
30+
p_token_vocabs + cur_batch * stride_p_token_vocabs_b + batch_ids,
3131
mask=batch_ids < stride_p_token_vocabs_b,
3232
other=0,
3333
)
@@ -43,7 +43,7 @@ def _fwd_kernel_apply_penalty_cache(
4343

4444

4545
@triton.jit
46-
def _eos_penalty(
46+
def _kernel_eos_penalty(
4747
Logits,
4848
req_idxs,
4949
p_token_lens,
@@ -71,26 +71,58 @@ def _eos_penalty(
7171
return
7272

7373

74+
@triton.jit
75+
def _kernel_bincount(
76+
req_idxs,
77+
input,
78+
output,
79+
input_lens,
80+
stride_input_b,
81+
stride_output_b,
82+
BLOCK_SIZE: tl.constexpr,
83+
):
84+
cur_batch = tl.program_id(0)
85+
req_idx = tl.load(req_idxs + cur_batch)
86+
block_idx = tl.program_id(1)
87+
input_ptr = input + req_idx * stride_input_b + block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
88+
input_len = tl.load(input_lens + req_idx)
89+
mask = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) < input_len
90+
token_id = tl.load(input_ptr, mask=mask, other=0)
91+
tl.atomic_add(output + cur_batch * stride_output_b + token_id, 1, mask=mask)
92+
return
93+
94+
7495
@torch.no_grad()
7596
def apply_penalty_cache(
7697
Logits,
7798
req_idxs,
7899
presence_penalty,
79100
freqency_penalty,
80101
repetition_penalty,
81-
p_token_vocabs,
102+
p_token_ids,
82103
p_token_lens,
83104
exponential_decay_length_penalties,
84105
length_penalty_idx,
85106
eos_ids,
86107
mask_eos_reqs,
87-
is_eos_penalty=False,
108+
vocab_size: tl.constexpr,
109+
is_eos_penalty: tl.constexpr = False,
88110
):
89111
assert Logits.is_contiguous()
90112
BLOCK_P = 1024
91-
num_warps = 8
92-
vocab_size = p_token_vocabs.shape[1]
93-
_fwd_kernel_apply_penalty_cache[(Logits.shape[0], triton.cdiv(vocab_size, BLOCK_P))](
113+
num_warps = 4
114+
p_token_vocabs = torch.zeros((Logits.shape[0], vocab_size), dtype=torch.int32, device="cuda")
115+
_kernel_bincount[(Logits.shape[0], triton.cdiv(p_token_ids.stride(0), BLOCK_P))](
116+
req_idxs,
117+
p_token_ids,
118+
p_token_vocabs,
119+
p_token_lens,
120+
p_token_ids.stride(0),
121+
p_token_vocabs.stride(0),
122+
num_warps=num_warps,
123+
BLOCK_SIZE=BLOCK_P,
124+
)
125+
_kernel_apply_penalty_cache[(Logits.shape[0], triton.cdiv(vocab_size, BLOCK_P))](
94126
Logits,
95127
req_idxs,
96128
presence_penalty,
@@ -103,8 +135,7 @@ def apply_penalty_cache(
103135
BLOCK_P=BLOCK_P,
104136
)
105137
if is_eos_penalty:
106-
p_token_lens = p_token_vocabs[req_idxs].sum(dim=1).cuda() if p_token_lens is None else p_token_lens
107-
_eos_penalty[(Logits.shape[0],)](
138+
_kernel_eos_penalty[(Logits.shape[0],)](
108139
Logits,
109140
req_idxs,
110141
p_token_lens,
@@ -121,11 +152,13 @@ def apply_penalty_cache(
121152

122153
if __name__ == "__main__":
123154
from .apply_penalty import apply_penalty
155+
from torch.nn.utils.rnn import pad_sequence
124156

125157
bs = 200
126158
vocab_size = 150000
127-
p_tokens = 2000
128-
repseats = 1000
159+
p_tokens = 3000
160+
max_token_len = 16384
161+
repseats = max_token_len // p_tokens
129162
req_idxs = torch.arange(bs).cuda()
130163
logits = torch.randn((bs, vocab_size), dtype=torch.float32).cuda()
131164
logits2 = logits.clone()
@@ -144,7 +177,7 @@ def apply_penalty_cache(
144177
i += s_l
145178
p_token_counts = torch.randint(1, repseats, (p_seq_len.sum(),)).cuda()
146179
p_cumsum_seq_len = p_seq_len.cumsum(dim=0).cuda()
147-
p_token_vocabs = torch.zeros((bs, vocab_size), dtype=torch.int16).cuda()
180+
p_token_vocabs = torch.zeros((bs, vocab_size), dtype=torch.int32).cuda()
148181
i = 0
149182
b = 0
150183
for token_id, count in zip(p_token_ids, p_token_counts):
@@ -154,7 +187,12 @@ def apply_penalty_cache(
154187
b += 1
155188
i = 0
156189

157-
p_token_lens = p_token_vocabs.sum(dim=1).cuda()
190+
p_token_lens = p_token_vocabs.cuda().sum(dim=1)
191+
assert p_token_lens.max() < max_token_len
192+
token_idx = torch.arange(vocab_size).cuda()
193+
sequences = [token_idx.repeat_interleave(p_token_vocabs[b]) for b in range(bs)] # shape = [sum(rep_nums)]
194+
p_token_mat = pad_sequence(sequences, batch_first=True, padding_value=0).to(torch.int32).cuda()
195+
158196
length_penalty_idx = torch.randint(0, p_tokens, (bs,)).cuda()
159197
len_idx = torch.tensor([max(p_token_lens[i] - length_penalty_idx[i], 0) for i in range(bs)]).cuda()
160198
mask_eos_reqs = torch.randint(1, p_tokens, (bs,)).cuda()
@@ -180,12 +218,13 @@ def apply_penalty_cache(
180218
presence_penalty,
181219
freqency_penalty,
182220
repetition_penalty,
183-
p_token_vocabs,
221+
p_token_mat,
184222
p_token_lens,
185223
exponential_decay_length_penalties,
186224
length_penalty_idx,
187225
eos_ids,
188226
mask_eos_reqs,
227+
vocab_size,
189228
)
190229
fn1()
191230
fn2()

lightllm/common/req_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana
6060
self.mem_manager = mem_manager
6161
self.req_sample_parms_manager = None
6262
self.max_request_num = max_request_num
63+
self.max_sequence_length = max_sequence_length
6364
self.HOLD_REQUEST_ID = max_request_num
6465

6566
def alloc(self):
@@ -69,7 +70,7 @@ def free(self, free_req_indexes: List[int], free_token_index):
6970
for req_index in free_req_indexes:
7071
self.req_list.free(req_index)
7172
if self.req_sample_parms_manager is not None:
72-
self.req_sample_parms_manager.p_token_vocabs[free_req_indexes] = 0
73+
self.req_sample_parms_manager.p_token_lens[free_req_indexes] = 0
7374

7475
if self.req_list.is_all_free():
7576
logger.debug(f"freed all request size {self.req_list.can_alloc_size}")

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323

2424
class ReqSampleParmsManager:
25-
def __init__(self, max_request_num, vocab_size):
25+
def __init__(self, max_request_num, max_seq_len, vocab_size):
2626
self.presence_penalties = torch.empty(max_request_num, dtype=torch.float, device="cpu", pin_memory=True).cuda(
2727
non_blocking=True
2828
)
@@ -50,9 +50,13 @@ def __init__(self, max_request_num, vocab_size):
5050
self.mask_eos_reqs = torch.empty(max_request_num, dtype=torch.int32, device="cpu", pin_memory=True).cuda(
5151
non_blocking=True
5252
)
53-
self.p_token_vocabs = torch.zeros(
54-
(max_request_num, vocab_size), dtype=torch.int16, device="cpu", pin_memory=True
53+
self.p_token_ids = torch.zeros(
54+
(max_request_num, max_seq_len), dtype=torch.int32, device="cpu", pin_memory=True
5555
).cuda(non_blocking=True)
56+
self.p_token_lens = torch.zeros((max_request_num,), dtype=torch.int32, device="cpu", pin_memory=True).cuda(
57+
non_blocking=True
58+
)
59+
self.vocab_size = vocab_size
5660

5761
def set_sample_params(self, req_idx, sampling_param):
5862
self.presence_penalties[req_idx] = sampling_param.shm_param.presence_penalty
@@ -83,7 +87,9 @@ def register(
8387
self, req_manager: ReqManager, radix_cache: RadixCache, shm_req_manager: ShmReqManager, vocab_size: int
8488
):
8589
if os.getenv("ENABLE_REQ_PARAM_CACHE", False):
86-
req_manager.req_sample_parms_manager = ReqSampleParmsManager(req_manager.max_request_num, vocab_size)
90+
req_manager.req_sample_parms_manager = ReqSampleParmsManager(
91+
req_manager.max_request_num, req_manager.max_sequence_length, vocab_size
92+
)
8793
self.req_manager = req_manager
8894
self.radix_cache = radix_cache
8995
self.shm_req_manager = shm_req_manager
@@ -317,9 +323,11 @@ def init_all(self):
317323
self.req_idx, self.sampling_param
318324
)
319325
if self.sampling_param.shm_param.input_penalty:
320-
dct = collections.Counter(self.shm_req.get_prompt_ids())
321-
for idx, count in dct.items():
322-
g_infer_context.req_manager.req_sample_parms_manager.p_token_vocabs[self.req_idx][idx] = count
326+
ids_len = len(self.shm_req.get_prompt_ids())
327+
g_infer_context.req_manager.req_sample_parms_manager.p_token_ids[self.req_idx][
328+
:ids_len
329+
] = self.shm_req.get_prompt_ids()
330+
g_infer_context.req_manager.req_sample_parms_manager.p_token_lens[self.req_idx] = ids_len
323331

324332
if self.sampling_param.shm_param.input_penalty:
325333
self.out_token_id_count = collections.Counter(self.shm_req.get_prompt_ids())

lightllm/server/router/model_infer/mode_backend/generic_post_process.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,12 +161,13 @@ def sample_in_cache(logits, reqs, eos_id: List[int] = [2]):
161161
params.presence_penalties,
162162
params.frequency_penalties,
163163
params.repetition_penalties,
164-
params.p_token_vocabs,
165-
None,
164+
params.p_token_ids,
165+
params.p_token_lens,
166166
params.exponential_decay_length_penalties,
167167
params.length_penalty_idx,
168168
eos_ids,
169169
params.mask_eos_reqs,
170+
params.vocab_size,
170171
)
171172

172173
logits.div_(params.temperatures[req_idxs].view((-1, 1)))
@@ -209,4 +210,6 @@ def _get_req_idxs(reqs: List[InferReq]):
209210

210211

211212
def _update_repeatition_tokens(req_idxs, token_ids):
212-
g_infer_context.req_manager.req_sample_parms_manager.p_token_vocabs[req_idxs, token_ids] += 1
213+
token_idxs = g_infer_context.req_manager.req_sample_parms_manager.p_token_lens[req_idxs]
214+
g_infer_context.req_manager.req_sample_parms_manager.p_token_ids[req_idxs, token_idxs] = token_ids
215+
g_infer_context.req_manager.req_sample_parms_manager.p_token_lens[req_idxs] += 1

0 commit comments

Comments
 (0)