@@ -82,13 +82,26 @@ def __call__(
8282 self ,
8383 attn : "FluxAttention" ,
8484 hidden_states : torch .Tensor ,
85+ other_hidden_states : torch .Tensor ,
8586 encoder_hidden_states : torch .Tensor = None ,
8687 attention_mask : Optional [torch .Tensor ] = None ,
8788 image_rotary_emb : Optional [torch .Tensor ] = None ,
8889 ) -> torch .Tensor :
89- query , key , value , encoder_query , encoder_key , encoder_value = _get_qkv_projections (
90- attn , hidden_states , encoder_hidden_states
91- )
90+
91+ if other_hidden_states is not None :
92+ query , _ , _ , encoder_query , encoder_key , encoder_value = _get_qkv_projections (
93+ attn , hidden_states , encoder_hidden_states
94+ )
95+
96+ _ , key , value , _ , _ , _ = _get_qkv_projections (
97+ attn , hidden_states , encoder_hidden_states
98+ )
99+ else :
100+ query , key , value , encoder_query , encoder_key , encoder_value = _get_qkv_projections (
101+ attn , hidden_states , encoder_hidden_states
102+ )
103+
104+
92105
93106 query = query .unflatten (- 1 , (attn .heads , - 1 ))
94107 key = key .unflatten (- 1 , (attn .heads , - 1 ))
@@ -176,6 +189,7 @@ def __call__(
176189 self ,
177190 attn : "FluxAttention" ,
178191 hidden_states : torch .Tensor ,
192+ other_hidden_states : torch .Tensor ,
179193 encoder_hidden_states : torch .Tensor = None ,
180194 attention_mask : Optional [torch .Tensor ] = None ,
181195 image_rotary_emb : Optional [torch .Tensor ] = None ,
@@ -184,9 +198,19 @@ def __call__(
184198 ) -> torch .Tensor :
185199 batch_size = hidden_states .shape [0 ]
186200
187- query , key , value , encoder_query , encoder_key , encoder_value = _get_qkv_projections (
188- attn , hidden_states , encoder_hidden_states
189- )
201+
202+ if other_hidden_states is not None :
203+ query , _ , _ , encoder_query , encoder_key , encoder_value = _get_qkv_projections (
204+ attn , hidden_states , encoder_hidden_states
205+ )
206+
207+ _ , key , value , _ , _ , _ = _get_qkv_projections (
208+ attn , hidden_states , encoder_hidden_states
209+ )
210+ else :
211+ query , key , value , encoder_query , encoder_key , encoder_value = _get_qkv_projections (
212+ attn , hidden_states , encoder_hidden_states
213+ )
190214
191215 query = query .unflatten (- 1 , (attn .heads , - 1 ))
192216 key = key .unflatten (- 1 , (attn .heads , - 1 ))
@@ -326,6 +350,7 @@ def __init__(
326350 def forward (
327351 self ,
328352 hidden_states : torch .Tensor ,
353+ other_hidden_states : Optional [torch .Tensor ] = None ,
329354 encoder_hidden_states : Optional [torch .Tensor ] = None ,
330355 attention_mask : Optional [torch .Tensor ] = None ,
331356 image_rotary_emb : Optional [torch .Tensor ] = None ,
@@ -339,7 +364,7 @@ def forward(
339364 f"joint_attention_kwargs { unused_kwargs } are not expected by { self .processor .__class__ .__name__ } and will be ignored."
340365 )
341366 kwargs = {k : w for k , w in kwargs .items () if k in attn_parameters }
342- return self .processor (self , hidden_states , encoder_hidden_states , attention_mask , image_rotary_emb , ** kwargs )
367+ return self .processor (self , hidden_states , other_hidden_states , encoder_hidden_states , attention_mask , image_rotary_emb , ** kwargs )
343368
344369
345370@maybe_allow_in_graph
@@ -367,8 +392,9 @@ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int,
367392 def forward (
368393 self ,
369394 hidden_states : torch .Tensor ,
370- encoder_hidden_states : torch .Tensor ,
371- temb : torch .Tensor ,
395+ other_hidden_states : Optional [torch .Tensor ]= None ,
396+ encoder_hidden_states : Optional [torch .Tensor ]= None ,
397+ temb : Optional [torch .Tensor ]= None ,
372398 image_rotary_emb : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
373399 joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
374400 ) -> Tuple [torch .Tensor , torch .Tensor ]:
@@ -381,6 +407,7 @@ def forward(
381407 joint_attention_kwargs = joint_attention_kwargs or {}
382408 attn_output = self .attn (
383409 hidden_states = norm_hidden_states ,
410+ other_hidden_states = other_hidden_states if other_hidden_states is not None else None ,
384411 image_rotary_emb = image_rotary_emb ,
385412 ** joint_attention_kwargs ,
386413 )
@@ -427,6 +454,7 @@ def __init__(
427454 def forward (
428455 self ,
429456 hidden_states : torch .Tensor ,
457+ other_hidden_states : torch .Tensor ,
430458 encoder_hidden_states : torch .Tensor ,
431459 temb : torch .Tensor ,
432460 image_rotary_emb : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
@@ -442,6 +470,7 @@ def forward(
442470 # Attention.
443471 attention_outputs = self .attn (
444472 hidden_states = norm_hidden_states ,
473+ other_hidden_states = other_hidden_states if other_hidden_states is not None else None ,
445474 encoder_hidden_states = norm_encoder_hidden_states ,
446475 image_rotary_emb = image_rotary_emb ,
447476 ** joint_attention_kwargs ,
@@ -521,36 +550,6 @@ class FluxTransformer2DModel(
521550 CacheMixin ,
522551 AttentionMixin ,
523552):
524- """
525- The Transformer model introduced in Flux.
526-
527- Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
528-
529- Args:
530- patch_size (`int`, defaults to `1`):
531- Patch size to turn the input data into small patches.
532- in_channels (`int`, defaults to `64`):
533- The number of channels in the input.
534- out_channels (`int`, *optional*, defaults to `None`):
535- The number of channels in the output. If not specified, it defaults to `in_channels`.
536- num_layers (`int`, defaults to `19`):
537- The number of layers of dual stream DiT blocks to use.
538- num_single_layers (`int`, defaults to `38`):
539- The number of layers of single stream DiT blocks to use.
540- attention_head_dim (`int`, defaults to `128`):
541- The number of dimensions to use for each attention head.
542- num_attention_heads (`int`, defaults to `24`):
543- The number of attention heads to use.
544- joint_attention_dim (`int`, defaults to `4096`):
545- The number of dimensions to use for the joint attention (embedding/channel dimension of
546- `encoder_hidden_states`).
547- pooled_projection_dim (`int`, defaults to `768`):
548- The number of dimensions to use for the pooled projection.
549- guidance_embeds (`bool`, defaults to `False`):
550- Whether to use guidance embeddings for guidance-distilled variant of the model.
551- axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
552- The dimensions to use for the rotary positional embeddings.
553- """
554553
555554 _supports_gradient_checkpointing = True
556555 _no_split_modules = ["FluxTransformerBlock" , "FluxSingleTransformerBlock" ]
@@ -571,10 +570,12 @@ def __init__(
571570 pooled_projection_dim : int = 768 ,
572571 guidance_embeds : bool = False ,
573572 axes_dims_rope : Tuple [int , int , int ] = (16 , 56 , 56 ),
573+ use_2nd_guider : bool = True
574574 ):
575575 super ().__init__ ()
576576 self .out_channels = out_channels or in_channels
577577 self .inner_dim = num_attention_heads * attention_head_dim
578+ self .use_2nd_guider = use_2nd_guider
578579
579580 self .pos_embed = FluxPosEmbed (theta = 10000 , axes_dim = axes_dims_rope )
580581
@@ -599,6 +600,23 @@ def __init__(
599600 ]
600601 )
601602
603+ if use_2nd_guider :
604+ self .transformer_blocks2 = nn .ModuleList (
605+ [
606+ FluxTransformerBlock (
607+ dim = self .inner_dim ,
608+ num_attention_heads = num_attention_heads ,
609+ attention_head_dim = attention_head_dim ,
610+ )
611+ for _ in range (num_layers )
612+ ]
613+ )
614+ else :
615+ self .transformer_blocks2 = []
616+ for i in range (len (self .transformer_blocks )):
617+ self .transformer_blocks2 .append (None )
618+
619+
602620 self .single_transformer_blocks = nn .ModuleList (
603621 [
604622 FluxSingleTransformerBlock (
@@ -610,6 +628,24 @@ def __init__(
610628 ]
611629 )
612630
631+ if use_2nd_guider :
632+
633+ self .single_transformer_blocks2 = nn .ModuleList (
634+ [
635+ FluxSingleTransformerBlock (
636+ dim = self .inner_dim ,
637+ num_attention_heads = num_attention_heads ,
638+ attention_head_dim = attention_head_dim ,
639+ )
640+ for _ in range (num_single_layers )
641+ ]
642+ )
643+ else :
644+ self .single_transformer_blocks2 = []
645+ for i in range (len (self .single_transformer_blocks )):
646+ self .single_transformer_blocks2 .append (None )
647+
648+
613649 self .norm_out = AdaLayerNormContinuous (self .inner_dim , self .inner_dim , elementwise_affine = False , eps = 1e-6 )
614650 self .proj_out = nn .Linear (self .inner_dim , patch_size * patch_size * self .out_channels , bias = True )
615651
@@ -618,6 +654,7 @@ def __init__(
618654 def forward (
619655 self ,
620656 hidden_states : torch .Tensor ,
657+ other_hidden_states : torch .Tensor = None ,
621658 encoder_hidden_states : torch .Tensor = None ,
622659 pooled_projections : torch .Tensor = None ,
623660 timestep : torch .LongTensor = None ,
@@ -672,6 +709,10 @@ def forward(
672709 )
673710
674711 hidden_states = self .x_embedder (hidden_states )
712+
713+ if other_hidden_states is not None :
714+ # other states
715+ other_hidden_states = self .x_embedder (other_hidden_states )
675716
676717 timestep = timestep .to (hidden_states .dtype ) * 1000
677718 if guidance is not None :
@@ -705,26 +746,49 @@ def forward(
705746 ip_hidden_states = self .encoder_hid_proj (ip_adapter_image_embeds )
706747 joint_attention_kwargs .update ({"ip_hidden_states" : ip_hidden_states })
707748
708- for index_block , block in enumerate (self .transformer_blocks ):
749+
750+ for index_block , (block , block2 ) in enumerate (zip (self .transformer_blocks , self .transformer_blocks2 )):
709751 if torch .is_grad_enabled () and self .gradient_checkpointing :
710752 encoder_hidden_states , hidden_states = self ._gradient_checkpointing_func (
711753 block ,
712754 hidden_states ,
755+ other_hidden_states if other_hidden_states is not None else None ,
713756 encoder_hidden_states ,
714757 temb ,
715758 image_rotary_emb ,
716759 joint_attention_kwargs ,
717760 )
761+ if other_hidden_states is not None :
762+ encoder_hidden_states , other_hidden_states = self ._gradient_checkpointing_func (
763+ block2 ,
764+ other_hidden_states ,
765+ hidden_states ,
766+ encoder_hidden_states ,
767+ temb ,
768+ image_rotary_emb ,
769+ joint_attention_kwargs ,
770+ )
718771
719772 else :
720773 encoder_hidden_states , hidden_states = block (
721774 hidden_states = hidden_states ,
775+ other_hidden_states = other_hidden_states if other_hidden_states is not None else None ,
722776 encoder_hidden_states = encoder_hidden_states ,
723777 temb = temb ,
724778 image_rotary_emb = image_rotary_emb ,
725779 joint_attention_kwargs = joint_attention_kwargs ,
726780 )
727781
782+ if other_hidden_states is not None :
783+ encoder_hidden_states , other_hidden_states = block2 (
784+ hidden_states = other_hidden_states ,
785+ other_hidden_states = hidden_states ,
786+ encoder_hidden_states = encoder_hidden_states ,
787+ temb = temb ,
788+ image_rotary_emb = image_rotary_emb ,
789+ joint_attention_kwargs = joint_attention_kwargs ,
790+ )
791+
728792 # controlnet residual
729793 if controlnet_block_samples is not None :
730794 interval_control = len (self .transformer_blocks ) / len (controlnet_block_samples )
@@ -737,8 +801,9 @@ def forward(
737801 else :
738802 hidden_states = hidden_states + controlnet_block_samples [index_block // interval_control ]
739803
740- for index_block , block in enumerate (self .single_transformer_blocks ):
804+ for index_block , ( block , block2 ) in enumerate (zip ( self .single_transformer_blocks , self . single_transformer_blocks2 ) ):
741805 if torch .is_grad_enabled () and self .gradient_checkpointing :
806+
742807 encoder_hidden_states , hidden_states = self ._gradient_checkpointing_func (
743808 block ,
744809 hidden_states ,
@@ -748,15 +813,39 @@ def forward(
748813 joint_attention_kwargs ,
749814 )
750815
816+ if other_hidden_states is not None :
817+
818+ encoder_hidden_states , other_hidden_states = self ._gradient_checkpointing_func (
819+ block2 ,
820+ other_hidden_states ,
821+ hidden_states ,
822+ encoder_hidden_states ,
823+ temb ,
824+ image_rotary_emb ,
825+ joint_attention_kwargs ,
826+ )
827+
828+
751829 else :
752830 encoder_hidden_states , hidden_states = block (
753831 hidden_states = hidden_states ,
832+ other_hidden_states = other_hidden_states if other_hidden_states is not None else None ,
754833 encoder_hidden_states = encoder_hidden_states ,
755834 temb = temb ,
756835 image_rotary_emb = image_rotary_emb ,
757836 joint_attention_kwargs = joint_attention_kwargs ,
758837 )
759838
839+ if other_hidden_states is not None :
840+ encoder_hidden_states , other_hidden_states = block2 (
841+ hidden_states = other_hidden_states ,
842+ other_hidden_states = hidden_states ,
843+ encoder_hidden_states = encoder_hidden_states ,
844+ temb = temb ,
845+ image_rotary_emb = image_rotary_emb ,
846+ joint_attention_kwargs = joint_attention_kwargs ,
847+ )
848+
760849 # controlnet residual
761850 if controlnet_single_block_samples is not None :
762851 interval_control = len (self .single_transformer_blocks ) / len (controlnet_single_block_samples )
@@ -766,11 +855,44 @@ def forward(
766855 hidden_states = self .norm_out (hidden_states , temb )
767856 output = self .proj_out (hidden_states )
768857
858+ if other_hidden_states is not None :
859+ other_hidden_states = self .norm_out (other_hidden_states , temb )
860+ other_output = self .proj_out (other_hidden_states )
861+
769862 if USE_PEFT_BACKEND :
770863 # remove `lora_scale` from each PEFT layer
771864 unscale_lora_layers (self , lora_scale )
772865
773- if not return_dict :
774- return (output ,)
866+ if other_hidden_states is not None :
867+ if not return_dict :
868+ return (output , other_output )
775869
776- return Transformer2DModelOutput (sample = output )
870+ return Transformer2DModelOutput (sample = (output , other_output ))
871+ else :
872+ if not return_dict :
873+ return (output ,)
874+
875+ return Transformer2DModelOutput (sample = (output ,))
876+
877+ @classmethod
878+ def from_pretrained (cls , pretrained_model_name_or_path , use_2nd_guider , * model_args , ** kwargs ):
879+ # Step A: load model normally
880+ model = super ().from_pretrained (pretrained_model_name_or_path , * model_args , ** kwargs )
881+
882+ # Step B: copy weights into the new transformer_blocks2
883+ if use_2nd_guider :
884+ if hasattr (model , "transformer_blocks2" ):
885+ with torch .no_grad ():
886+ for b2 , b1 in zip (model .transformer_blocks2 , model .transformer_blocks ):
887+ for (_ , p2 ), (_ , p1 ) in zip (b2 .named_parameters (), b1 .named_parameters ()):
888+ p2 .copy_ (p1 )
889+ print ("double_block weights loaded Yayy !!!!" )
890+
891+ if hasattr (model , "single_transformer_blocks2" ):
892+ with torch .no_grad ():
893+ for b2 , b1 in zip (model .single_transformer_blocks2 , model .single_transformer_blocks ):
894+ for (_ , p2 ), (_ , p1 ) in zip (b2 .named_parameters (), b1 .named_parameters ()):
895+ p2 .copy_ (p1 )
896+ print ("Single_block weights loaded Yayy !!!!" )
897+
898+ return model
0 commit comments