@@ -19,10 +19,10 @@ def main(args):
19
19
from gpt_oss .torch .utils import init_distributed
20
20
from gpt_oss .triton .model import TokenGenerator as TritonGenerator
21
21
device = init_distributed ()
22
- generator = TritonGenerator (args .checkpoint , context = 4096 , device = device )
22
+ generator = TritonGenerator (args .checkpoint , context = args . context_length , device = device )
23
23
case "vllm" :
24
24
from gpt_oss .vllm .token_generator import TokenGenerator as VLLMGenerator
25
- generator = VLLMGenerator (args .checkpoint , tensor_parallel_size = 2 )
25
+ generator = VLLMGenerator (args .checkpoint , tensor_parallel_size = args . tensor_parallel_size )
26
26
case _:
27
27
raise ValueError (f"Invalid backend: { args .backend } " )
28
28
@@ -31,9 +31,9 @@ def main(args):
31
31
max_tokens = None if args .limit == 0 else args .limit
32
32
for token , logprob in generator .generate (tokens , stop_tokens = [tokenizer .eot_token ], temperature = args .temperature , max_tokens = max_tokens , return_logprobs = True ):
33
33
tokens .append (token )
34
- decoded_token = tokenizer .decode ([token ])
34
+ token_text = tokenizer .decode ([token ])
35
35
print (
36
- f"Generated token: { repr (decoded_token )} , logprob: { logprob } "
36
+ f"Generated token: { repr (token_text )} , logprob: { logprob } "
37
37
)
38
38
39
39
@@ -78,6 +78,18 @@ def main(args):
78
78
choices = ["triton" , "torch" , "vllm" ],
79
79
help = "Inference backend" ,
80
80
)
81
+ parser .add_argument (
82
+ "--tensor-parallel-size" ,
83
+ type = int ,
84
+ default = 2 ,
85
+ help = "Tensor parallel size for vLLM backend" ,
86
+ )
87
+ parser .add_argument (
88
+ "--context-length" ,
89
+ type = int ,
90
+ default = 4096 ,
91
+ help = "Context length for Triton backend" ,
92
+ )
81
93
args = parser .parse_args ()
82
94
83
95
main (args )
0 commit comments