Skip to content

Commit 171ee4e

Browse files
authored
sampling_backend support sglang_kernel (#855)
1 parent 7ea2597 commit 171ee4e

File tree

3 files changed

+35
-6
lines changed

3 files changed

+35
-6
lines changed

lightllm/server/api_cli.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,4 +352,12 @@ def make_argument_parser() -> argparse.ArgumentParser:
352352
help="""Path of quantization config. It can be used for mixed quantization.
353353
Examples can be found in lightllm/common/quantization/configs.""",
354354
)
355+
parser.add_argument(
356+
"--sampling_backend",
357+
type=str,
358+
choices=["triton", "sglang_kernel"],
359+
default="triton",
360+
help="""sampling used impl. 'triton' is use torch and triton kernel,
361+
sglang_kernel use sglang_kernel impl""",
362+
)
355363
return parser

lightllm/server/core/objs/start_args_type.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,4 @@ class StartArgs:
7979
vit_quant_cfg: Optional[str] = field(default=None)
8080
enable_flashinfer_prefill: bool = field(default=False)
8181
enable_flashinfer_decode: bool = field(default=False)
82+
sampling_backend: str = field(default="triton", metadata={"choices": ["triton", "sglang_kernel"]})

lightllm/server/router/model_infer/mode_backend/generic_post_process.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from lightllm.common.basemodel.triton_kernel.apply_penalty import apply_penalty
44
from dataclasses import dataclass
55
from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context
6+
from lightllm.utils.envs_utils import get_env_start_args
67

78

89
def sample(logits, reqs, eos_id: List[int] = [2]):
@@ -41,13 +42,32 @@ def sample(logits, reqs, eos_id: List[int] = [2]):
4142
logits[mask_eos_reqs, eos_id] = -1000000.0
4243
logits.div_(temperatures.view((-1, 1)))
4344
probs = torch.softmax(logits, dim=-1)
44-
probs_sort, probs_idx = _top_p_top_k(probs, top_ps, top_ks)
45-
sampled_index = torch.multinomial(probs_sort, num_samples=1, replacement=True)
4645

47-
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index)
48-
batch_next_token_probs = torch.gather(probs_sort, dim=1, index=sampled_index)
49-
50-
return batch_next_token_ids.view(-1), batch_next_token_probs.view(-1)
46+
if get_env_start_args().sampling_backend == "triton":
47+
probs_sort, probs_idx = _top_p_top_k(probs, top_ps, top_ks)
48+
sampled_index = torch.multinomial(probs_sort, num_samples=1, replacement=True)
49+
50+
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index)
51+
batch_next_token_probs = torch.gather(probs_sort, dim=1, index=sampled_index)
52+
53+
return batch_next_token_ids.view(-1), batch_next_token_probs.view(-1)
54+
55+
elif get_env_start_args().sampling_backend == "sglang_kernel":
56+
from sgl_kernel import top_k_top_p_sampling_from_probs
57+
58+
batch_next_token_ids = top_k_top_p_sampling_from_probs(
59+
probs,
60+
top_ks,
61+
top_ps,
62+
filter_apply_order="joint",
63+
check_nan=True,
64+
)
65+
int64_batch_next_token_ids = torch.empty_like(batch_next_token_ids, dtype=torch.int64)
66+
int64_batch_next_token_ids[:] = batch_next_token_ids
67+
batch_next_token_probs = torch.gather(probs, dim=1, index=int64_batch_next_token_ids.view(-1, 1))
68+
return batch_next_token_ids.view(-1), batch_next_token_probs.view(-1)
69+
else:
70+
assert False, "dead path"
5171

5272

5373
def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor):

0 commit comments

Comments
 (0)