diff --git a/gpt_oss/generate.py b/gpt_oss/generate.py index dfaaa6f..c075580 100644 --- a/gpt_oss/generate.py +++ b/gpt_oss/generate.py @@ -19,10 +19,10 @@ def main(args): from gpt_oss.torch.utils import init_distributed from gpt_oss.triton.model import TokenGenerator as TritonGenerator device = init_distributed() - generator = TritonGenerator(args.checkpoint, context=4096, device=device) + generator = TritonGenerator(args.checkpoint, context=args.context_length, device=device) case "vllm": from gpt_oss.vllm.token_generator import TokenGenerator as VLLMGenerator - generator = VLLMGenerator(args.checkpoint, tensor_parallel_size=2) + generator = VLLMGenerator(args.checkpoint, tensor_parallel_size=args.tensor_parallel_size) case _: raise ValueError(f"Invalid backend: {args.backend}") @@ -31,9 +31,9 @@ def main(args): max_tokens = None if args.limit == 0 else args.limit for token, logprob in generator.generate(tokens, stop_tokens=[tokenizer.eot_token], temperature=args.temperature, max_tokens=max_tokens, return_logprobs=True): tokens.append(token) - decoded_token = tokenizer.decode([token]) + token_text = tokenizer.decode([token]) print( - f"Generated token: {repr(decoded_token)}, logprob: {logprob}" + f"Generated token: {repr(token_text)}, logprob: {logprob}" ) @@ -78,6 +78,18 @@ def main(args): choices=["triton", "torch", "vllm"], help="Inference backend", ) + parser.add_argument( + "--tensor-parallel-size", + type=int, + default=2, + help="Tensor parallel size for vLLM backend", + ) + parser.add_argument( + "--context-length", + type=int, + default=4096, + help="Context length for Triton backend", + ) args = parser.parse_args() main(args)