Skip to content

Commit 246d715

Browse files
author
qyo9735
committed
added the changes to the flux transformer model
1 parent 4067d6c commit 246d715

File tree

3 files changed

+188
-44
lines changed

3 files changed

+188
-44
lines changed

flux-schnell.png

1.34 MB
Loading

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 166 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,26 @@ def __call__(
8282
self,
8383
attn: "FluxAttention",
8484
hidden_states: torch.Tensor,
85+
other_hidden_states: torch.Tensor,
8586
encoder_hidden_states: torch.Tensor = None,
8687
attention_mask: Optional[torch.Tensor] = None,
8788
image_rotary_emb: Optional[torch.Tensor] = None,
8889
) -> torch.Tensor:
89-
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
90-
attn, hidden_states, encoder_hidden_states
91-
)
90+
91+
if other_hidden_states is not None:
92+
query, _, _, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
93+
attn, hidden_states, encoder_hidden_states
94+
)
95+
96+
_, key, value, _, _, _ = _get_qkv_projections(
97+
attn, hidden_states, encoder_hidden_states
98+
)
99+
else:
100+
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
101+
attn, hidden_states, encoder_hidden_states
102+
)
103+
104+
92105

93106
query = query.unflatten(-1, (attn.heads, -1))
94107
key = key.unflatten(-1, (attn.heads, -1))
@@ -176,6 +189,7 @@ def __call__(
176189
self,
177190
attn: "FluxAttention",
178191
hidden_states: torch.Tensor,
192+
other_hidden_states: torch.Tensor,
179193
encoder_hidden_states: torch.Tensor = None,
180194
attention_mask: Optional[torch.Tensor] = None,
181195
image_rotary_emb: Optional[torch.Tensor] = None,
@@ -184,9 +198,19 @@ def __call__(
184198
) -> torch.Tensor:
185199
batch_size = hidden_states.shape[0]
186200

187-
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
188-
attn, hidden_states, encoder_hidden_states
189-
)
201+
202+
if other_hidden_states is not None:
203+
query, _, _, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
204+
attn, hidden_states, encoder_hidden_states
205+
)
206+
207+
_, key, value, _, _, _ = _get_qkv_projections(
208+
attn, hidden_states, encoder_hidden_states
209+
)
210+
else:
211+
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
212+
attn, hidden_states, encoder_hidden_states
213+
)
190214

