Skip to content

Commit c1abec6

Browse files
committed
🤷🏻‍♂️
1 parent 745bcec commit c1abec6

File tree

1 file changed

+78
-117
lines changed

1 file changed

+78
-117
lines changed

src/diffusers/models/transformers/transformer_hidream_image.py

Lines changed: 78 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -183,16 +183,16 @@ def __init__(
183183

184184
def forward(
185185
self,
186-
norm_image_tokens: torch.FloatTensor,
187-
image_tokens_masks: torch.FloatTensor = None,
188-
norm_text_tokens: torch.FloatTensor = None,
186+
norm_hidden_states: torch.FloatTensor,
187+
hidden_states_masks: torch.FloatTensor = None,
188+
norm_encoder_hidden_states: torch.FloatTensor = None,
189189
image_rotary_emb: torch.FloatTensor = None,
190190
) -> torch.Tensor:
191191
return self.processor(
192192
self,
193-
image_tokens=norm_image_tokens,
194-
image_tokens_masks=image_tokens_masks,
195-
text_tokens=norm_text_tokens,
193+
hidden_states=norm_hidden_states,
194+
hidden_states_masks=hidden_states_masks,
195+
encoder_hidden_states=norm_encoder_hidden_states,
196196
image_rotary_emb=image_rotary_emb,
197197
)
198198

@@ -203,33 +203,33 @@ class HiDreamAttnProcessor:
203203
def __call__(
204204
self,
205205
attn: HiDreamAttention,
206-
image_tokens: torch.FloatTensor,
207-
image_tokens_masks: Optional[torch.FloatTensor] = None,
208-
text_tokens: Optional[torch.FloatTensor] = None,
206+
hidden_states: torch.FloatTensor,
207+
hidden_states_masks: Optional[torch.FloatTensor] = None,
208+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
209209
image_rotary_emb: torch.FloatTensor = None,
210210
*args,
211211
**kwargs,
212212
) -> torch.FloatTensor:
213-
dtype = image_tokens.dtype
214-
batch_size = image_tokens.shape[0]
213+
dtype = hidden_states.dtype
214+
batch_size = hidden_states.shape[0]
215215

216-
query_i = attn.q_rms_norm(attn.to_q(image_tokens)).to(dtype=dtype)
217-
key_i = attn.k_rms_norm(attn.to_k(image_tokens)).to(dtype=dtype)
218-
value_i = attn.to_v(image_tokens)
216+
query_i = attn.q_rms_norm(attn.to_q(hidden_states)).to(dtype=dtype)
217+
key_i = attn.k_rms_norm(attn.to_k(hidden_states)).to(dtype=dtype)
218+
value_i = attn.to_v(hidden_states)
219219

220220
inner_dim = key_i.shape[-1]
221221
head_dim = inner_dim // attn.heads
222222

223223
query_i = query_i.view(batch_size, -1, attn.heads, head_dim)
224224
key_i = key_i.view(batch_size, -1, attn.heads, head_dim)
225225
value_i = value_i.view(batch_size, -1, attn.heads, head_dim)
226-
if image_tokens_masks is not None:
227-
key_i = key_i * image_tokens_masks.view(batch_size, -1, 1, 1)
226+
if hidden_states_masks is not None:
227+
key_i = key_i * hidden_states_masks.view(batch_size, -1, 1, 1)
228228

229229
if not attn.single:
230-
query_t = attn.q_rms_norm_t(attn.to_q_t(text_tokens)).to(dtype=dtype)
231-
key_t = attn.k_rms_norm_t(attn.to_k_t(text_tokens)).to(dtype=dtype)
232-
value_t = attn.to_v_t(text_tokens)
230+
query_t = attn.q_rms_norm_t(attn.to_q_t(encoder_hidden_states)).to(dtype=dtype)
231+
key_t = attn.k_rms_norm_t(attn.to_k_t(encoder_hidden_states)).to(dtype=dtype)
232+
value_t = attn.to_v_t(encoder_hidden_states)
233233

234234
query_t = query_t.view(batch_size, -1, attn.heads, head_dim)
235235
key_t = key_t.view(batch_size, -1, attn.heads, head_dim)
@@ -454,33 +454,33 @@ def __init__(
454454

455455
def forward(
456456
self,
457-
image_tokens: torch.FloatTensor,
458-
image_tokens_masks: Optional[torch.FloatTensor] = None,
459-
text_tokens: Optional[torch.FloatTensor] = None,
457+
hidden_states: torch.FloatTensor,
458+
hidden_states_masks: Optional[torch.FloatTensor] = None,
459+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
460460
adaln_input: Optional[torch.FloatTensor] = None,
461461
image_rotary_emb: torch.FloatTensor = None,
462462
) -> torch.FloatTensor:
463-
wtype = image_tokens.dtype
463+
wtype = hidden_states.dtype
464464
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = self.adaLN_modulation(
465465
adaln_input
466466
)[:, None].chunk(6, dim=-1)
467467

468468
# 1. MM-Attention
469-
norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype)
470-
norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i
469+
norm_hidden_states = self.norm1_i(hidden_states).to(dtype=wtype)
470+
norm_hidden_states = norm_hidden_states * (1 + scale_msa_i) + shift_msa_i
471471
attn_output_i = self.attn1(
472-
norm_image_tokens,
473-
image_tokens_masks,
472+
norm_hidden_states,
473+
hidden_states_masks,
474474
image_rotary_emb=image_rotary_emb,
475475
)
476-
image_tokens = gate_msa_i * attn_output_i + image_tokens
476+
hidden_states = gate_msa_i * attn_output_i + hidden_states
477477

478478
# 2. Feed-forward
479-
norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype)
480-
norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i
481-
ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens.to(dtype=wtype))
482-
image_tokens = ff_output_i + image_tokens
483-
return image_tokens
479+
norm_hidden_states = self.norm3_i(hidden_states).to(dtype=wtype)
480+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp_i) + shift_mlp_i
481+
ff_output_i = gate_mlp_i * self.ff_i(norm_hidden_states.to(dtype=wtype))
482+
hidden_states = ff_output_i + hidden_states
483+
return hidden_states
484484

