@@ -275,32 +275,34 @@ def load_transformer(vae, args):
275
275
device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
276
276
model_path = args .model_path
277
277
if args .checkpoint_type == 'torch' :
278
- assert ('ar-' in model_path ) or ('slim-' in model_path )
279
- # copy large model to local, save slim to local, and copy slim to nas, and load local slim model
278
+ # copy large model to local; save slim to local; and copy slim to nas; load local slim model
280
279
if osp .exists (args .cache_dir ):
281
280
local_model_path = osp .join (args .cache_dir , 'tmp' , model_path .replace ('/' , '_' ))
282
281
else :
283
282
local_model_path = model_path
284
- slim_model_path = model_path .replace ('ar-' , 'slim-' )
285
- local_slim_model_path = local_model_path .replace ('ar-' , 'slim-' )
286
- os .makedirs (osp .dirname (local_slim_model_path ), exist_ok = True )
287
- print (f'model_path: { model_path } , slim_model_path: { slim_model_path } ' )
288
- print (f'local_model_path: { local_model_path } , local_slim_model_path: { local_slim_model_path } ' )
289
- if not osp .exists (local_slim_model_path ):
290
- if osp .exists (slim_model_path ):
291
- print (f'copy { slim_model_path } to { local_slim_model_path } ' )
292
- shutil .copyfile (slim_model_path , local_slim_model_path )
293
- else :
294
- if not osp .exists (local_model_path ):
295
- print (f'copy { model_path } to { local_model_path } ' )
296
- shutil .copyfile (model_path , local_model_path )
297
- save_slim_model (local_model_path , save_file = local_slim_model_path , device = device )
298
- print (f'copy { local_slim_model_path } to { slim_model_path } ' )
299
- if not osp .exists (slim_model_path ):
300
- shutil .copyfile (local_slim_model_path , slim_model_path )
301
- os .remove (local_model_path )
302
- os .remove (model_path )
303
- slim_model_path = local_slim_model_path
283
+ if args .enable_model_cache :
284
+ slim_model_path = model_path .replace ('ar-' , 'slim-' )
285
+ local_slim_model_path = local_model_path .replace ('ar-' , 'slim-' )
286
+ os .makedirs (osp .dirname (local_slim_model_path ), exist_ok = True )
287
+ print (f'model_path: { model_path } , slim_model_path: { slim_model_path } ' )
288
+ print (f'local_model_path: { local_model_path } , local_slim_model_path: { local_slim_model_path } ' )
289
+ if not osp .exists (local_slim_model_path ):
290
+ if osp .exists (slim_model_path ):
291
+ print (f'copy { slim_model_path } to { local_slim_model_path } ' )
292
+ shutil .copyfile (slim_model_path , local_slim_model_path )
293
+ else :
294
+ if not osp .exists (local_model_path ):
295
+ print (f'copy { model_path } to { local_model_path } ' )
296
+ shutil .copyfile (model_path , local_model_path )
297
+ save_slim_model (local_model_path , save_file = local_slim_model_path , device = device )
298
+ print (f'copy { local_slim_model_path } to { slim_model_path } ' )
299
+ if not osp .exists (slim_model_path ):
300
+ shutil .copyfile (local_slim_model_path , slim_model_path )
301
+ os .remove (local_model_path )
302
+ os .remove (model_path )
303
+ slim_model_path = local_slim_model_path
304
+ else :
305
+ slim_model_path = model_path
304
306
print (f'load checkpoint from { slim_model_path } ' )
305
307
306
308
if args .model_type == 'infinity_2b' :
@@ -358,9 +360,11 @@ def add_common_arguments(parser):
358
360
parser .add_argument ('--use_flex_attn' , type = int , default = 0 , choices = [0 ,1 ])
359
361
parser .add_argument ('--enable_positive_prompt' , type = int , default = 0 , choices = [0 ,1 ])
360
362
parser .add_argument ('--cache_dir' , type = str , default = '/dev/shm' )
363
+ parser .add_argument ('--enable_model_cache' , type = int , default = 0 , choices = [0 ,1 ])
361
364
parser .add_argument ('--checkpoint_type' , type = str , default = 'torch' )
362
365
parser .add_argument ('--seed' , type = int , default = 0 )
363
366
parser .add_argument ('--bf16' , type = int , default = 1 , choices = [0 ,1 ])
367
+
364
368
365
369
366
370
if __name__ == '__main__' :
0 commit comments