@@ -604,8 +604,7 @@ def get_coreml_pipe(pytorch_pipe,
604
604
"tokenizer" : pytorch_pipe .tokenizer ,
605
605
'tokenizer_2' : pytorch_pipe .tokenizer_2 ,
606
606
"scheduler" : pytorch_pipe .scheduler if scheduler_override is None else scheduler_override ,
607
- "force_zeros_for_empty_prompt" : force_zeros_for_empty_prompt ,
608
- 'xl' : True
607
+ 'xl' : True ,
609
608
}
610
609
611
610
model_packages_to_load = ["text_encoder" , "text_encoder_2" , "unet" , "vae_decoder" ]
@@ -618,6 +617,8 @@ def get_coreml_pipe(pytorch_pipe,
618
617
}
619
618
model_packages_to_load = ["text_encoder" , "unet" , "vae_decoder" ]
620
619
620
+ coreml_pipe_kwargs ["force_zeros_for_empty_prompt" ] = force_zeros_for_empty_prompt
621
+
621
622
if getattr (pytorch_pipe , "safety_checker" , None ) is not None :
622
623
model_packages_to_load .append ("safety_checker" )
623
624
else :
@@ -713,7 +714,7 @@ def main(args):
713
714
714
715
# Get Force Zeros Config if it exists
715
716
force_zeros_for_empty_prompt : bool = False
716
- if 'force_zeros_for_empty_prompt' in pytorch_pipe .config :
717
+ if 'xl' in args . model_version and ' force_zeros_for_empty_prompt' in pytorch_pipe .config :
717
718
force_zeros_for_empty_prompt = pytorch_pipe .config ['force_zeros_for_empty_prompt' ]
718
719
719
720
coreml_pipe = get_coreml_pipe (
0 commit comments