485485

486486
@maybe_allow_in_graph
@@ -526,13 +526,13 @@ def __init__(
526526

527527
def forward(
528528
self,
529-
image_tokens: torch.FloatTensor,
530-
image_tokens_masks: Optional[torch.FloatTensor] = None,
531-
text_tokens: Optional[torch.FloatTensor] = None,
529+
hidden_states: torch.FloatTensor,
530+
hidden_states_masks: Optional[torch.FloatTensor] = None,
531+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
532532
adaln_input: Optional[torch.FloatTensor] = None,
533533
image_rotary_emb: torch.FloatTensor = None,
534534
) -> torch.FloatTensor:
535-
wtype = image_tokens.dtype
535+
wtype = hidden_states.dtype
536536
(
537537
shift_msa_i,
538538
scale_msa_i,
@@ -549,74 +549,37 @@ def forward(
549549
) = self.adaLN_modulation(adaln_input)[:, None].chunk(12, dim=-1)
550550

551551
# 1. MM-Attention
552-
norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype)
553-
norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i
554-
norm_text_tokens = self.norm1_t(text_tokens).to(dtype=wtype)
555-
norm_text_tokens = norm_text_tokens * (1 + scale_msa_t) + shift_msa_t
552+
norm_hidden_states = self.norm1_i(hidden_states).to(dtype=wtype)
553+
norm_hidden_states = norm_hidden_states * (1 + scale_msa_i) + shift_msa_i
554+
norm_encoder_hidden_states = self.norm1_t(encoder_hidden_states).to(dtype=wtype)
555+
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + scale_msa_t) + shift_msa_t
556556

557557
attn_output_i, attn_output_t = self.attn1(
558-
norm_image_tokens,
559-
image_tokens_masks,
560-
norm_text_tokens,
558+
norm_hidden_states,
559+
hidden_states_masks,
560+
norm_encoder_hidden_states,
561561
image_rotary_emb=image_rotary_emb,
562562
)
563563

