Skip to content

Commit 2b2d30f

Browse files
authored
improve post_process. (#857)
1 parent 171ee4e commit 2b2d30f

File tree

2 files changed

+82
-40
lines changed

2 files changed

+82
-40
lines changed

lightllm/common/basemodel/triton_kernel/apply_penalty.py

Lines changed: 73 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,24 @@
44
import triton.language as tl
55
import numpy as np
66

7+
78
@triton.jit
89
def _fwd_kernel_apply_penalty(
9-
Logits, presence_penalty, freqency_penalty, repetition_penalty,
10-
p_token_ids, p_token_counts, p_cumsum_seq_len,
11-
stride_logit_b, stride_logit_s,
12-
BLOCK_P: tl.constexpr
10+
Logits,
11+
presence_penalty,
12+
freqency_penalty,
13+
repetition_penalty,
14+
p_token_ids,
15+
p_token_counts,
16+
p_cumsum_seq_len,
17+
exponential_decay_length_penalties,
18+
length_penalty_idx,
19+
eos_ids,
20+
mask_eos_reqs,
21+
stride_logit_b,
22+
stride_logit_s,
23+
BLOCK_P: tl.constexpr,
24+
EOS_ID_NUM: tl.constexpr,
1325
):
1426
cur_batch = tl.program_id(0)
1527
cur_freqency = tl.load(freqency_penalty + cur_batch)
@@ -18,36 +30,70 @@ def _fwd_kernel_apply_penalty(
1830

1931
cur_batch_start_index = tl.load(p_cumsum_seq_len + cur_batch)
2032
cur_batch_end_index = tl.load(p_cumsum_seq_len + cur_batch + 1)
33+
for block_start_index in range(cur_batch_start_index, cur_batch_end_index, BLOCK_P):
34+
cur_batch_id_offset = block_start_index + tl.arange(0, BLOCK_P)
35+
batch_ids = tl.load(p_token_ids + cur_batch_id_offset, mask=cur_batch_id_offset < cur_batch_end_index, other=0)
36+
batch_ids_count = tl.load(
37+
p_token_counts + cur_batch_id_offset, mask=cur_batch_id_offset < cur_batch_end_index, other=0
38+
)
39+
40+
row_start_ptr = Logits + cur_batch * stride_logit_b
41+
cur_offset = row_start_ptr + batch_ids
42+
cur_logits = tl.load(cur_offset, mask=cur_batch_id_offset < cur_batch_end_index, other=0.0)
43+
rep_logits = tl.where(cur_logits > 0, cur_logits / cur_repetition, cur_logits * cur_repetition)
44+
freq_logits = rep_logits - batch_ids_count * cur_freqency
45+
pre_logits = freq_logits - cur_presence
46+
output_ptr = Logits + cur_batch * stride_logit_b + batch_ids
47+
tl.store(output_ptr, pre_logits, mask=cur_batch_id_offset < cur_batch_end_index)
2148

22-
cur_batch_id_offset = cur_batch_start_index + tl.arange(0, BLOCK_P)
23-
batch_ids = tl.load(p_token_ids + cur_batch_id_offset, mask=cur_batch_id_offset<cur_batch_end_index, other=0)
24-
batch_ids_count = tl.load(p_token_counts + cur_batch_id_offset, mask=cur_batch_id_offset<cur_batch_end_index, other=0)
25-
26-
row_start_ptr = Logits + cur_batch * stride_logit_b
27-
cur_offset = row_start_ptr + batch_ids
28-
cur_logits = tl.load(cur_offset, mask=cur_batch_id_offset<cur_batch_end_index, other=0.0)
29-
rep_logits = tl.where(cur_logits > 0, cur_logits / cur_repetition, cur_logits * cur_repetition)
30-
freq_logits = rep_logits - batch_ids_count * cur_freqency
31-
pre_logits = freq_logits - cur_presence
32-
output_ptr = Logits + cur_batch * stride_logit_b + batch_ids
33-
tl.store(output_ptr, pre_logits, mask=cur_batch_id_offset<cur_batch_end_index)
49+
mask_eos = tl.load(mask_eos_reqs + cur_batch)
50+
exponential_decay_length_penalty = tl.load(exponential_decay_length_penalties + cur_batch)
51+
length_penalty = tl.load(length_penalty_idx + cur_batch)
52+
penalty_scale = tl.exp2(tl.log2(exponential_decay_length_penalty) * length_penalty) - 1
3453

54+
for eos_index in range(EOS_ID_NUM):
55+
eos_id = tl.load(eos_ids + eos_index)
56+
cur_eos_logit_ptr = Logits + cur_batch * stride_logit_b + eos_id
57+
cur_eos_logit = tl.load(cur_eos_logit_ptr)
58+
cur_eos_logit = cur_eos_logit + tl.abs(cur_eos_logit) * penalty_scale
59+
cur_eos_logit = tl.where(mask_eos, -10000000.0, cur_eos_logit)
60+
tl.store(cur_eos_logit_ptr, cur_eos_logit)
3561
return
3662

63+
3764
@torch.no_grad()
38-
def apply_penalty(Logits, presence_penalty, freqency_penalty, repetition_penalty, p_token_ids, p_token_counts, p_cumsum_seq_len, p_max_len_in_batch):
65+
def apply_penalty(
66+
Logits,
67+
presence_penalty,
68+
freqency_penalty,
69+
repetition_penalty,
70+
p_token_ids,
71+
p_token_counts,
72+
p_cumsum_seq_len,
73+
exponential_decay_length_penalties,
74+
length_penalty_idx,
75+
eos_ids,
76+
mask_eos_reqs,
77+
):
3978
assert Logits.is_contiguous()
40-
BLOCK = triton.next_power_of_2(p_max_len_in_batch)
41-
if BLOCK <= 512:
42-
BLOCK = 512
43-
elif BLOCK <= 1024:
44-
BLOCK = 1024
79+
BLOCK_P = 1024
4580
num_warps = 8
46-
_fwd_kernel_apply_penalty[(Logits.shape[0], )](
47-
Logits, presence_penalty, freqency_penalty, repetition_penalty,
48-
p_token_ids, p_token_counts, p_cumsum_seq_len,
49-
Logits.stride(0), Logits.stride(1),
81+
_fwd_kernel_apply_penalty[(Logits.shape[0],)](
82+
Logits,
83+
presence_penalty,
84+
freqency_penalty,
85+
repetition_penalty,
86+
p_token_ids,
87+
p_token_counts,
88+
p_cumsum_seq_len,
89+
exponential_decay_length_penalties,
90+
length_penalty_idx,
91+
eos_ids,
92+
mask_eos_reqs,
93+
Logits.stride(0),
94+
Logits.stride(1),
5095
num_warps=num_warps,
51-
BLOCK_P=BLOCK
96+
BLOCK_P=BLOCK_P,
97+
EOS_ID_NUM=eos_ids.shape[0],
5298
)
5399
return

lightllm/server/router/model_infer/mode_backend/generic_post_process.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@ def sample(logits, reqs, eos_id: List[int] = [2]):
1818
p_token_ids,
1919
p_token_counts,
2020
p_cumsum_seq_len,
21-
p_max_len_in_batch,
2221
length_penalty_idx,
2322
mask_eos_reqs,
2423
) = _get_post_sample_tensors(reqs)
2524

25+
eos_ids = torch.tensor(eos_id, dtype=torch.int32, device="cpu", pin_memory=True).cuda(non_blocking=True)
26+
2627
logits = logits.contiguous()
2728

2829
apply_penalty(
@@ -33,13 +34,12 @@ def sample(logits, reqs, eos_id: List[int] = [2]):
3334
p_token_ids,
3435
p_token_counts,
3536
p_cumsum_seq_len,
36-
p_max_len_in_batch,
37-
)
38-
logits[:, eos_id] = logits[:, eos_id] + torch.abs(logits[:, eos_id]) * (
39-
torch.pow(exponential_decay_length_penalties, length_penalty_idx).view((-1, 1)) - 1
37+
exponential_decay_length_penalties,
38+
length_penalty_idx,
39+
eos_ids,
40+
mask_eos_reqs,
4041
)
41-
if mask_eos_reqs.any():
42-
logits[mask_eos_reqs, eos_id] = -1000000.0
42+
4343
logits.div_(temperatures.view((-1, 1)))
4444
probs = torch.softmax(logits, dim=-1)
4545

@@ -94,7 +94,6 @@ def _get_post_sample_tensors(reqs: List[InferReq]):
9494
p_seq_len: List[int] = [
9595
0,
9696
]
97-
p_max_len_in_batch: int = 0
9897
length_penalty_idx: List[int] = []
9998
mask_eos_reqs: List[bool] = []
10099
for i, req_obj in enumerate(reqs):
@@ -113,11 +112,9 @@ def _get_post_sample_tensors(reqs: List[InferReq]):
113112
top_ps.append(sample_param.shm_param.top_p)
114113
top_ks.append(sample_param.shm_param.top_k)
115114

116-
for token_id, count in id_to_count.items():
117-
p_token_ids.append(token_id)
118-
p_token_counts.append(count)
115+
p_token_ids.extend(list(id_to_count.keys()))
116+
p_token_counts.extend(list(id_to_count.values()))
119117
p_seq_len.append(len(id_to_count))
120-
p_max_len_in_batch = max(p_max_len_in_batch, len(id_to_count))
121118

122119
presence_penalties_cpu = torch.tensor(presence_penalties, dtype=torch.float, device="cpu", pin_memory=True)
123120
frequency_penalties_cpu = torch.tensor(frequency_penalties, dtype=torch.float, device="cpu", pin_memory=True)
@@ -146,7 +143,6 @@ def _get_post_sample_tensors(reqs: List[InferReq]):
146143
p_token_ids_cpu.cuda(non_blocking=True),
147144
p_token_counts_cpu.cuda(non_blocking=True),
148145
p_cumsum_seq_len_cpu.cuda(non_blocking=True),
149-
p_max_len_in_batch,
150146
length_penalty_idx_cpu.cuda(non_blocking=True),
151147
mask_eos_reqs_cpu.cuda(non_blocking=True),
152148
)

0 commit comments

Comments
 (0)