@@ -13,6 +13,8 @@ def get_args():
1313 parser .add_argument ("--greedy" , action = "store_true" )
1414 parser .add_argument ("--top-k" , type = int , default = 0 )
1515 parser .add_argument ("--offload_folder" , type = str , help = "offload folder for accelerate" , default = "./offload" )
16+ parser .add_argument ("--max_memory" , type = str , help = "max memory per GPU" , default = "30GB" )
17+ parser .add_argument ("--max_cpu_memory" , type = str , help = "max memory on CPU" , default = "300GB" )
1618
1719 return parser .parse_args ()
1820
@@ -40,9 +42,12 @@ def main():
4042 args .checkpoint ,
4143 device_map = "auto" if args .parallelize else None ,
4244 torch_dtype = torch .bfloat16 ,
43- revision = "gs{}" .format (args .global_step ) if args .global_step else None
44- offload_folder = args .offload_folder is args .parallelize else None ,
45+ revision = "gs{}" .format (args .global_step ) if args .global_step else None ,
46+ max_memory = args .max_memory if args .parallelize else None ,
47+ max_cpu_memory = args .max_cpu_memory if args .parallelize else None ,
48+ offload_folder = args .offload_folder if args .parallelize else None ,
4549 )
50+
4651 print (f"Loaded model in { datetime .datetime .now () - start } " )
4752
4853 text = ''
0 commit comments