Skip to content

Commit 63ba5bd

Browse files
authored
add greedy_sample (#1019)
1 parent f5f54fd commit 63ba5bd

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

lightllm/server/api_cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
423423
"--sampling_backend",
424424
type=str,
425425
choices=["triton", "sglang_kernel"],
426-
default="sglang_kernel",
426+
default="triton",
427427
help="""sampling used impl. 'triton' is use torch and triton kernel,
428428
sglang_kernel use sglang_kernel impl""",
429429
)

lightllm/server/router/model_infer/mode_backend/generic_post_process.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]):
1414
b_top_ks,
1515
b_length_penalty_param,
1616
b_mask_eos_reqs,
17+
is_all_greedy,
1718
) = _get_post_sample_tensors(reqs)
1819
eos_ids = torch.tensor(eos_id, dtype=torch.int32, device="cpu", pin_memory=True).cuda(non_blocking=True)
1920

@@ -61,7 +62,12 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]):
6162
logits.div_(b_temperatures.view((-1, 1)))
6263
probs = torch.softmax(logits, dim=-1)
6364

64-
if get_env_start_args().sampling_backend == "triton":
65+
if is_all_greedy:
66+
batch_next_token_ids = torch.argmax(logits, -1)
67+
batch_next_token_probs = torch.gather(probs, dim=1, index=batch_next_token_ids.view(-1, 1))
68+
return batch_next_token_ids.view(-1), torch.log(batch_next_token_probs).view(-1)
69+
70+
elif get_env_start_args().sampling_backend == "triton":
6571
probs_sort, probs_idx = _top_p_top_k(probs, b_top_ps, b_top_ks)
6672
sampled_index = torch.multinomial(probs_sort, num_samples=1, replacement=True)
6773
next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index)
@@ -104,6 +110,8 @@ def _get_post_sample_tensors(reqs: List[InferReq]):
104110
top_ks: List[int] = []
105111
length_penalty_param: List[int] = []
106112
mask_eos_reqs: List[bool] = []
113+
is_all_greedy = True
114+
107115
for i, req_obj in enumerate(reqs):
108116
sample_param = req_obj.sampling_param
109117
shm_param = sample_param.shm_param
@@ -114,7 +122,10 @@ def _get_post_sample_tensors(reqs: List[InferReq]):
114122

115123
temperatures.append(shm_param.temperature)
116124
top_ps.append(shm_param.top_p)
117-
top_ks.append(shm_param.top_k)
125+
top_k_val = shm_param.top_k
126+
top_ks.append(top_k_val)
127+
if top_k_val > 1:
128+
is_all_greedy = False
118129
req_idxes.append(req_obj.req_idx)
119130

120131
req_idxes_cpu = torch.tensor(req_idxes, dtype=torch.int32, device="cpu", pin_memory=True)
@@ -131,4 +142,5 @@ def _get_post_sample_tensors(reqs: List[InferReq]):
131142
top_ks_cpu.cuda(non_blocking=True),
132143
length_penalty_param_cpu.cuda(non_blocking=True),
133144
mask_eos_reqs_cpu.cuda(non_blocking=True),
145+
is_all_greedy,
134146
)

0 commit comments

Comments
 (0)