@@ -455,15 +455,15 @@ def __init__(
455455 self ,
456456 model = "bes-dev/stable-diffusion-v1-4-openvino" ,
457457 tokenizer = "openai/clip-vit-large-patch14" ,
458- device = ["CPU" ,"CPU" ,"CPU" ,"CPU" ]):
458+ device = ["CPU" ,"CPU" ,"CPU" ,"CPU" ], model_name = "fp16" ):
459459
460460 self .core = Core ()
461461 self .core .set_property ({'CACHE_DIR' : os .path .join (model , 'cache' )})
462462
463463 batch_size = 2 if device [1 ] == device [2 ] and device [1 ] == "GPU" else 1
464464
465465 # if 'int8' is in model, then we are using unet_int8a16 model, and for this we will always use batch size 1.
466- if "int8" in model :
466+ if "int8" in model_name :
467467 batch_size = 1
468468
469469 self .batch_size = batch_size
@@ -477,21 +477,24 @@ def __init__(
477477 self .tokenizer = CLIPTokenizer .from_pretrained (tokenizer )
478478 self .tokenizer .save_pretrained (model )
479479
480- print ( "Loading models... " )
480+
481481
482482 with concurrent .futures .ThreadPoolExecutor (max_workers = 8 ) as executor :
483483 text_future = executor .submit (self .load_model , model , "text_encoder" , device [0 ])
484484 vae_de_future = executor .submit (self .load_model , model , "vae_decoder" , device [3 ])
485485 vae_en_future = executor .submit (self .load_model , model , "vae_encoder" , device [3 ])
486486
487487 if self .batch_size == 1 :
488- if "int8" not in model :
489- unet_future = executor .submit (self .load_model , model , "unet_bs1" , device [1 ])
490- unet_neg_future = executor .submit (self .load_model , model , "unet_bs1" , device [2 ]) if device [1 ] != device [2 ] else None
491- else :
492- unet_future = executor .submit (self .load_model , model , "unet_int8a16" , device [1 ])
488+ if "int8a16" in model_name :
489+ print ("Loading models ... int8a16" )
490+ unet_future = executor .submit (self .load_model , model , "unet_int8a16" , device [1 ])
493491 unet_neg_future = executor .submit (self .load_model , model , "unet_int8a16" , device [2 ]) if device [1 ] != device [2 ] else None
492+ else :
493+ print ("Loading models ... fp16 bs1" )
494+ unet_future = executor .submit (self .load_model , model , "unet_bs1" , device [1 ])
495+ unet_neg_future = executor .submit (self .load_model , model , "unet_bs1" , device [2 ]) if device [1 ] != device [2 ] else None
494496 else :
497+ print ("Loading models ... fp16" )
495498 unet_future = executor .submit (self .load_model , model , "unet" , device [1 ])
496499 unet_neg_future = None
497500
0 commit comments