1313# limitations under the License. 
1414
1515import  inspect 
16+ import  math 
1617from  typing  import  Any , Callable , Dict , List , Optional , Union 
1718
1819import  torch 
1920import  torch .nn  as  nn 
2021import  torch .nn .functional  as  F 
22+ from  einops  import  rearrange 
2123from  transformers  import  (
2224    CLIPTextModelWithProjection ,
2325    CLIPTokenizer ,
2830from  diffusers .image_processor  import  VaeImageProcessor 
2931from  diffusers .loaders  import  FromSingleFileMixin , SD3LoraLoaderMixin 
3032from  diffusers .models .autoencoders  import  AutoencoderKL 
33+ from  diffusers .models .embeddings  import  TimestepEmbedding , Timesteps 
34+ from  diffusers .models .normalization  import  RMSNorm 
35+ from  diffusers .models .transformers  import  SD3Transformer2DModel 
36+ from  diffusers .pipelines .pipeline_utils  import  DiffusionPipeline 
37+ from  diffusers .pipelines .stable_diffusion_3 .pipeline_output  import  StableDiffusion3PipelineOutput 
3138from  diffusers .schedulers  import  FlowMatchEulerDiscreteScheduler 
3239from  diffusers .utils  import  (
3340    USE_PEFT_BACKEND ,
3845    unscale_lora_layers ,
3946)
4047from  diffusers .utils .torch_utils  import  randn_tensor 
41- from  diffusers .pipelines .pipeline_utils  import  DiffusionPipeline 
42- from  diffusers .pipelines .stable_diffusion_3 .pipeline_output  import  StableDiffusion3PipelineOutput 
43- 
44- from  diffusers .models .transformers  import  SD3Transformer2DModel 
45- from  diffusers .models .normalization  import  RMSNorm 
46- from  einops  import  rearrange 
47- import  math 
48- 
49- from  diffusers .models .embeddings  import  Timesteps , TimestepEmbedding 
5048
5149
5250if  is_torch_xla_available ():
@@ -86,10 +84,10 @@ def FeedForward(dim, mult=4):
8684        nn .Linear (inner_dim , dim , bias = False ),
8785    )
8886
89-      
87+ 
9088def  reshape_tensor (x , heads ):
9189    bs , length , width  =  x .shape 
92-     #(bs, length, width) --> (bs, length, n_heads, dim_per_head) 
90+     #  (bs, length, width) --> (bs, length, n_heads, dim_per_head) 
9391    x  =  x .view (bs , length , heads , - 1 )
9492    # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) 
9593    x  =  x .transpose (1 , 2 )
@@ -113,7 +111,6 @@ def __init__(self, *, dim, dim_head=64, heads=8):
113111        self .to_kv  =  nn .Linear (dim , inner_dim  *  2 , bias = False )
114112        self .to_out  =  nn .Linear (inner_dim , dim , bias = False )
115113
116- 
117114    def  forward (self , x , latents , shift = None , scale = None ):
118115        """ 
119116        Args: 
@@ -127,23 +124,23 @@ def forward(self, x, latents, shift=None, scale=None):
127124
128125        if  shift  is  not None  and  scale  is  not None :
129126            latents  =  latents  *  (1  +  scale .unsqueeze (1 )) +  shift .unsqueeze (1 )
130-          
127+ 
131128        b , l , _  =  latents .shape 
132129
133130        q  =  self .to_q (latents )
134131        kv_input  =  torch .cat ((x , latents ), dim = - 2 )
135132        k , v  =  self .to_kv (kv_input ).chunk (2 , dim = - 1 )
136-          
133+ 
137134        q  =  reshape_tensor (q , self .heads )
138135        k  =  reshape_tensor (k , self .heads )
139136        v  =  reshape_tensor (v , self .heads )
140137
141138        # attention 
142139        scale  =  1  /  math .sqrt (math .sqrt (self .dim_head ))
143-         weight  =  (q  *  scale ) @ (k  *  scale ).transpose (- 2 , - 1 ) # More stable with f16 than dividing afterwards 
140+         weight  =  (q  *  scale ) @ (k  *  scale ).transpose (- 2 , - 1 )   # More stable with f16 than dividing afterwards 
144141        weight  =  torch .softmax (weight .float (), dim = - 1 ).type (weight .dtype )
145142        out  =  weight  @ v 
146-          
143+ 
147144        out  =  out .permute (0 , 2 , 1 , 3 ).reshape (b , l , - 1 )
148145
149146        return  self .to_out (out )
@@ -166,14 +163,14 @@ def __init__(
166163        timestep_freq_shift = 0 ,
167164    ):
168165        super ().__init__ ()
169-          
166+ 
170167        self .latents  =  nn .Parameter (torch .randn (1 , num_queries , dim ) /  dim ** 0.5 )
171-          
168+ 
172169        self .proj_in  =  nn .Linear (embedding_dim , dim )
173170
174171        self .proj_out  =  nn .Linear (dim , output_dim )
175172        self .norm_out  =  nn .LayerNorm (output_dim )
176-          
173+ 
177174        self .layers  =  nn .ModuleList ([])
178175        for  _  in  range (depth ):
179176            self .layers .append (
@@ -184,7 +181,7 @@ def __init__(
184181                        # ff 
185182                        FeedForward (dim = dim , mult = ff_mult ),
186183                        # adaLN 
187-                         nn .Sequential (nn .SiLU (), nn .Linear (dim , 4  *  dim , bias = True ))
184+                         nn .Sequential (nn .SiLU (), nn .Linear (dim , 4  *  dim , bias = True )), 
188185                    ]
189186                )
190187            )
@@ -199,12 +196,11 @@ def __init__(
199196        #     nn.Linear(timestep_out_dim, 6 * timestep_out_dim, bias=True) 
200197        # ) 
201198
202- 
203199    def  forward (self , x , timestep , need_temb = False ):
204200        timestep_emb  =  self .embedding_time (x , timestep )  # bs, dim 
205201
206202        latents  =  self .latents .repeat (x .size (0 ), 1 , 1 )
207-          
203+ 
208204        x  =  self .proj_in (x )
209205        x  =  x  +  timestep_emb [:, None ]
210206
@@ -221,7 +217,7 @@ def forward(self, x, timestep, need_temb=False):
221217            latents  =  latents  +  res 
222218
223219            # latents = ff(latents) + latents 
224-              
220+ 
225221        latents  =  self .proj_out (latents )
226222        latents  =  self .norm_out (latents )
227223
@@ -230,10 +226,7 @@ def forward(self, x, timestep, need_temb=False):
230226        else :
231227            return  latents 
232228
233- 
234- 
235229    def  embedding_time (self , sample , timestep ):
236- 
237230        # 1. time 
238231        timesteps  =  timestep 
239232        if  not  torch .is_tensor (timesteps ):
@@ -271,32 +264,29 @@ class AdaLayerNorm(nn.Module):
271264        num_embeddings (`int`): The size of the embeddings dictionary. 
272265    """ 
273266
274-     def  __init__ (self , embedding_dim : int , time_embedding_dim = None , mode = ' normal' 
267+     def  __init__ (self , embedding_dim : int , time_embedding_dim = None , mode = " normal" 
275268        super ().__init__ ()
276269
277270        self .silu  =  nn .SiLU ()
278-         num_params_dict  =  dict (
279-             zero = 6 ,
280-             normal = 2 ,
281-         )
282-         num_params  =  num_params_dict [mode ]
271+ 
272+         num_params  =  2  if  mode  ==  "normal"  else  6 
283273        self .linear  =  nn .Linear (time_embedding_dim  or  embedding_dim , num_params  *  embedding_dim , bias = True )
284274        self .norm  =  nn .LayerNorm (embedding_dim , elementwise_affine = False , eps = 1e-6 )
285275        self .mode  =  mode 
286276
287277    def  forward (
288278        self ,
289279        x ,
290-         hidden_dtype   =   None ,
291-         emb   =   None ,
280+         hidden_dtype = None ,
281+         emb = None ,
292282    ):
293283        emb  =  self .linear (self .silu (emb ))
294-         if  self .mode  ==  ' normal' 
284+         if  self .mode  ==  " normal" 
295285            shift_msa , scale_msa  =  emb .chunk (2 , dim = 1 )
296286            x  =  self .norm (x ) *  (1  +  scale_msa [:, None ]) +  shift_msa [:, None ]
297287            return  x 
298288
299-         elif  self .mode  ==  ' zero' 
289+         elif  self .mode  ==  " zero" 
300290            shift_msa , scale_msa , gate_msa , shift_mlp , scale_mlp , gate_mlp  =  emb .chunk (6 , dim = 1 )
301291            x  =  self .norm (x ) *  (1  +  scale_msa [:, None ]) +  shift_msa [:, None ]
302292            return  x , gate_msa , shift_mlp , scale_mlp , gate_mlp 
@@ -323,7 +313,6 @@ def __init__(
323313        self .norm_k  =  RMSNorm (head_dim , 1e-6 )
324314        self .norm_ip_k  =  RMSNorm (head_dim , 1e-6 )
325315
326- 
327316    def  __call__ (
328317        self ,
329318        attn ,
@@ -396,9 +385,8 @@ def __call__(
396385            if  not  attn .context_pre_only :
397386                encoder_hidden_states  =  attn .to_add_out (encoder_hidden_states )
398387
399- 
400388        # IPadapter 
401-         ip_hidden_states  =  emb_dict .get (' ip_hidden_states' None )
389+         ip_hidden_states  =  emb_dict .get (" ip_hidden_states" None )
402390        ip_hidden_states  =  self .get_ip_hidden_states (
403391            attn ,
404392            img_query ,
@@ -407,11 +395,10 @@ def __call__(
407395            img_value ,
408396            None ,
409397            None ,
410-             emb_dict [' temb' 
398+             emb_dict [" temb" 
411399        )
412400        if  ip_hidden_states  is  not None :
413-             hidden_states  =  hidden_states  +  ip_hidden_states  *  emb_dict .get ('scale' , 1.0 )
414- 
401+             hidden_states  =  hidden_states  +  ip_hidden_states  *  emb_dict .get ("scale" , 1.0 )
415402
416403        # linear proj 
417404        hidden_states  =  attn .to_out [0 ](hidden_states )
@@ -423,12 +410,13 @@ def __call__(
423410        else :
424411            return  hidden_states 
425412
426- 
427-     def  get_ip_hidden_states (self , attn , query , ip_hidden_states , img_key = None , img_value = None , text_key = None , text_value = None , temb = None ):
413+     def  get_ip_hidden_states (
414+         self , attn , query , ip_hidden_states , img_key = None , img_value = None , text_key = None , text_value = None , temb = None 
415+     ):
428416        if  ip_hidden_states  is  None :
429417            return  None 
430-          
431-         if  not  hasattr (self , ' to_k_ip' or  not  hasattr (self , ' to_v_ip' 
418+ 
419+         if  not  hasattr (self , " to_k_ip" or  not  hasattr (self , " to_v_ip" 
432420            return  None 
433421
434422        # norm ip input 
@@ -439,11 +427,11 @@ def get_ip_hidden_states(self, attn, query, ip_hidden_states, img_key=None, img_
439427        ip_value  =  self .to_v_ip (norm_ip_hidden_states )
440428
441429        # reshape 
442-         query  =  rearrange (query , ' b l (h d) -> b h l d' h = attn .heads )
443-         img_key  =  rearrange (img_key , ' b l (h d) -> b h l d' h = attn .heads )
444-         img_value  =  rearrange (img_value , ' b l (h d) -> b h l d' h = attn .heads )
445-         ip_key  =  rearrange (ip_key , ' b l (h d) -> b h l d' h = attn .heads )
446-         ip_value  =  rearrange (ip_value , ' b l (h d) -> b h l d' h = attn .heads )
430+         query  =  rearrange (query , " b l (h d) -> b h l d" h = attn .heads )
431+         img_key  =  rearrange (img_key , " b l (h d) -> b h l d" h = attn .heads )
432+         img_value  =  rearrange (img_value , " b l (h d) -> b h l d" h = attn .heads )
433+         ip_key  =  rearrange (ip_key , " b l (h d) -> b h l d" h = attn .heads )
434+         ip_value  =  rearrange (ip_value , " b l (h d) -> b h l d" h = attn .heads )
447435
448436        # norm 
449437        query  =  self .norm_q (query )
@@ -454,9 +442,9 @@ def get_ip_hidden_states(self, attn, query, ip_hidden_states, img_key=None, img_
454442        key  =  torch .cat ([img_key , ip_key ], dim = 2 )
455443        value  =  torch .cat ([img_value , ip_value ], dim = 2 )
456444
457-         #   
445+         # 
458446        ip_hidden_states  =  F .scaled_dot_product_attention (query , key , value , dropout_p = 0.0 , is_causal = False )
459-         ip_hidden_states  =  rearrange (ip_hidden_states , ' b h l d -> b l (h d)' 
447+         ip_hidden_states  =  rearrange (ip_hidden_states , " b h l d -> b l (h d)" 
460448        ip_hidden_states  =  ip_hidden_states .to (query .dtype )
461449        return  ip_hidden_states 
462450
@@ -1049,10 +1037,10 @@ def num_timesteps(self):
10491037    def  interrupt (self ):
10501038        return  self ._interrupt 
10511039
1052- 
10531040    @torch .inference_mode () 
10541041    def  init_ipadapter (self , ip_adapter_path , image_encoder_path , nb_token , output_dim = 2432 ):
1055-         from  transformers  import  SiglipVisionModel , SiglipImageProcessor 
1042+         from  transformers  import  SiglipImageProcessor , SiglipVisionModel 
1043+ 
10561044        state_dict  =  torch .load (ip_adapter_path , map_location = "cpu" )
10571045
10581046        device , dtype  =  self .transformer .device , self .transformer .dtype 
@@ -1084,14 +1072,13 @@ def init_ipadapter(self, ip_adapter_path, image_encoder_path, nb_token, output_d
10841072
10851073        self .image_proj_model  =  image_proj_model 
10861074
1087- 
10881075        attn_procs  =  {}
10891076        transformer  =  self .transformer 
10901077        for  idx_name , name  in  enumerate (transformer .attn_processors .keys ()):
10911078            hidden_size  =  transformer .config .attention_head_dim  *  transformer .config .num_attention_heads 
10921079            ip_hidden_states_dim  =  transformer .config .attention_head_dim  *  transformer .config .num_attention_heads 
10931080            ip_encoder_hidden_states_dim  =  transformer .config .caption_projection_dim 
1094-              
1081+ 
10951082            attn_procs [name ] =  JointIPAttnProcessor (
10961083                hidden_size = hidden_size ,
10971084                cross_attention_dim = transformer .config .caption_projection_dim ,
@@ -1107,10 +1094,8 @@ def init_ipadapter(self, ip_adapter_path, image_encoder_path, nb_token, output_d
11071094        key_name  =  tmp_ip_layers .load_state_dict (state_dict ["ip_adapter" ], strict = False )
11081095        print (f"=> loading ip_adapter: { key_name }  )
11091096
1110- 
11111097    @torch .inference_mode () 
11121098    def  encode_clip_image_emb (self , clip_image , device , dtype ):
1113- 
11141099        # clip 
11151100        clip_image_tensor  =  self .clip_image_processor (images = clip_image , return_tensors = "pt" ).pixel_values 
11161101        clip_image_tensor  =  clip_image_tensor .to (device , dtype = dtype )
@@ -1119,8 +1104,6 @@ def encode_clip_image_emb(self, clip_image, device, dtype):
11191104
11201105        return  clip_image_embeds 
11211106
1122- 
1123- 
11241107    @torch .no_grad () 
11251108    @replace_example_docstring (EXAMPLE_DOC_STRING ) 
11261109    def  __call__ (
@@ -1150,7 +1133,6 @@ def __call__(
11501133        callback_on_step_end : Optional [Callable [[int , int , Dict ], None ]] =  None ,
11511134        callback_on_step_end_tensor_inputs : List [str ] =  ["latents" ],
11521135        max_sequence_length : int  =  256 ,
1153- 
11541136        # ipa 
11551137        clip_image = None ,
11561138        ipadapter_scale = 1.0 ,
@@ -1349,18 +1331,16 @@ def __call__(
13491331                timestep  =  t .expand (latent_model_input .shape [0 ])
13501332
13511333                image_prompt_embeds , timestep_emb  =  self .image_proj_model (
1352-                     clip_image_embeds , 
1353-                     timestep .to (dtype = latents .dtype ), 
1354-                     need_temb = True 
1334+                     clip_image_embeds , timestep .to (dtype = latents .dtype ), need_temb = True 
13551335                )
13561336
1357-                 joint_attention_kwargs  =  dict ( 
1358-                     emb_dict = dict ( 
1359-                         ip_hidden_states = image_prompt_embeds ,
1360-                         temb = timestep_emb ,
1361-                         scale = ipadapter_scale ,
1362-                     ) 
1363-                 ) 
1337+                 joint_attention_kwargs  =  { 
1338+                     " emb_dict" : { 
1339+                         " ip_hidden_states" :  image_prompt_embeds ,
1340+                         " temb" :  timestep_emb ,
1341+                         " scale" :  ipadapter_scale ,
1342+                     } 
1343+                 } 
13641344
13651345                noise_pred  =  self .transformer (
13661346                    hidden_states = latent_model_input ,
@@ -1420,4 +1400,4 @@ def __call__(
14201400        if  not  return_dict :
14211401            return  (image ,)
14221402
1423-         return  StableDiffusion3PipelineOutput (images = image )
1403+         return  StableDiffusion3PipelineOutput (images = image )
0 commit comments