564-
image_tokens = gate_msa_i * attn_output_i + image_tokens
565-
text_tokens = gate_msa_t * attn_output_t + text_tokens
564+
hidden_states = gate_msa_i * attn_output_i + hidden_states
565+
encoder_hidden_states = gate_msa_t * attn_output_t + encoder_hidden_states
566566

567567
# 2. Feed-forward
568-
norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype)
569-
norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i
570-
norm_text_tokens = self.norm3_t(text_tokens).to(dtype=wtype)
571-
norm_text_tokens = norm_text_tokens * (1 + scale_mlp_t) + shift_mlp_t
572-
573-
ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens)
574-
ff_output_t = gate_mlp_t * self.ff_t(norm_text_tokens)
575-
image_tokens = ff_output_i + image_tokens
576-
text_tokens = ff_output_t + text_tokens
577-
return image_tokens, text_tokens
568+
norm_hidden_states = self.norm3_i(hidden_states).to(dtype=wtype)
569+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp_i) + shift_mlp_i
570+
norm_encoder_hidden_states = self.norm3_t(encoder_hidden_states).to(dtype=wtype)
571+
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + scale_mlp_t) + shift_mlp_t
578572

579-
580-
@maybe_allow_in_graph
581-
class HiDreamImageBlock(nn.Module):
582-
def __init__(
583-
self,
584-
dim: int,
585-
num_attention_heads: int,
586-
attention_head_dim: int,
587-
num_routed_experts: int = 4,
588-
num_activated_experts: int = 2,
589-
block_type: BlockType = BlockType.TransformerBlock,
590-
):
591-
super().__init__()
592-
block_classes = {
593-
BlockType.TransformerBlock: HiDreamImageTransformerBlock,
594-
BlockType.SingleTransformerBlock: HiDreamImageSingleTransformerBlock,
595-
}
596-
self.block = block_classes[block_type](
597-
dim, num_attention_heads, attention_head_dim, num_routed_experts, num_activated_experts
598-
)
599-
600-
def forward(
601-
self,
602-
image_tokens: torch.FloatTensor,
603-
image_tokens_masks: Optional[torch.FloatTensor] = None,
604-
text_tokens: Optional[torch.FloatTensor] = None,
605-
adaln_input: torch.FloatTensor = None,
606-
image_rotary_emb: torch.FloatTensor = None,
607-
) -> torch.FloatTensor:
608-
return self.block(
609-
image_tokens,
610-
image_tokens_masks,
611-
text_tokens,
612-
adaln_input,
613-
image_rotary_emb,
614-
)
573+
ff_output_i = gate_mlp_i * self.ff_i(norm_hidden_states)
574+
ff_output_t = gate_mlp_t * self.ff_t(norm_encoder_hidden_states)
575+
hidden_states = ff_output_i + hidden_states
576+
encoder_hidden_states = ff_output_t + encoder_hidden_states
577+
return hidden_states, encoder_hidden_states
615578

616579

617580
class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
618581
_supports_gradient_checkpointing = True
619-
_no_split_modules = ["HiDreamImageBlock"]
582+
_no_split_modules = ["HiDreamImageTransformerBlock", "HiDreamImageSingleTransformerBlock"]
620583

