1313
1414from transformers import AutoTokenizer , AutoModelForCausalLM
1515
16+
1617def generate (
1718 prompt : str ,
1819 model : Union [str , AutoModelForCausalLM ],
1920 hf_access_token : str = None ,
20- tokenizer : Union [str , AutoTokenizer ] = ' meta-llama/Llama-2-7b-hf' ,
21+ tokenizer : Union [str , AutoTokenizer ] = " meta-llama/Llama-2-7b-hf" ,
2122 device : Optional [str ] = None ,
2223 max_length : int = 1024 ,
2324 assistant_model : Optional [Union [str , AutoModelForCausalLM ]] = None ,
2425 generate_kwargs : Optional [dict ] = None ,
2526) -> str :
26- """ Generates output given a prompt.
27+ """Generates output given a prompt.
2728
2829 Args:
2930 prompt: The string prompt.
@@ -53,43 +54,40 @@ def generate(
5354 if torch .cuda .is_available () and torch .cuda .device_count ():
5455 device = "cuda:0"
5556 logging .warning (
56- ' inference device is not set, using cuda:0, %s' ,
57- torch .cuda .get_device_name (0 )
57+ " inference device is not set, using cuda:0, %s" ,
58+ torch .cuda .get_device_name (0 ),
5859 )
5960 else :
60- device = ' cpu'
61+ device = " cpu"
6162 logging .warning (
62- (
63- 'No CUDA device detected, using cpu, '
64- 'expect slower speeds.'
65- )
63+ ("No CUDA device detected, using cpu, " "expect slower speeds." )
6664 )
6765
68- if ' cuda' in device and not torch .cuda .is_available ():
69- raise ValueError (' CUDA device requested but no CUDA device detected.' )
66+ if " cuda" in device and not torch .cuda .is_available ():
67+ raise ValueError (" CUDA device requested but no CUDA device detected." )
7068
7169 if not tokenizer :
72- raise ValueError (' Tokenizer is not set in the generate function.' )
70+ raise ValueError (" Tokenizer is not set in the generate function." )
7371
7472 if not hf_access_token :
75- raise ValueError ((
76- 'Hugging face access token needs to be specified. '
77- 'Please refer to https://huggingface.co/docs/hub/security-tokens'
78- ' to obtain one.'
73+ raise ValueError (
74+ (
75+ "Hugging face access token needs to be specified. "
76+ "Please refer to https://huggingface.co/docs/hub/security-tokens"
77+ " to obtain one."
7978 )
8079 )
8180
8281 if isinstance (model , str ):
8382 checkpoint_path = model
8483 model = AutoModelForCausalLM .from_pretrained (
85- checkpoint_path ,
86- trust_remote_code = True
84+ checkpoint_path , trust_remote_code = True
8785 )
8886 model .to (device ).eval ()
8987 if isinstance (tokenizer , str ):
9088 tokenizer = AutoTokenizer .from_pretrained (
91- tokenizer ,
92- token = hf_access_token ,
89+ tokenizer ,
90+ token = hf_access_token ,
9391 )
9492
9593 # Speculative mode
@@ -98,17 +96,13 @@ def generate(
9896 draft_model = assistant_model
9997 if isinstance (assistant_model , str ):
10098 draft_model = AutoModelForCausalLM .from_pretrained (
101- assistant_model ,
102- trust_remote_code = True
99+ assistant_model , trust_remote_code = True
103100 )
104101 draft_model .to (device ).eval ()
105102
106103 # Prepare the prompt
107104 tokenized_prompt = tokenizer (prompt )
108- tokenized_prompt = torch .tensor (
109- tokenized_prompt ['input_ids' ],
110- device = device
111- )
105+ tokenized_prompt = torch .tensor (tokenized_prompt ["input_ids" ], device = device )
112106
113107 tokenized_prompt = tokenized_prompt .unsqueeze (0 )
114108
@@ -123,10 +117,7 @@ def generate(
123117 )
124118 generation_time = time .time () - stime
125119
126- output_text = tokenizer .decode (
127- output_ids [0 ].tolist (),
128- skip_special_tokens = True
129- )
120+ output_text = tokenizer .decode (output_ids [0 ].tolist (), skip_special_tokens = True )
130121
131122 return output_text , generation_time
132123
@@ -136,83 +127,84 @@ def openelm_generate_parser():
136127
137128 class KwargsParser (argparse .Action ):
138129 """Parser action class to parse kwargs of form key=value"""
130+
139131 def __call__ (self , parser , namespace , values , option_string = None ):
140132 setattr (namespace , self .dest , dict ())
141133 for val in values :
142- if '=' not in val :
134+ if "=" not in val :
143135 raise ValueError (
144136 (
145- ' Argument parsing error, kwargs are expected in'
146- ' the form of key=value.'
137+ " Argument parsing error, kwargs are expected in"
138+ " the form of key=value."
147139 )
148140 )
149- kwarg_k , kwarg_v = val .split ('=' )
141+ kwarg_k , kwarg_v = val .split ("=" )
150142 try :
151143 converted_v = int (kwarg_v )
152144 except ValueError :
153145 try :
154146 converted_v = float (kwarg_v )
155147 except ValueError :
156- converted_v = kwarg_v
148+ converted_v = kwarg_v
157149 getattr (namespace , self .dest )[kwarg_k ] = converted_v
158150
159- parser = argparse .ArgumentParser (' OpenELM Generate Module' )
151+ parser = argparse .ArgumentParser (" OpenELM Generate Module" )
160152 parser .add_argument (
161- ' --model' ,
162- dest = ' model' ,
163- help = ' Path to the hf converted model.' ,
153+ " --model" ,
154+ dest = " model" ,
155+ help = " Path to the hf converted model." ,
164156 required = True ,
165157 type = str ,
166158 )
167159 parser .add_argument (
168- ' --hf_access_token' ,
169- dest = ' hf_access_token' ,
160+ " --hf_access_token" ,
161+ dest = " hf_access_token" ,
170162 help = 'Hugging face access token, starting with "hf_".' ,
171163 type = str ,
172164 )
173165 parser .add_argument (
174- ' --prompt' ,
175- dest = ' prompt' ,
176- help = ' Prompt for LLM call.' ,
177- default = '' ,
178- type = str ,
166+ " --prompt" ,
167+ dest = " prompt" ,
168+ help = " Prompt for LLM call." ,
169+ default = "" ,
170+ type = str ,
179171 )
180172 parser .add_argument (
181- ' --device' ,
182- dest = ' device' ,
183- help = ' Device used for inference.' ,
173+ " --device" ,
174+ dest = " device" ,
175+ help = " Device used for inference." ,
184176 type = str ,
185177 )
186178 parser .add_argument (
187- ' --max_length' ,
188- dest = ' max_length' ,
189- help = ' Maximum length of tokens.' ,
179+ " --max_length" ,
180+ dest = " max_length" ,
181+ help = " Maximum length of tokens." ,
190182 default = 256 ,
191183 type = int ,
192184 )
193185 parser .add_argument (
194- ' --assistant_model' ,
195- dest = ' assistant_model' ,
186+ " --assistant_model" ,
187+ dest = " assistant_model" ,
196188 help = (
197189 (
198- ' If set, this is used as a draft model '
199- ' for assisted speculative generation.'
190+ " If set, this is used as a draft model "
191+ " for assisted speculative generation."
200192 )
201193 ),
202194 type = str ,
203195 )
204196 parser .add_argument (
205- ' --generate_kwargs' ,
206- dest = ' generate_kwargs' ,
207- help = ' Additional kwargs passed to the HF generate function.' ,
197+ " --generate_kwargs" ,
198+ dest = " generate_kwargs" ,
199+ help = " Additional kwargs passed to the HF generate function." ,
208200 type = str ,
209- nargs = '*' ,
201+ nargs = "*" ,
210202 action = KwargsParser ,
211203 )
212204 return parser .parse_args ()
213205
214206
215- if __name__ == ' __main__' :
207+ if __name__ == " __main__" :
216208 args = openelm_generate_parser ()
217209 prompt = args .prompt
218210
@@ -228,12 +220,12 @@ def __call__(self, parser, namespace, values, option_string=None):
228220
229221 print_txt = (
230222 f'\r \n { "=" * os .get_terminal_size ().columns } \r \n '
231- ' \033 [1m Prompt + Generated Output\033 [0m\r \n '
223+ " \033 [1m Prompt + Generated Output\033 [0m\r \n "
232224 f'{ "-" * os .get_terminal_size ().columns } \r \n '
233- f' { output_text } \r \n '
225+ f" { output_text } \r \n "
234226 f'{ "-" * os .get_terminal_size ().columns } \r \n '
235- ' \r \n Generation took'
236- f' \033 [1m\033 [92m { round (genertaion_time , 2 )} \033 [0m'
237- ' seconds.\r \n '
227+ " \r \n Generation took"
228+ f" \033 [1m\033 [92m { round (genertaion_time , 2 )} \033 [0m"
229+ " seconds.\r \n "
238230 )
239231 print (print_txt )
0 commit comments