Skip to content

Commit 8b2c770

Browse files
committed
first try gen script
1 parent 6632f23 commit 8b2c770

File tree

1 file changed

+57
-0
lines changed

1 file changed

+57
-0
lines changed

evaluation/generation/generate.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import argparse
2+
import json
3+
import datetime
4+
5+
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
6+
7+
def get_args():
8+
parser = argparse.ArgumentParser()
9+
parser.add_argument("--checkpoint", type=str, help="Checkpoint path", required=True)
10+
parser.add_argument("--parallelize", action="store_true")
11+
parser.add_argument("--global-step", type=str, default=None)
12+
parser.add_argument("--generate-max-length", type=int, default=50, help="max generation length")
13+
parser.add_argument("--greedy", action="store_true")
14+
parser.add_argument("--top-k", type=int, default=0)
15+
16+
return parser.parse_args()
17+
18+
def generate_from_text(model, text, tokenizer, max_length=200, greedy=False, top_k=0):
19+
input_ids = tokenizer.encode(text, return_tensors='pt').to("cuda:0")
20+
max_length = input_ids.size(-1) + max_length
21+
22+
greedy_output = model.generate(
23+
input_ids.to('cuda:0'),
24+
max_length=max_length,
25+
do_sample=not greedy,
26+
top_k=None if greedy else top_k,
27+
)
28+
return {
29+
"inputs": text,
30+
"outputs": tokenizer.decode(greedy_output, skip_special_tokens=True)
31+
}
32+
33+
def main(args):
34+
print(f"Loading model", flush=True)
35+
36+
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom", padding_side="left")
37+
38+
print("Loaded tokenizer !")
39+
start = datetime.datetime.now()
40+
model = AutoModelForCausalLM.from_pretrained(
41+
args.checkpoint,
42+
device_map="auto" if args.parallelize else None,
43+
torch_dtype=torch.bfloat16,
44+
revision="gs{}".format(args.global_step) if args.global_step else None
45+
)
46+
model.eval()
47+
print(f"Loaded model in {datetime.datetime.now() - start}")
48+
49+
while True:
50+
text = ''
51+
while True:
52+
dummy = input('''Enter the paragraph :''')+'\n'
53+
if dummy=='\n':
54+
break
55+
text += dummy
56+
output = generate_from_text(model, text, tokenizer, max_length=args.generate_max_length, greedy=args.greedy, top_k=args.top_k)
57+
print(json.dumps(output, indent=2))

0 commit comments

Comments
 (0)