191215
query = query.unflatten(-1, (attn.heads, -1))
192216
key = key.unflatten(-1, (attn.heads, -1))
@@ -326,6 +350,7 @@ def __init__(
326350
def forward(
327351
self,
328352
hidden_states: torch.Tensor,
353+
other_hidden_states: Optional[torch.Tensor] = None,
329354
encoder_hidden_states: Optional[torch.Tensor] = None,
330355
attention_mask: Optional[torch.Tensor] = None,
331356
image_rotary_emb: Optional[torch.Tensor] = None,
@@ -339,7 +364,7 @@ def forward(
339364
f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
340365
)
341366
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
342-
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
367+
return self.processor(self, hidden_states, other_hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
343368

344369

345370
@maybe_allow_in_graph
@@ -367,8 +392,9 @@ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int,
367392
def forward(
368393
self,
369394
hidden_states: torch.Tensor,
370-
encoder_hidden_states: torch.Tensor,
371-
temb: torch.Tensor,
395+
other_hidden_states: Optional[torch.Tensor]=None,
396+
encoder_hidden_states: Optional[torch.Tensor]=None,
397+
temb: Optional[torch.Tensor]=None,
372398
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
373399
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
374400
) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -381,6 +407,7 @@ def forward(
381407
joint_attention_kwargs = joint_attention_kwargs or {}
382408
attn_output = self.attn(
383409
hidden_states=norm_hidden_states,
410+
other_hidden_states=other_hidden_states if other_hidden_states is not None else None,
384411
image_rotary_emb=image_rotary_emb,
385412
**joint_attention_kwargs,
386413
)
@@ -427,6 +454,7 @@ def __init__(
427454
def forward(
428455
self,
429456
hidden_states: torch.Tensor,
457+
other_hidden_states: torch.Tensor,
430458
encoder_hidden_states: torch.Tensor,
431459
temb: torch.Tensor,
432460
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
@@ -442,6 +470,7 @@ def forward(
442470
# Attention.
443471
attention_outputs = self.attn(
444472
hidden_states=norm_hidden_states,
473+
other_hidden_states=other_hidden_states if other_hidden_states is not None else None,
445474
encoder_hidden_states=norm_encoder_hidden_states,
446475
image_rotary_emb=image_rotary_emb,
447476
**joint_attention_kwargs,
@@ -521,36 +550,6 @@ class FluxTransformer2DModel(
521550
CacheMixin,
522551
AttentionMixin,
523552
):
524-
"""
525-
The Transformer model introduced in Flux.
526-
527-
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
528-
529-
Args:
530-
patch_size (`int`, defaults to `1`):
531-
Patch size to turn the input data into small patches.
532-
in_channels (`int`, defaults to `64`):
533-
The number of channels in the input.
534-
out_channels (`int`, *optional*, defaults to `None`):
535-
The number of channels in the output. If not specified, it defaults to `in_channels`.
536-
num_layers (`int`, defaults to `19`):
537-
The number of layers of dual stream DiT blocks to use.
538-
num_single_layers (`int`, defaults to `38`):
539-
The number of layers of single stream DiT blocks to use.
540-
attention_head_dim (`int`, defaults to `128`):
541-
The number of dimensions to use for each attention head.
542-
num_attention_heads (`int`, defaults to `24`):
543-
The number of attention heads to use.
544-
joint_attention_dim (`int`, defaults to `4096`):
545-
The number of dimensions to use for the joint attention (embedding/channel dimension of
546-
`encoder_hidden_states`).
547-
pooled_projection_dim (`int`, defaults to `768`):
548-
The number of dimensions to use for the pooled projection.
549-
guidance_embeds (`bool`, defaults to `False`):
550-
Whether to use guidance embeddings for guidance-distilled variant of the model.
551-
axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
552-
The dimensions to use for the rotary positional embeddings.
553-
"""
554553

555554
_supports_gradient_checkpointing = True
556555
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
@@ -571,10 +570,12 @@ def __init__(
571570
pooled_projection_dim: int = 768,
572571
guidance_embeds: bool = False,
573572
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
573+
use_2nd_guider: bool = True
574574
):
575575
super().__init__()
576576
self.out_channels = out_channels or in_channels
577577
self.inner_dim = num_attention_heads * attention_head_dim
578+
self.use_2nd_guider = use_2nd_guider
578579

579580
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
580581

@@ -599,6 +600,23 @@ def __init__(
599600
]
600601
)
601602

603+
if use_2nd_guider:
604+
self.transformer_blocks2 = nn.ModuleList(
605+
[
606+
FluxTransformerBlock(
607+
dim=self.inner_dim,
608+
num_attention_heads=num_attention_heads,
609+
attention_head_dim=attention_head_dim,
610+
)
611+
for _ in range(num_layers)
612+
]
613+
)
614+
else:
615+
self.transformer_blocks2 = []
616+
for i in range(len(self.transformer_blocks)):
617+
self.transformer_blocks2.append(None)
618+
619+
602620
self.single_transformer_blocks = nn.ModuleList(
603621
[
604622
FluxSingleTransformerBlock(
@@ -610,6 +628,24 @@ def __init__(
610628
]
611629
)
612630

631+
if use_2nd_guider:
632+
633+
self.single_transformer_blocks2 = nn.ModuleList(
634+
[
635+
FluxSingleTransformerBlock(
636+
dim=self.inner_dim,
637+
num_attention_heads=num_attention_heads,
638+
attention_head_dim=attention_head_dim,
639+
)
640+
for _ in range(num_single_layers)
641+
]
642+
)
643+
else:
644+
self.single_transformer_blocks2 = []
645+
for i in range(len(self.single_transformer_blocks)):
646+
self.single_transformer_blocks2.append(None)
647+
648+
613649
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
614650
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
615651

@@ -618,6 +654,7 @@ def __init__(
618654
def forward(
619655
self,
620656
hidden_states: torch.Tensor,
657+
other_hidden_states: torch.Tensor = None,
621658
encoder_hidden_states: torch.Tensor = None,
622659
pooled_projections: torch.Tensor = None,
623660
timestep: torch.LongTensor = None,
@@ -672,6 +709,10 @@ def forward(
672709
)
673710

674711
hidden_states = self.x_embedder(hidden_states)
712+
713+
if other_hidden_states is not None:
714+
# other states
715+
other_hidden_states= self.x_embedder(other_hidden_states)
675716

676717
timestep = timestep.to(hidden_states.dtype) * 1000
677718
if guidance is not None:
@@ -705,26 +746,49 @@ def forward(
705746
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
706747
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
707748

708-
for index_block, block in enumerate(self.transformer_blocks):
749+
750+
for index_block, (block, block2) in enumerate(zip(self.transformer_blocks, self.transformer_blocks2)):
709751
if torch.is_grad_enabled() and self.gradient_checkpointing:
710752
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
711753
block,
712754
hidden_states,
755+
other_hidden_states if other_hidden_states is not None else None,
713756
encoder_hidden_states,
714757
temb,
715758
image_rotary_emb,
716759
joint_attention_kwargs,
717760
)
761+
if other_hidden_states is not None:
762+
encoder_hidden_states, other_hidden_states = self._gradient_checkpointing_func(
763+
block2,
764+
other_hidden_states,
765+
hidden_states,
766+
encoder_hidden_states,
767+
temb,
768+
image_rotary_emb,
769+
joint_attention_kwargs,
770+
)
718771

719772
else:
720773
encoder_hidden_states, hidden_states = block(
721774
hidden_states=hidden_states,
775+
other_hidden_states=other_hidden_states if other_hidden_states is not None else None,
722776
encoder_hidden_states=encoder_hidden_states,
723777
temb=temb,
724778
image_rotary_emb=image_rotary_emb,
725779
joint_attention_kwargs=joint_attention_kwargs,
726780
)
727781

782+
if other_hidden_states is not None:
783+
encoder_hidden_states, other_hidden_states = block2(
784+
hidden_states=other_hidden_states,
785+
other_hidden_states=hidden_states,
786+
encoder_hidden_states=encoder_hidden_states,
787+
temb=temb,
788+
image_rotary_emb=image_rotary_emb,
789+
joint_attention_kwargs=joint_attention_kwargs,
790+
)
791+
728792
# controlnet residual
729793
if controlnet_block_samples is not None:
730794
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
@@ -737,8 +801,9 @@ def forward(
737801
else:
738802
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
739803

740-
for index_block, block in enumerate(self.single_transformer_blocks):
804+
for index_block, (block, block2) in enumerate(zip(self.single_transformer_blocks, self.single_transformer_blocks2)):
741805
if torch.is_grad_enabled() and self.gradient_checkpointing:
806+
742807
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
743808
block,
744809
hidden_states,
@@ -748,15 +813,39 @@ def forward(
748813
joint_attention_kwargs,
749814
)
750815

816+
if other_hidden_states is not None:
817+
818+
encoder_hidden_states, other_hidden_states = self._gradient_checkpointing_func(
819+
block2,
820+
other_hidden_states,
821+
hidden_states,
822+
encoder_hidden_states,
823+
temb,
824+
image_rotary_emb,
825+
joint_attention_kwargs,
826+
)
827+
828+
751829
else:
752830
encoder_hidden_states, hidden_states = block(
753831
hidden_states=hidden_states,
832+
other_hidden_states=other_hidden_states if other_hidden_states is not None else None,
754833
encoder_hidden_states=encoder_hidden_states,
755834
temb=temb,
756835
image_rotary_emb=image_rotary_emb,
757836
joint_attention_kwargs=joint_attention_kwargs,
758837
)
759838

839+
if other_hidden_states is not None:
840+
encoder_hidden_states, other_hidden_states = block2(
841+
hidden_states=other_hidden_states,
842+
other_hidden_states=hidden_states,
843+
encoder_hidden_states=encoder_hidden_states,
844+
temb=temb,
845+
image_rotary_emb=image_rotary_emb,
846+
joint_attention_kwargs=joint_attention_kwargs,
847+
)
848+
760849
# controlnet residual
761850
if controlnet_single_block_samples is not None:
762851
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
@@ -766,11 +855,44 @@ def forward(
766855
hidden_states = self.norm_out(hidden_states, temb)
767856
output = self.proj_out(hidden_states)
768857

858+
if other_hidden_states is not None:
859+
other_hidden_states = self.norm_out(other_hidden_states, temb)
860+
other_output = self.proj_out(other_hidden_states)
861+
769862
if USE_PEFT_BACKEND:
770863
# remove `lora_scale` from each PEFT layer
771864
unscale_lora_layers(self, lora_scale)
772865

773-
if not return_dict:
774-
return (output,)
866+
if other_hidden_states is not None:
867+
if not return_dict:
868+
return (output, other_output)
775869

776-
return Transformer2DModelOutput(sample=output)
870+
return Transformer2DModelOutput(sample=(output, other_output))
871+
else:
872+
if not return_dict:
873+
return (output,)
874+
875+
return Transformer2DModelOutput(sample=(output,))
876+
877+
@classmethod
878+
def from_pretrained(cls, pretrained_model_name_or_path, use_2nd_guider, *model_args, **kwargs):
879+
# Step A: load model normally
880+
model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
881+
882+
# Step B: copy weights into the new transformer_blocks2
883+
if use_2nd_guider:
884+
if hasattr(model, "transformer_blocks2"):
885+
with torch.no_grad():
886+
for b2, b1 in zip(model.transformer_blocks2, model.transformer_blocks):
887+
for (_, p2), (_, p1) in zip(b2.named_parameters(), b1.named_parameters()):
888+
p2.copy_(p1)
889+
print("double_block weights loaded Yayy !!!!")
890+
891+
if hasattr(model, "single_transformer_blocks2"):
892+
with torch.no_grad():
893+
for b2, b1 in zip(model.single_transformer_blocks2, model.single_transformer_blocks):
894+
for (_, p2), (_, p1) in zip(b2.named_parameters(), b1.named_parameters()):
895+
p2.copy_(p1)
896+
print("Single_block weights loaded Yayy !!!!")
897+
898+
return model

0 commit comments

Comments
 (0)