diff --git a/README.md b/README.md index 039922e..cb70ad0 100644 --- a/README.md +++ b/README.md @@ -76,6 +76,9 @@ vllm serve openai/gpt-oss-20b [Learn more about how to use gpt-oss with vLLM.](https://cookbook.openai.com/articles/gpt-oss/run-vllm) +#### SGLang +see [GPT OSS Usage](https://docs.sglang.ai/basic_usage/gpt_oss.html) + #### PyTorch / Triton / Metal These implementations are largely reference implementations for educational purposes and are not expected to be run in production. diff --git a/gpt_oss/chat.py b/gpt_oss/chat.py index 5e40079..d667ba4 100644 --- a/gpt_oss/chat.py +++ b/gpt_oss/chat.py @@ -72,7 +72,10 @@ def main(args): generator = TorchGenerator(args.checkpoint, device) case "vllm": from gpt_oss.vllm.token_generator import TokenGenerator as VLLMGenerator - generator = VLLMGenerator(args.checkpoint, tensor_parallel_size=2) + generator = VLLMGenerator(args.checkpoint, tensor_parallel_size=args.tensor_parallel_size) + case "sglang": + from gpt_oss.sglang.token_generator import TokenGenerator as SGLangGenerator + generator = SGLangGenerator(args.checkpoint, tensor_parallel_size=args.tensor_parallel_size) case _: raise ValueError(f"Invalid backend: {args.backend}") @@ -351,9 +354,17 @@ async def run_tool(): "--backend", type=str, default="triton", - choices=["triton", "torch", "vllm"], + choices=["triton", "torch", "vllm", "sglang"], help="Inference backend", ) + parser.add_argument( + "-tp", + "--tensor-parallel-size", + metavar="TP", + type=int, + default=1, + help="Inference backend tensor parallel size", + ) args = parser.parse_args() if int(os.environ.get("WORLD_SIZE", 1)) == 1: diff --git a/gpt_oss/generate.py b/gpt_oss/generate.py index dfaaa6f..5b56370 100644 --- a/gpt_oss/generate.py +++ b/gpt_oss/generate.py @@ -22,7 +22,10 @@ def main(args): generator = TritonGenerator(args.checkpoint, context=4096, device=device) case "vllm": from gpt_oss.vllm.token_generator import TokenGenerator as VLLMGenerator - generator = VLLMGenerator(args.checkpoint, tensor_parallel_size=2) + generator = VLLMGenerator(args.checkpoint, tensor_parallel_size=args.tensor_parallel_size) + case "sglang": + from gpt_oss.sglang.token_generator import TokenGenerator as SGLangGenerator + generator = SGLangGenerator(args.checkpoint, tensor_parallel_size=args.tensor_parallel_size) case _: raise ValueError(f"Invalid backend: {args.backend}") @@ -75,9 +78,17 @@ def main(args): metavar="BACKEND", type=str, default="torch", - choices=["triton", "torch", "vllm"], + choices=["triton", "torch", "vllm", "sglang"], help="Inference backend", ) + parser.add_argument( + "-tp", + "--tensor-parallel-size", + metavar="TP", + type=int, + default=1, + help="Inference backend tensor parallel size", + ) args = parser.parse_args() main(args) diff --git a/gpt_oss/sglang/token_generator.py b/gpt_oss/sglang/token_generator.py new file mode 100644 index 0000000..50e841e --- /dev/null +++ b/gpt_oss/sglang/token_generator.py @@ -0,0 +1,106 @@ +import sglang + +class TokenGenerator: + def __init__( + self, + model_path: str, + tensor_parallel_size: int = 1, + mem_fraction_static: float = 0.7, + ): + # Create an LLM. + self.engine = sglang.Engine( + model_path=model_path, + skip_tokenizer_init=True, + tp_size=tensor_parallel_size, + mem_fraction_static=mem_fraction_static, + ) + + def generate( + self, + prompt_tokens: list[int], + stop_tokens: list[int] | None = None, + temperature: float = 1.0, + top_p: float = 1.0, + max_tokens: int = 0, + return_logprobs: bool = False, + ): + # https://docs.sglangang.ai/backend/sampling_params.html + sampling_params = { + "n": 1, # number of samples to generate + "temperature": temperature, + "top_p": top_p, + "stop_token_ids": stop_tokens, + } + if max_tokens > 0: + sampling_params["max_new_tokens"] = max_tokens + pre_len = 0 + gen_iter = self.engine.generate( + input_ids=prompt_tokens, + sampling_params=sampling_params, + stream=True, + return_logprob=return_logprobs, + ) + for output in gen_iter: + token_ids = output["output_ids"] + logprobs_list = ( + output.logprobs + if hasattr(output["meta_info"], "output_token_logprobs") + else None + ) + if return_logprobs is True: + new_logprobs = logprobs_list[pre_len:] + else: + new_logprobs = [(None, token_id, None) for token_id in token_ids[pre_len:]] + pre_len = len(token_ids) + for logprob_val, token_id, _ in new_logprobs: + if logprob_val is None: + yield token_id + else: + yield (token_id, logprob_val) + if stop_tokens is not None and token_id in stop_tokens: + break + + async def async_generate( + self, + prompt_tokens: list[int], + stop_tokens: list[int] | None = None, + temperature: float = 1.0, + top_p: float = 1.0, + max_tokens: int = 0, + return_logprobs: bool = False, + ): + # https://docs.sglangang.ai/backend/sampling_params.html + sampling_params = { + "n": 1, # number of samples to generate + "temperature": temperature, + "top_p": top_p, + "stop_token_ids": stop_tokens, + } + if max_tokens > 0: + sampling_params["max_new_tokens"] = max_tokens + pre_len = 0 + gen_iter = await self.engine.async_generate( + input_ids=prompt_tokens, + sampling_params=sampling_params, + stream=True, + return_logprob=return_logprobs, + ) + async for output in gen_iter: + token_ids = output["output_ids"] + logprobs_list = ( + output.logprobs + if hasattr(output["meta_info"], "output_token_logprobs") + else None + ) + if return_logprobs is True: + new_logprobs = logprobs_list[pre_len:] + else: + new_logprobs = [(None, token_id, None) for token_id in token_ids[pre_len:]] + pre_len = len(token_ids) + for logprob_val, token_id, _ in new_logprobs: + if logprob_val is None: + yield token_id + else: + yield (token_id, logprob_val) + if stop_tokens is not None and token_id in stop_tokens: + break \ No newline at end of file