1+ import argparse
2+ import datetime
3+
4+ import torch
5+ from transformers import AutoTokenizer , AutoModelForCausalLM
6+
7+ def get_args ():
8+ parser = argparse .ArgumentParser ()
9+ parser .add_argument ("--checkpoint" , type = str , help = "Checkpoint path" , required = True )
10+ parser .add_argument ("--parallelize" , action = "store_true" )
11+ parser .add_argument ("--global-step" , type = str , default = None )
12+ parser .add_argument ("--generate-max-length" , type = int , default = 50 , help = "max generation length" )
13+ parser .add_argument ("--greedy" , action = "store_true" )
14+ parser .add_argument ("--top-k" , type = int , default = 0 )
15+ parser .add_argument ("--offload_folder" , type = str , help = "offload folder for accelerate" , default = "./offload" )
16+
17+ return parser .parse_args ()
18+
19+ def generate_from_text (model , text , tokenizer , max_length = 200 , greedy = False , top_k = 0 ):
20+ input_ids = tokenizer .encode (text , return_tensors = 'pt' ).to ("cuda:0" )
21+ max_length = input_ids .size (- 1 ) + max_length
22+
23+ greedy_output = model .generate (
24+ input_ids .to ('cuda:0' ),
25+ max_length = max_length ,
26+ do_sample = not greedy ,
27+ top_k = None if greedy else top_k ,
28+ )
29+ return tokenizer .decode (greedy_output [0 ], skip_special_tokens = True )
30+
31+ def main ():
32+ args = get_args ()
33+ print ("Loading model" )
34+
35+ tokenizer = AutoTokenizer .from_pretrained (args .checkpoint , padding_side = "left" )
36+
37+ print ("Loaded tokenizer!" )
38+ start = datetime .datetime .now ()
39+ model = AutoModelForCausalLM .from_pretrained (
40+ args .checkpoint ,
41+ device_map = "auto" if args .parallelize else None ,
42+ torch_dtype = torch .bfloat16 ,
43+ revision = "gs{}" .format (args .global_step ) if args .global_step else None
44+ offload_folder = args .offload_folder is args .parallelize else None ,
45+ )
46+ print (f"Loaded model in { datetime .datetime .now () - start } " )
47+
48+ text = ''
49+ while True :
50+ try :
51+ dummy = input ('''Enter the paragraph (Enter for new line and Ctrl-c to end the prompt):''' )+ '\n '
52+ text += dummy
53+ except KeyboardInterrupt :
54+ output = generate_from_text (model , text , tokenizer , max_length = args .generate_max_length , greedy = args .greedy , top_k = args .top_k )
55+ print (output )
56+ text = ''
57+
58+ if __name__ == "__main__" :
59+ main ()
0 commit comments