621584
@register_to_config
622585
def __init__(
@@ -652,29 +615,27 @@ def __init__(
652615

653616
self.double_stream_blocks = nn.ModuleList(
654617
[
655-
HiDreamImageBlock(
618+
HiDreamImageTransformerBlock(
656619
dim=self.inner_dim,
657620
num_attention_heads=self.config.num_attention_heads,
658621
attention_head_dim=self.config.attention_head_dim,
659622
num_routed_experts=num_routed_experts,
660623
num_activated_experts=num_activated_experts,
661-
block_type=BlockType.TransformerBlock,
662624
)
663-
for i in range(self.config.num_layers)
625+
for _ in range(self.config.num_layers)
664626
]
665627
)
666628

667629
self.single_stream_blocks = nn.ModuleList(
668630
[
669-
HiDreamImageBlock(
631+
HiDreamImageSingleTransformerBlock(
670632
dim=self.inner_dim,
671633
num_attention_heads=self.config.num_attention_heads,
672634
attention_head_dim=self.config.attention_head_dim,
673635
num_routed_experts=num_routed_experts,
674636
num_activated_experts=num_activated_experts,
675-
block_type=BlockType.SingleTransformerBlock,
676637
)
677-
for i in range(self.config.num_single_layers)
638+
for _ in range(self.config.num_single_layers)
678639
]
679640
)
680641

@@ -816,8 +777,8 @@ def forward(
816777
p_embedder = self.p_embedder(pooled_embeds)
817778
adaln_input = timesteps + p_embedder
818779

819-
hidden_states, image_tokens_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes)
820-
if image_tokens_masks is None:
780+
hidden_states, hidden_states_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes)
781+
if hidden_states_masks is None:
821782
pH, pW = img_sizes[0]
822783
img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device)
823784
img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None]
@@ -869,16 +830,16 @@ def forward(
869830
hidden_states, initial_encoder_hidden_states = self._gradient_checkpointing_func(
870831
block,
871832
hidden_states,
872-
image_tokens_masks,
833+
hidden_states_masks,
873834
cur_encoder_hidden_states,
874835
adaln_input,
875836
image_rotary_emb,
876837
)
877838
else:
878839
hidden_states, initial_encoder_hidden_states = block(
879-
image_tokens=hidden_states,
880-
image_tokens_masks=image_tokens_masks,
881-
text_tokens=cur_encoder_hidden_states,
840+
hidden_states=hidden_states,
841+
hidden_states_masks=hidden_states_masks,
842+
encoder_hidden_states=cur_encoder_hidden_states,
882843
adaln_input=adaln_input,
883844
image_rotary_emb=image_rotary_emb,
884845
)
@@ -888,13 +849,13 @@ def forward(
888849
image_tokens_seq_len = hidden_states.shape[1]
889850
hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1)
890851
hidden_states_seq_len = hidden_states.shape[1]
891-
if image_tokens_masks is not None:
852+
if hidden_states_masks is not None:
892853
encoder_attention_mask_ones = torch.ones(
893854
(batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]),
894-
device=image_tokens_masks.device,
895-
dtype=image_tokens_masks.dtype,
855+
device=hidden_states_masks.device,
856+
dtype=hidden_states_masks.dtype,
896857
)
897-
image_tokens_masks = torch.cat([image_tokens_masks, encoder_attention_mask_ones], dim=1)
858+
hidden_states_masks = torch.cat([hidden_states_masks, encoder_attention_mask_ones], dim=1)
898859

899860
for bid, block in enumerate(self.single_stream_blocks):
900861
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
@@ -903,16 +864,16 @@ def forward(
903864
hidden_states = self._gradient_checkpointing_func(
904865
block,
905866
hidden_states,
906-
image_tokens_masks,
867+
hidden_states_masks,
907868
None,
908869
adaln_input,
909870
image_rotary_emb,
910871
)
911872
else:
912873
hidden_states = block(
913-
image_tokens=hidden_states,
914-
image_tokens_masks=image_tokens_masks,
915-
text_tokens=None,
874+
hidden_states=hidden_states,
875+
hidden_states_masks=hidden_states_masks,
876+
encoder_hidden_states=None,
916877
adaln_input=adaln_input,
917878
image_rotary_emb=image_rotary_emb,
918879
)
@@ -922,13 +883,13 @@ def forward(
922883
hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
923884
output = self.final_layer(hidden_states, adaln_input)
924885
output = self.unpatchify(output, img_sizes, self.training)
925-
if image_tokens_masks is not None:
926-
image_tokens_masks = image_tokens_masks[:, :image_tokens_seq_len]
886+
if hidden_states_masks is not None:
887+
hidden_states_masks = hidden_states_masks[:, :image_tokens_seq_len]
927888

928889
if USE_PEFT_BACKEND:
929890
# remove `lora_scale` from each PEFT layer
930891
unscale_lora_layers(self, lora_scale)
931892

932893
if not return_dict:
933-
return (output, image_tokens_masks)
934-
return Transformer2DModelOutput(sample=output, mask=image_tokens_masks)
894+
return (output, hidden_states_masks)
895+
return Transformer2DModelOutput(sample=output, mask=hidden_states_masks)

0 commit comments

Comments
 (0)