Skip to content

Commit 5d98a56

Browse files
committed
ran make pre-commit for formatting fixes
1 parent de83267 commit 5d98a56

File tree

1 file changed

+58
-66
lines changed

1 file changed

+58
-66
lines changed

supporting-blog-content/using-openelm-models/OpenELM/generate_openelm.py

Lines changed: 58 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,18 @@
1313

1414
from transformers import AutoTokenizer, AutoModelForCausalLM
1515

16+
1617
def generate(
1718
prompt: str,
1819
model: Union[str, AutoModelForCausalLM],
1920
hf_access_token: str = None,
20-
tokenizer: Union[str, AutoTokenizer] = 'meta-llama/Llama-2-7b-hf',
21+
tokenizer: Union[str, AutoTokenizer] = "meta-llama/Llama-2-7b-hf",
2122
device: Optional[str] = None,
2223
max_length: int = 1024,
2324
assistant_model: Optional[Union[str, AutoModelForCausalLM]] = None,
2425
generate_kwargs: Optional[dict] = None,
2526
) -> str:
26-
""" Generates output given a prompt.
27+
"""Generates output given a prompt.
2728
2829
Args:
2930
prompt: The string prompt.
@@ -53,43 +54,40 @@ def generate(
5354
if torch.cuda.is_available() and torch.cuda.device_count():
5455
device = "cuda:0"
5556
logging.warning(
56-
'inference device is not set, using cuda:0, %s',
57-
torch.cuda.get_device_name(0)
57+
"inference device is not set, using cuda:0, %s",
58+
torch.cuda.get_device_name(0),
5859
)
5960
else:
60-
device = 'cpu'
61+
device = "cpu"
6162
logging.warning(
62-
(
63-
'No CUDA device detected, using cpu, '
64-
'expect slower speeds.'
65-
)
63+
("No CUDA device detected, using cpu, " "expect slower speeds.")
6664
)
6765

68-
if 'cuda' in device and not torch.cuda.is_available():
69-
raise ValueError('CUDA device requested but no CUDA device detected.')
66+
if "cuda" in device and not torch.cuda.is_available():
67+
raise ValueError("CUDA device requested but no CUDA device detected.")
7068

7169
if not tokenizer:
72-
raise ValueError('Tokenizer is not set in the generate function.')
70+
raise ValueError("Tokenizer is not set in the generate function.")
7371

7472
if not hf_access_token:
75-
raise ValueError((
76-
'Hugging face access token needs to be specified. '
77-
'Please refer to https://huggingface.co/docs/hub/security-tokens'
78-
' to obtain one.'
73+
raise ValueError(
74+
(
75+
"Hugging face access token needs to be specified. "
76+
"Please refer to https://huggingface.co/docs/hub/security-tokens"
77+
" to obtain one."
7978
)
8079
)
8180

8281
if isinstance(model, str):
8382
checkpoint_path = model
8483
model = AutoModelForCausalLM.from_pretrained(
85-
checkpoint_path,
86-
trust_remote_code=True
84+
checkpoint_path, trust_remote_code=True
8785
)
8886
model.to(device).eval()
8987
if isinstance(tokenizer, str):
9088
tokenizer = AutoTokenizer.from_pretrained(
91-
tokenizer,
92-
token=hf_access_token,
89+
tokenizer,
90+
token=hf_access_token,
9391
)
9492

9593
# Speculative mode
@@ -98,17 +96,13 @@ def generate(
9896
draft_model = assistant_model
9997
if isinstance(assistant_model, str):
10098
draft_model = AutoModelForCausalLM.from_pretrained(
101-
assistant_model,
102-
trust_remote_code=True
99+
assistant_model, trust_remote_code=True
103100
)
104101
draft_model.to(device).eval()
105102

106103
# Prepare the prompt
107104
tokenized_prompt = tokenizer(prompt)
108-
tokenized_prompt = torch.tensor(
109-
tokenized_prompt['input_ids'],
110-
device=device
111-
)
105+
tokenized_prompt = torch.tensor(tokenized_prompt["input_ids"], device=device)
112106

113107
tokenized_prompt = tokenized_prompt.unsqueeze(0)
114108

@@ -123,10 +117,7 @@ def generate(
123117
)
124118
generation_time = time.time() - stime
125119

126-
output_text = tokenizer.decode(
127-
output_ids[0].tolist(),
128-
skip_special_tokens=True
129-
)
120+
output_text = tokenizer.decode(output_ids[0].tolist(), skip_special_tokens=True)
130121

131122
return output_text, generation_time
132123

@@ -136,83 +127,84 @@ def openelm_generate_parser():
136127

137128
class KwargsParser(argparse.Action):
138129
"""Parser action class to parse kwargs of form key=value"""
130+
139131
def __call__(self, parser, namespace, values, option_string=None):
140132
setattr(namespace, self.dest, dict())
141133
for val in values:
142-
if '=' not in val:
134+
if "=" not in val:
143135
raise ValueError(
144136
(
145-
'Argument parsing error, kwargs are expected in'
146-
' the form of key=value.'
137+
"Argument parsing error, kwargs are expected in"
138+
" the form of key=value."
147139
)
148140
)
149-
kwarg_k, kwarg_v = val.split('=')
141+
kwarg_k, kwarg_v = val.split("=")
150142
try:
151143
converted_v = int(kwarg_v)
152144
except ValueError:
153145
try:
154146
converted_v = float(kwarg_v)
155147
except ValueError:
156-
converted_v = kwarg_v
148+
converted_v = kwarg_v
157149
getattr(namespace, self.dest)[kwarg_k] = converted_v
158150

159-
parser = argparse.ArgumentParser('OpenELM Generate Module')
151+
parser = argparse.ArgumentParser("OpenELM Generate Module")
160152
parser.add_argument(
161-
'--model',
162-
dest='model',
163-
help='Path to the hf converted model.',
153+
"--model",
154+
dest="model",
155+
help="Path to the hf converted model.",
164156
required=True,
165157
type=str,
166158
)
167159
parser.add_argument(
168-
'--hf_access_token',
169-
dest='hf_access_token',
160+
"--hf_access_token",
161+
dest="hf_access_token",
170162
help='Hugging face access token, starting with "hf_".',
171163
type=str,
172164
)
173165
parser.add_argument(
174-
'--prompt',
175-
dest='prompt',
176-
help='Prompt for LLM call.',
177-
default='',
178-
type=str,
166+
"--prompt",
167+
dest="prompt",
168+
help="Prompt for LLM call.",
169+
default="",
170+
type=str,
179171
)
180172
parser.add_argument(
181-
'--device',
182-
dest='device',
183-
help='Device used for inference.',
173+
"--device",
174+
dest="device",
175+
help="Device used for inference.",
184176
type=str,
185177
)
186178
parser.add_argument(
187-
'--max_length',
188-
dest='max_length',
189-
help='Maximum length of tokens.',
179+
"--max_length",
180+
dest="max_length",
181+
help="Maximum length of tokens.",
190182
default=256,
191183
type=int,
192184
)
193185
parser.add_argument(
194-
'--assistant_model',
195-
dest='assistant_model',
186+
"--assistant_model",
187+
dest="assistant_model",
196188
help=(
197189
(
198-
'If set, this is used as a draft model '
199-
'for assisted speculative generation.'
190+
"If set, this is used as a draft model "
191+
"for assisted speculative generation."
200192
)
201193
),
202194
type=str,
203195
)
204196
parser.add_argument(
205-
'--generate_kwargs',
206-
dest='generate_kwargs',
207-
help='Additional kwargs passed to the HF generate function.',
197+
"--generate_kwargs",
198+
dest="generate_kwargs",
199+
help="Additional kwargs passed to the HF generate function.",
208200
type=str,
209-
nargs='*',
201+
nargs="*",
210202
action=KwargsParser,
211203
)
212204
return parser.parse_args()
213205

214206

215-
if __name__ == '__main__':
207+
if __name__ == "__main__":
216208
args = openelm_generate_parser()
217209
prompt = args.prompt
218210

@@ -228,12 +220,12 @@ def __call__(self, parser, namespace, values, option_string=None):
228220

229221
print_txt = (
230222
f'\r\n{"=" * os.get_terminal_size().columns}\r\n'
231-
'\033[1m Prompt + Generated Output\033[0m\r\n'
223+
"\033[1m Prompt + Generated Output\033[0m\r\n"
232224
f'{"-" * os.get_terminal_size().columns}\r\n'
233-
f'{output_text}\r\n'
225+
f"{output_text}\r\n"
234226
f'{"-" * os.get_terminal_size().columns}\r\n'
235-
'\r\nGeneration took'
236-
f'\033[1m\033[92m {round(genertaion_time, 2)} \033[0m'
237-
'seconds.\r\n'
227+
"\r\nGeneration took"
228+
f"\033[1m\033[92m {round(genertaion_time, 2)} \033[0m"
229+
"seconds.\r\n"
238230
)
239231
print(print_txt)

0 commit comments

Comments
 (0)