|
3 | 3 | from lightllm.common.basemodel.triton_kernel.apply_penalty import apply_penalty |
4 | 4 | from dataclasses import dataclass |
5 | 5 | from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context |
| 6 | +from lightllm.utils.envs_utils import get_env_start_args |
6 | 7 |
|
7 | 8 |
|
8 | 9 | def sample(logits, reqs, eos_id: List[int] = [2]): |
@@ -41,13 +42,32 @@ def sample(logits, reqs, eos_id: List[int] = [2]): |
41 | 42 | logits[mask_eos_reqs, eos_id] = -1000000.0 |
42 | 43 | logits.div_(temperatures.view((-1, 1))) |
43 | 44 | probs = torch.softmax(logits, dim=-1) |
44 | | - probs_sort, probs_idx = _top_p_top_k(probs, top_ps, top_ks) |
45 | | - sampled_index = torch.multinomial(probs_sort, num_samples=1, replacement=True) |
46 | 45 |
|
47 | | - batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index) |
48 | | - batch_next_token_probs = torch.gather(probs_sort, dim=1, index=sampled_index) |
49 | | - |
50 | | - return batch_next_token_ids.view(-1), batch_next_token_probs.view(-1) |
| 46 | + if get_env_start_args().sampling_backend == "triton": |
| 47 | + probs_sort, probs_idx = _top_p_top_k(probs, top_ps, top_ks) |
| 48 | + sampled_index = torch.multinomial(probs_sort, num_samples=1, replacement=True) |
| 49 | + |
| 50 | + batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index) |
| 51 | + batch_next_token_probs = torch.gather(probs_sort, dim=1, index=sampled_index) |
| 52 | + |
| 53 | + return batch_next_token_ids.view(-1), batch_next_token_probs.view(-1) |
| 54 | + |
| 55 | + elif get_env_start_args().sampling_backend == "sglang_kernel": |
| 56 | + from sgl_kernel import top_k_top_p_sampling_from_probs |
| 57 | + |
| 58 | + batch_next_token_ids = top_k_top_p_sampling_from_probs( |
| 59 | + probs, |
| 60 | + top_ks, |
| 61 | + top_ps, |
| 62 | + filter_apply_order="joint", |
| 63 | + check_nan=True, |
| 64 | + ) |
| 65 | + int64_batch_next_token_ids = torch.empty_like(batch_next_token_ids, dtype=torch.int64) |
| 66 | + int64_batch_next_token_ids[:] = batch_next_token_ids |
| 67 | + batch_next_token_probs = torch.gather(probs, dim=1, index=int64_batch_next_token_ids.view(-1, 1)) |
| 68 | + return batch_next_token_ids.view(-1), batch_next_token_probs.view(-1) |
| 69 | + else: |
| 70 | + assert False, "dead path" |
51 | 71 |
|
52 | 72 |
|
53 | 73 | def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor): |
|
0 commit comments