2424from  ...loaders  import  FluxTransformer2DLoadersMixin , FromOriginalModelMixin , PeftAdapterMixin 
2525from  ...utils  import  USE_PEFT_BACKEND , logging , scale_lora_layers , unscale_lora_layers 
2626from  ...utils .torch_utils  import  maybe_allow_in_graph 
27- from  ..attention  import  Attention ,  AttentionMixin , FeedForward 
27+ from  ..attention  import  AttentionMixin ,  AttentionModuleMixin , FeedForward 
2828from  ..attention_dispatch  import  dispatch_attention_fn 
2929from  ..cache_utils  import  CacheMixin 
30- from  ..embeddings  import  CombinedTimestepGuidanceTextProjEmbeddings , CombinedTimestepTextProjEmbeddings , FluxPosEmbed 
30+ from  ..embeddings  import  (
31+     CombinedTimestepGuidanceTextProjEmbeddings ,
32+     CombinedTimestepTextProjEmbeddings ,
33+     apply_rotary_emb ,
34+     get_1d_rotary_pos_embed ,
35+ )
3136from  ..modeling_outputs  import  Transformer2DModelOutput 
3237from  ..modeling_utils  import  ModelMixin 
3338from  ..normalization  import  AdaLayerNormContinuous , AdaLayerNormZero , AdaLayerNormZeroSingle 
@@ -73,7 +78,6 @@ def get_qkv_projections(self, attn, hidden_states, encoder_hidden_states=None):
7378        """Public method to get projections based on whether we're using fused mode or not.""" 
7479        if  attn .is_fused  and  hasattr (attn , "to_qkv" ):
7580            return  self ._get_fused_projections (attn , hidden_states , encoder_hidden_states )
76- 
7781        return  self ._get_projections (attn , hidden_states , encoder_hidden_states )
7882
7983    def  __call__ (
@@ -117,17 +121,10 @@ def __call__(
117121            value  =  torch .cat ([encoder_value , value ], dim = 2 )
118122
119123        if  image_rotary_emb  is  not None :
120-             from  ..embeddings  import  apply_rotary_emb 
121- 
122124            query  =  apply_rotary_emb (query , image_rotary_emb )
123125            key  =  apply_rotary_emb (key , image_rotary_emb )
124126
125-         hidden_states  =  dispatch_attention_fn (
126-             query ,
127-             key ,
128-             value ,
129-             attn_mask = attention_mask ,
130-         )
127+         hidden_states  =  dispatch_attention_fn (query , key , value , attn_mask = attention_mask )
131128
132129        hidden_states  =  hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads  *  head_dim )
133130        hidden_states  =  hidden_states .to (query .dtype )
@@ -242,12 +239,10 @@ def __call__(
242239            value  =  torch .cat ([encoder_hidden_states_value_proj , value ], dim = 2 )
243240
244241        if  image_rotary_emb  is  not None :
245-             from  .embeddings  import  apply_rotary_emb 
246- 
247242            query  =  apply_rotary_emb (query , image_rotary_emb )
248243            key  =  apply_rotary_emb (key , image_rotary_emb )
249244
250-         hidden_states  =  F . scaled_dot_product_attention (query , key , value , dropout_p = 0.0 , is_causal = False )
245+         hidden_states  =  dispatch_attention_fn (query , key , value , dropout_p = 0.0 , is_causal = False )
251246        hidden_states  =  hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads  *  head_dim )
252247        hidden_states  =  hidden_states .to (query .dtype )
253248
@@ -292,13 +287,76 @@ def __call__(
292287
293288
294289@maybe_allow_in_graph  
295- class  FluxAttention (Attention ):
290+ class  FluxAttention (torch . nn . Module ,  AttentionModuleMixin ):
296291    _default_processor_cls  =  FluxAttnProcessor 
297292    _available_processors  =  [
298293        FluxAttnProcessor ,
299294        FluxIPAdapterAttnProcessor ,
300295    ]
301296
297+     def  __init__ (
298+         self ,
299+         query_dim : int ,
300+         heads : int  =  8 ,
301+         dim_head : int  =  64 ,
302+         dropout : float  =  0.0 ,
303+         bias : bool  =  False ,
304+         qk_norm : Optional [str ] =  None ,
305+         added_kv_proj_dim : Optional [int ] =  None ,
306+         added_proj_bias : Optional [bool ] =  True ,
307+         out_bias : bool  =  True ,
308+         eps : float  =  1e-5 ,
309+         out_dim : int  =  None ,
310+         context_pre_only : Optional [bool ] =  None ,
311+         pre_only : bool  =  False ,
312+         elementwise_affine : bool  =  True ,
313+         processor = None ,
314+     ):
315+         super ().__init__ ()
316+         assert  qk_norm  ==  "rms_norm" , "Flux uses RMSNorm" 
317+ 
318+         self .inner_dim  =  out_dim  if  out_dim  is  not None  else  dim_head  *  heads 
319+         self .query_dim  =  query_dim 
320+         self .use_bias  =  bias 
321+         self .dropout  =  dropout 
322+         self .out_dim  =  out_dim  if  out_dim  is  not None  else  query_dim 
323+         self .context_pre_only  =  context_pre_only 
324+         self .pre_only  =  pre_only 
325+         self .heads  =  out_dim  //  dim_head  if  out_dim  is  not None  else  heads 
326+         self .added_proj_bias  =  added_proj_bias 
327+ 
328+         self .norm_q  =  torch .nn .RMSNorm (dim_head , eps = eps , elementwise_affine = elementwise_affine )
329+         self .norm_k  =  torch .nn .RMSNorm (dim_head , eps = eps , elementwise_affine = elementwise_affine )
330+         self .to_q  =  torch .nn .Linear (query_dim , self .inner_dim , bias = bias )
331+         self .to_k  =  torch .nn .Linear (query_dim , self .inner_dim , bias = bias )
332+         self .to_v  =  torch .nn .Linear (query_dim , self .inner_dim , bias = bias )
333+ 
334+         if  not  self .pre_only :
335+             self .to_out  =  torch .nn .ModuleList ([])
336+             self .to_out .append (torch .nn .Linear (self .inner_dim , self .out_dim , bias = out_bias ))
337+ 
338+         if  added_kv_proj_dim  is  not None :
339+             self .norm_added_q  =  torch .nn .RMSNorm (dim_head , eps = eps )
340+             self .norm_added_k  =  torch .nn .RMSNorm (dim_head , eps = eps )
341+             self .add_q_proj  =  torch .nn .Linear (added_kv_proj_dim , self .inner_dim , bias = added_proj_bias )
342+             self .add_k_proj  =  torch .nn .Linear (added_kv_proj_dim , self .inner_dim , bias = added_proj_bias )
343+             self .add_v_proj  =  torch .nn .Linear (added_kv_proj_dim , self .inner_dim , bias = added_proj_bias )
344+             self .to_add_out  =  torch .nn .Linear (self .inner_dim , query_dim , bias = out_bias )
345+ 
346+         if  processor  is  None :
347+             processor  =  self ._default_processor_cls ()
348+         self .set_processor (processor )
349+ 
350+     def  forward (
351+         self ,
352+         hidden_states : torch .Tensor ,
353+         encoder_hidden_states : Optional [torch .Tensor ] =  None ,
354+         attention_mask : Optional [torch .Tensor ] =  None ,
355+         image_rotary_emb : Optional [torch .Tensor ] =  None ,
356+         ** kwargs ,
357+     ) ->  torch .Tensor :
358+         return  self .processor (self , hidden_states , encoder_hidden_states , attention_mask , image_rotary_emb , ** kwargs )
359+ 
302360
303361@maybe_allow_in_graph  
304362class  FluxSingleTransformerBlock (nn .Module ):
@@ -330,20 +388,19 @@ def forward(
330388        image_rotary_emb : Optional [Tuple [torch .Tensor , torch .Tensor ]] =  None ,
331389        joint_attention_kwargs : Optional [Dict [str , Any ]] =  None ,
332390    ) ->  torch .Tensor :
333-         residual  =  hidden_states 
391+         joint_attention_kwargs  =  joint_attention_kwargs  or  {}
392+ 
334393        norm_hidden_states , gate  =  self .norm (hidden_states , emb = temb )
335394        mlp_hidden_states  =  self .act_mlp (self .proj_mlp (norm_hidden_states ))
336-         joint_attention_kwargs  =  joint_attention_kwargs  or  {}
337395        attn_output  =  self .attn (
338396            hidden_states = norm_hidden_states ,
339397            image_rotary_emb = image_rotary_emb ,
340398            ** joint_attention_kwargs ,
341399        )
400+         attn_mlp_hidden_states  =  torch .cat ([attn_output , mlp_hidden_states ], dim = 2 )
401+         proj_out  =  self .proj_out (attn_mlp_hidden_states )
402+         hidden_states  =  hidden_states  +  gate .unsqueeze (1 ) *  proj_out 
342403
343-         hidden_states  =  torch .cat ([attn_output , mlp_hidden_states ], dim = 2 )
344-         gate  =  gate .unsqueeze (1 )
345-         hidden_states  =  gate  *  self .proj_out (hidden_states )
346-         hidden_states  =  residual  +  hidden_states 
347404        if  hidden_states .dtype  ==  torch .float16 :
348405            hidden_states  =  hidden_states .clip (- 65504 , 65504 )
349406
@@ -386,12 +443,13 @@ def forward(
386443        image_rotary_emb : Optional [Tuple [torch .Tensor , torch .Tensor ]] =  None ,
387444        joint_attention_kwargs : Optional [Dict [str , Any ]] =  None ,
388445    ) ->  Tuple [torch .Tensor , torch .Tensor ]:
389-         norm_hidden_states ,  gate_msa ,  shift_mlp ,  scale_mlp ,  gate_mlp   =   self . norm1 ( hidden_states ,  emb = temb ) 
446+         joint_attention_kwargs   =   joint_attention_kwargs   or  {} 
390447
448+         norm_hidden_states , gate_msa , shift_mlp , scale_mlp , gate_mlp  =  self .norm1 (hidden_states , emb = temb )
391449        norm_encoder_hidden_states , c_gate_msa , c_shift_mlp , c_scale_mlp , c_gate_mlp  =  self .norm1_context (
392450            encoder_hidden_states , emb = temb 
393451        )
394-          joint_attention_kwargs   =   joint_attention_kwargs   or  {} 
452+ 
395453        # Attention. 
396454        attention_outputs  =  self .attn (
397455            hidden_states = norm_hidden_states ,
@@ -410,7 +468,7 @@ def forward(
410468        hidden_states  =  hidden_states  +  attn_output 
411469
412470        norm_hidden_states  =  self .norm2 (hidden_states )
413-         norm_hidden_states  =  norm_hidden_states  *  (1  +  scale_mlp [:,  None ])  +  shift_mlp [:,  None ] 
471+         norm_hidden_states  =  norm_hidden_states  *  (1  +  scale_mlp . unsqueeze ( 1 ))  +  shift_mlp . unsqueeze ( 1 ) 
414472
415473        ff_output  =  self .ff (norm_hidden_states )
416474        ff_output  =  gate_mlp .unsqueeze (1 ) *  ff_output 
@@ -420,21 +478,54 @@ def forward(
420478            hidden_states  =  hidden_states  +  ip_attn_output 
421479
422480        # Process attention outputs for the `encoder_hidden_states`. 
423- 
424481        context_attn_output  =  c_gate_msa .unsqueeze (1 ) *  context_attn_output 
425482        encoder_hidden_states  =  encoder_hidden_states  +  context_attn_output 
426483
427484        norm_encoder_hidden_states  =  self .norm2_context (encoder_hidden_states )
428-         norm_encoder_hidden_states  =  norm_encoder_hidden_states  *  (1  +  c_scale_mlp [:, None ]) +  c_shift_mlp [:, None ]
485+         norm_encoder_hidden_states  =  norm_encoder_hidden_states  *  (
486+             1  +  c_scale_mlp .unsqueeze (1 )
487+         ) +  c_shift_mlp .unsqueeze (1 )
429488
430489        context_ff_output  =  self .ff_context (norm_encoder_hidden_states )
431490        encoder_hidden_states  =  encoder_hidden_states  +  c_gate_mlp .unsqueeze (1 ) *  context_ff_output 
491+ 
432492        if  encoder_hidden_states .dtype  ==  torch .float16 :
433493            encoder_hidden_states  =  encoder_hidden_states .clip (- 65504 , 65504 )
434494
435495        return  encoder_hidden_states , hidden_states 
436496
437497
498+ class  FluxPosEmbed (nn .Module ):
499+     # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 
500+     def  __init__ (self , theta : int , axes_dim : List [int ]):
501+         super ().__init__ ()
502+         self .theta  =  theta 
503+         self .axes_dim  =  axes_dim 
504+ 
505+     def  forward (self , ids : torch .Tensor ) ->  torch .Tensor :
506+         n_axes  =  ids .shape [- 1 ]
507+         cos_out  =  []
508+         sin_out  =  []
509+         pos  =  ids .float ()
510+         is_mps  =  ids .device .type  ==  "mps" 
511+         is_npu  =  ids .device .type  ==  "npu" 
512+         freqs_dtype  =  torch .float32  if  (is_mps  or  is_npu ) else  torch .float64 
513+         for  i  in  range (n_axes ):
514+             cos , sin  =  get_1d_rotary_pos_embed (
515+                 self .axes_dim [i ],
516+                 pos [:, i ],
517+                 theta = self .theta ,
518+                 repeat_interleave_real = True ,
519+                 use_real = True ,
520+                 freqs_dtype = freqs_dtype ,
521+             )
522+             cos_out .append (cos )
523+             sin_out .append (sin )
524+         freqs_cos  =  torch .cat (cos_out , dim = - 1 ).to (ids .device )
525+         freqs_sin  =  torch .cat (sin_out , dim = - 1 ).to (ids .device )
526+         return  freqs_cos , freqs_sin 
527+ 
528+ 
438529class  FluxTransformer2DModel (
439530    ModelMixin ,
440531    ConfigMixin ,
@@ -537,10 +628,6 @@ def __init__(
537628
538629        self .gradient_checkpointing  =  False 
539630
540-     # Using inherited methods from AttentionMixin 
541- 
542-     # Using inherited methods from AttentionMixin 
543- 
544631    def  forward (
545632        self ,
546633        hidden_states : torch .Tensor ,
@@ -634,11 +721,7 @@ def forward(
634721        for  index_block , block  in  enumerate (self .transformer_blocks ):
635722            if  torch .is_grad_enabled () and  self .gradient_checkpointing :
636723                encoder_hidden_states , hidden_states  =  self ._gradient_checkpointing_func (
637-                     block ,
638-                     hidden_states ,
639-                     encoder_hidden_states ,
640-                     temb ,
641-                     image_rotary_emb ,
724+                     block , hidden_states , encoder_hidden_states , temb , image_rotary_emb 
642725                )
643726
644727            else :
@@ -665,12 +748,7 @@ def forward(
665748
666749        for  index_block , block  in  enumerate (self .single_transformer_blocks ):
667750            if  torch .is_grad_enabled () and  self .gradient_checkpointing :
668-                 hidden_states  =  self ._gradient_checkpointing_func (
669-                     block ,
670-                     hidden_states ,
671-                     temb ,
672-                     image_rotary_emb ,
673-                 )
751+                 hidden_states  =  self ._gradient_checkpointing_func (block , hidden_states , temb , image_rotary_emb )
674752
675753            else :
676754                hidden_states  =  block (
0 commit comments