Skip to content

Commit f6c82a3

Browse files
committed
Fix some bugs in Flux 2 transformer implementation
1 parent 89e42d9 commit f6c82a3

File tree

1 file changed

+30
-15
lines changed

1 file changed

+30
-15
lines changed

src/diffusers/models/transformers/transformer_flux2.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,10 @@ def __call__(
133133
if attn.parallel_proj_in:
134134
hidden_states = attn.to_qkv_mlp_proj(hidden_states)
135135
qkv, mlp_hidden_states = torch.split(
136-
hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor]
136+
hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1
137137
)
138138
query, key, value = qkv.chunk(3, dim=-1)
139-
mlp_hidden_states = self.mlp_act_fn(mlp_hidden_states)
139+
mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states)
140140

141141
# Get encoder QKV, if available
142142
encoder_query = encoder_key = encoder_value = None
@@ -423,6 +423,7 @@ def forward(
423423
) -> Tuple[torch.Tensor, torch.Tensor]:
424424
joint_attention_kwargs = joint_attention_kwargs or {}
425425

426+
# Modulation parameters shape: [1, 1, self.dim]
426427
(shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = temb_mod_params_img
427428
(c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = temb_mod_params_txt
428429

@@ -448,27 +449,27 @@ def forward(
448449
attn_output, context_attn_output, ip_attn_output = attention_outputs
449450

450451
# Process attention outputs for the image stream (`hidden_states`).
451-
attn_output = gate_msa.unsqueeze(1) * attn_output
452+
attn_output = gate_msa * attn_output
452453
hidden_states = hidden_states + attn_output
453454

454455
norm_hidden_states = self.norm2(hidden_states)
455-
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
456+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
456457

457458
ff_output = self.ff(norm_hidden_states)
458-
hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
459+
hidden_states = hidden_states + gate_mlp * ff_output
459460

460461
if len(attention_outputs) == 3:
461462
hidden_states = hidden_states + ip_attn_output
462463

463464
# Process attention outputs for the text stream (`encoder_hidden_states`).
464-
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
465+
context_attn_output = c_gate_msa * context_attn_output
465466
encoder_hidden_states = encoder_hidden_states + context_attn_output
466467

467468
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
468-
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
469+
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp
469470

470471
context_ff_output = self.ff_context(norm_encoder_hidden_states)
471-
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
472+
encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
472473
if encoder_hidden_states.dtype == torch.float16:
473474
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
474475

@@ -483,6 +484,7 @@ def __init__(self, theta: int, axes_dim: List[int]):
483484
self.axes_dim = axes_dim
484485

485486
def forward(self, ids: torch.Tensor) -> torch.Tensor:
487+
# Expected ids shape: [S, len(self.axes_dim)]
486488
cos_out = []
487489
sin_out = []
488490
pos = ids.float()
@@ -493,7 +495,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
493495
for i in range(len(self.axes_dim)):
494496
cos, sin = get_1d_rotary_pos_embed(
495497
self.axes_dim[i],
496-
pos[:, i],
498+
pos[..., i],
497499
theta=self.theta,
498500
repeat_interleave_real=True,
499501
use_real=True,
@@ -736,6 +738,8 @@ def forward(
736738
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
737739
)
738740

741+
num_txt_tokens = encoder_hidden_states.shape[1]
742+
739743
# 1. Calculate timestep embedding and modulation parameters
740744
timestep = timestep.to(hidden_states.dtype) * 1000
741745
guidance = guidance.to(hidden_states.dtype) * 1000
@@ -751,6 +755,13 @@ def forward(
751755
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
752756

753757
# 3. Calculate RoPE embeddings from image and text tokens
758+
# NOTE: the below logic means that we can't support batched inference with images of different resolutions or
759+
# text prompts of differents lengths. Is this a use case we want to support?
760+
if img_ids.ndim == 3:
761+
img_ids = img_ids[0]
762+
if txt_ids.ndim == 3:
763+
txt_ids = txt_ids[0]
764+
754765
if is_torch_npu_available():
755766
freqs_cos_image, freqs_sin_image = self.pos_embed(img_ids.cpu())
756767
image_rotary_emb = (freqs_cos_image.npu(), freqs_sin_image.npu())
@@ -760,8 +771,8 @@ def forward(
760771
image_rotary_emb = self.pos_embed(img_ids)
761772
text_rotary_emb = self.pos_embed(txt_ids)
762773
concat_rotary_emb = (
763-
torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=2),
764-
torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=2),
774+
torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
775+
torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
765776
)
766777

767778
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
@@ -790,26 +801,30 @@ def forward(
790801
image_rotary_emb=concat_rotary_emb,
791802
joint_attention_kwargs=joint_attention_kwargs,
792803
)
804+
# Concatenate text and image streams for single-block inference
805+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
793806

794807
# 5. Single Stream Transformer Blocks
795808
for index_block, block in enumerate(self.single_transformer_blocks):
796809
if torch.is_grad_enabled() and self.gradient_checkpointing:
797-
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
810+
hidden_states = self._gradient_checkpointing_func(
798811
block,
799812
hidden_states,
800-
encoder_hidden_states,
813+
None,
801814
single_stream_mod,
802815
concat_rotary_emb,
803816
joint_attention_kwargs,
804817
)
805818
else:
806-
encoder_hidden_states, hidden_states = block(
819+
hidden_states = block(
807820
hidden_states=hidden_states,
808-
encoder_hidden_states=encoder_hidden_states,
821+
encoder_hidden_states=None,
809822
temb_mod_params=single_stream_mod,
810823
image_rotary_emb=concat_rotary_emb,
811824
joint_attention_kwargs=joint_attention_kwargs,
812825
)
826+
# Remove text tokens from concatenated stream
827+
hidden_states = hidden_states[:, num_txt_tokens:, ...]
813828

814829
# 6. Output layers
815830
hidden_states = self.norm_out(hidden_states, temb)

0 commit comments

Comments
 (0)