Skip to content

Commit a64e653

Browse files
committed
Merge branch 'master' of github.com:bigscience-workshop/bigscience
2 parents 893e075 + 9eec76b commit a64e653

File tree

1 file changed

+59
-0
lines changed

1 file changed

+59
-0
lines changed

evaluation/generation/generate.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import argparse
2+
import datetime
3+
4+
import torch
5+
from transformers import AutoTokenizer, AutoModelForCausalLM
6+
7+
def get_args():
8+
parser = argparse.ArgumentParser()
9+
parser.add_argument("--checkpoint", type=str, help="Checkpoint path", required=True)
10+
parser.add_argument("--parallelize", action="store_true")
11+
parser.add_argument("--global-step", type=str, default=None)
12+
parser.add_argument("--generate-max-length", type=int, default=50, help="max generation length")
13+
parser.add_argument("--greedy", action="store_true")
14+
parser.add_argument("--top-k", type=int, default=0)
15+
parser.add_argument("--offload_folder", type=str, help="offload folder for accelerate", default="./offload")
16+
17+
return parser.parse_args()
18+
19+
def generate_from_text(model, text, tokenizer, max_length=200, greedy=False, top_k=0):
20+
input_ids = tokenizer.encode(text, return_tensors='pt').to("cuda:0")
21+
max_length = input_ids.size(-1) + max_length
22+
23+
greedy_output = model.generate(
24+
input_ids.to('cuda:0'),
25+
max_length=max_length,
26+
do_sample=not greedy,
27+
top_k=None if greedy else top_k,
28+
)
29+
return tokenizer.decode(greedy_output[0], skip_special_tokens=True)
30+
31+
def main():
32+
args = get_args()
33+
print("Loading model")
34+
35+
tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, padding_side="left")
36+
37+
print("Loaded tokenizer!")
38+
start = datetime.datetime.now()
39+
model = AutoModelForCausalLM.from_pretrained(
40+
args.checkpoint,
41+
device_map="auto" if args.parallelize else None,
42+
torch_dtype=torch.bfloat16,
43+
revision="gs{}".format(args.global_step) if args.global_step else None
44+
offload_folder=args.offload_folder is args.parallelize else None,
45+
)
46+
print(f"Loaded model in {datetime.datetime.now() - start}")
47+
48+
text = ''
49+
while True:
50+
try:
51+
dummy = input('''Enter the paragraph (Enter for new line and Ctrl-c to end the prompt):''')+'\n'
52+
text += dummy
53+
except KeyboardInterrupt:
54+
output = generate_from_text(model, text, tokenizer, max_length=args.generate_max_length, greedy=args.greedy, top_k=args.top_k)
55+
print(output)
56+
text = ''
57+
58+
if __name__ == "__main__":
59+
main()

0 commit comments

Comments
 (0)