1717
1818import  torch 
1919from  transformers  import  (
20+     BaseImageProcessor ,
2021    CLIPTextModelWithProjection ,
2122    CLIPTokenizer ,
23+     PreTrainedModel ,
2224    T5EncoderModel ,
2325    T5TokenizerFast ,
24-     PreTrainedModel ,
25-     BaseImageProcessor ,
2626)
2727
28- from  ...image_processor  import  VaeImageProcessor ,  PipelineImageInput 
29- from  ...loaders  import  FromSingleFileMixin , SD3LoraLoaderMixin ,  SD3IPAdapterMixin 
28+ from  ...image_processor  import  PipelineImageInput ,  VaeImageProcessor 
29+ from  ...loaders  import  FromSingleFileMixin , SD3IPAdapterMixin ,  SD3LoraLoaderMixin 
3030from  ...models .autoencoders  import  AutoencoderKL 
3131from  ...models .transformers  import  SD3Transformer2DModel 
3232from  ...schedulers  import  FlowMatchEulerDiscreteScheduler 
@@ -184,7 +184,7 @@ def __init__(
184184        text_encoder_3 : T5EncoderModel ,
185185        tokenizer_3 : T5TokenizerFast ,
186186        image_encoder : PreTrainedModel  =  None ,
187-         feature_extractor : BaseImageProcessor  =  None 
187+         feature_extractor : BaseImageProcessor  =  None , 
188188    ):
189189        super ().__init__ ()
190190
@@ -199,7 +199,7 @@ def __init__(
199199            transformer = transformer ,
200200            scheduler = scheduler ,
201201            image_encoder = image_encoder ,
202-             feature_extractor = feature_extractor 
202+             feature_extractor = feature_extractor , 
203203        )
204204        self .vae_scale_factor  =  (
205205            2  **  (len (self .vae .config .block_out_channels ) -  1 ) if  hasattr (self , "vae" ) and  self .vae  is  not None  else  8 
@@ -678,7 +678,7 @@ def num_timesteps(self):
678678    @property  
679679    def  interrupt (self ):
680680        return  self ._interrupt 
681-      
681+ 
682682    # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_image 
683683    def  encode_image (self , image ):
684684        if  not  isinstance (image , torch .Tensor ):
@@ -687,16 +687,18 @@ def encode_image(self, image):
687687        image  =  image .to (device = self .device , dtype = self .dtype )
688688
689689        image_enc_hidden_states  =  self .image_encoder (image , output_hidden_states = True ).hidden_states [- 2 ]
690-         uncond_image_enc_hidden_states  =  self .image_encoder (torch .zeros_like (image ), output_hidden_states = True ).hidden_states [- 2 ]
691-         
690+         uncond_image_enc_hidden_states  =  self .image_encoder (
691+             torch .zeros_like (image ), output_hidden_states = True 
692+         ).hidden_states [- 2 ]
693+ 
692694        return  image_enc_hidden_states , uncond_image_enc_hidden_states 
693695
694696    # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.prepare_ip_adapter_image_embeds 
695697    def  prepare_ip_adapter_image_embeds (
696698        self , ip_adapter_image , ip_adapter_image_embeds , device , num_images_per_prompt , do_classifier_free_guidance 
697699    ):
698700        if  ip_adapter_image_embeds  is  None :
699-                  single_image_embeds , single_negative_image_embeds  =  self .encode_image (ip_adapter_image )
701+             single_image_embeds , single_negative_image_embeds  =  self .encode_image (ip_adapter_image )
700702        else :
701703            for  single_image_embeds  in  ip_adapter_image_embeds :
702704                if  do_classifier_free_guidance :
@@ -705,13 +707,13 @@ def prepare_ip_adapter_image_embeds(
705707                    single_image_embeds  =  ip_adapter_image_embeds 
706708
707709        single_image_embeds  =  torch .cat ([single_image_embeds ] *  num_images_per_prompt , dim = 0 )
708-          
710+ 
709711        if  do_classifier_free_guidance :
710712            single_negative_image_embeds  =  torch .cat ([single_negative_image_embeds ] *  num_images_per_prompt , dim = 0 )
711713            single_image_embeds  =  torch .cat ([single_negative_image_embeds , single_image_embeds ], dim = 0 )
712714
713715        return  single_image_embeds .to (device = device )
714-          
716+ 
715717    @torch .no_grad () 
716718    @replace_example_docstring (EXAMPLE_DOC_STRING ) 
717719    def  __call__ (
@@ -979,15 +981,12 @@ def __call__(
979981                        need_temb = True ,
980982                    )
981983
982-                     image_prompt_embeds  =  dict (
983-                         ip_hidden_states = ip_hidden_states ,
984-                         temb = temb 
985-                     )
984+                     image_prompt_embeds  =  {"ip_hidden_states" : ip_hidden_states , "temb" : temb }
986985
987986                    if  self .joint_attention_kwargs  is  None :
988987                        self ._joint_attention_kwargs  =  image_prompt_embeds 
989988                    else :
990-                         self ._joint_attention_kwargs .update (** image_prompt_embeds )                         
989+                         self ._joint_attention_kwargs .update (** image_prompt_embeds )
991990
992991                noise_pred  =  self .transformer (
993992                    hidden_states = latent_model_input ,
0 commit comments