Skip to content

Commit 32af5ce

Browse files
committed
address review comments
1 parent 9eb0b8b commit 32af5ce

File tree

2 files changed

+153
-133
lines changed

2 files changed

+153
-133
lines changed

src/diffusers/models/transformers/transformer_hidream_image.py

Lines changed: 80 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
import math
2-
from typing import Any, Dict, List, Optional, Tuple
1+
from typing import Any, Dict, List, Optional, Tuple, Union
32

43
import torch
54
import torch.nn as nn
@@ -12,10 +11,7 @@
1211
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
1312
from ...utils.torch_utils import maybe_allow_in_graph
1413
from ..attention import Attention
15-
from ..embeddings import (
16-
TimestepEmbedding,
17-
Timesteps,
18-
)
14+
from ..embeddings import TimestepEmbedding, Timesteps
1915

2016

2117
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -40,7 +36,7 @@ def __init__(
4036
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
4137
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
4238

43-
def forward(self, x):
39+
def forward(self, x: torch.Tensor) -> torch.Tensor:
4440
return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
4541

4642

@@ -49,7 +45,7 @@ def __init__(self, text_emb_dim, hidden_size):
4945
super().__init__()
5046
self.pooled_embedder = TimestepEmbedding(in_channels=text_emb_dim, time_embed_dim=hidden_size)
5147

52-
def forward(self, pooled_embed):
48+
def forward(self, pooled_embed: torch.Tensor) -> torch.Tensor:
5349
return self.pooled_embedder(pooled_embed)
5450

5551

@@ -59,7 +55,7 @@ def __init__(self, hidden_size, frequency_embedding_size=256):
5955
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0)
6056
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size)
6157

62-
def forward(self, timesteps, wdtype):
58+
def forward(self, timesteps: torch.Tensor, wdtype: Optional[torch.dtype] = None):
6359
t_emb = self.time_proj(timesteps).to(dtype=wdtype)
6460
t_emb = self.timestep_embedder(t_emb)
6561
return t_emb
@@ -72,11 +68,11 @@ def __init__(self, hidden_size, patch_size, out_channels):
7268
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
7369
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
7470

75-
def forward(self, x, adaln_input):
76-
shift, scale = self.adaLN_modulation(adaln_input).chunk(2, dim=1)
77-
x = self.norm_final(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
78-
x = self.linear(x)
79-
return x
71+
def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
72+
shift, scale = self.adaLN_modulation(temb).chunk(2, dim=1)
73+
hidden_states = self.norm_final(hidden_states) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
74+
hidden_states = self.linear(hidden_states)
75+
return hidden_states
8076

8177

8278
class HiDreamImagePatchEmbed(nn.Module):
@@ -183,10 +179,10 @@ def __init__(
183179

184180
def forward(
185181
self,
186-
norm_hidden_states: torch.FloatTensor,
187-
hidden_states_masks: torch.FloatTensor = None,
188-
norm_encoder_hidden_states: torch.FloatTensor = None,
189-
image_rotary_emb: torch.FloatTensor = None,
182+
norm_hidden_states: torch.Tensor,
183+
hidden_states_masks: torch.Tensor = None,
184+
norm_encoder_hidden_states: torch.Tensor = None,
185+
image_rotary_emb: torch.Tensor = None,
190186
) -> torch.Tensor:
191187
return self.processor(
192188
self,
@@ -203,13 +199,13 @@ class HiDreamAttnProcessor:
203199
def __call__(
204200
self,
205201
attn: HiDreamAttention,
206-
hidden_states: torch.FloatTensor,
207-
hidden_states_masks: Optional[torch.FloatTensor] = None,
208-
encoder_hidden_states: Optional[torch.FloatTensor] = None,
209-
image_rotary_emb: torch.FloatTensor = None,
202+
hidden_states: torch.Tensor,
203+
hidden_states_masks: Optional[torch.Tensor] = None,
204+
encoder_hidden_states: Optional[torch.Tensor] = None,
205+
image_rotary_emb: torch.Tensor = None,
210206
*args,
211207
**kwargs,
212-
) -> torch.FloatTensor:
208+
) -> torch.Tensor:
213209
dtype = hidden_states.dtype
214210
batch_size = hidden_states.shape[0]
215211

@@ -286,13 +282,7 @@ def __init__(self, embed_dim, num_routed_experts=4, num_activated_experts=2, aux
286282
# topk selection algorithm
287283
self.norm_topk_prob = False
288284
self.gating_dim = embed_dim
289-
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
290-
self.reset_parameters()
291-
292-
def reset_parameters(self) -> None:
293-
import torch.nn.init as init
294-
295-
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
285+
self.weight = nn.Parameter(torch.randn(self.n_routed_experts, self.gating_dim) / embed_dim**0.5)
296286

297287
def forward(self, hidden_states):
298288
bsz, seq_len, h = hidden_states.shape
@@ -409,11 +399,6 @@ def forward(self, caption):
409399
return hidden_states
410400

411401

412-
class BlockType:
413-
TransformerBlock = 1
414-
SingleTransformerBlock = 2
415-
416-
417402
@maybe_allow_in_graph
418403
class HiDreamImageSingleTransformerBlock(nn.Module):
419404
def __init__(
@@ -427,8 +412,6 @@ def __init__(
427412
super().__init__()
428413
self.num_attention_heads = num_attention_heads
429414
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True))
430-
nn.init.zeros_(self.adaLN_modulation[1].weight)
431-
nn.init.zeros_(self.adaLN_modulation[1].bias)
432415

433416
# 1. Attention
434417
self.norm1_i = nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False)
@@ -454,16 +437,16 @@ def __init__(
454437

455438
def forward(
456439
self,
457-
hidden_states: torch.FloatTensor,
458-
hidden_states_masks: Optional[torch.FloatTensor] = None,
459-
encoder_hidden_states: Optional[torch.FloatTensor] = None,
460-
adaln_input: Optional[torch.FloatTensor] = None,
461-
image_rotary_emb: torch.FloatTensor = None,
462-
) -> torch.FloatTensor:
440+
hidden_states: torch.Tensor,
441+
hidden_states_masks: Optional[torch.Tensor] = None,
442+
encoder_hidden_states: Optional[torch.Tensor] = None,
443+
temb: Optional[torch.Tensor] = None,
444+
image_rotary_emb: torch.Tensor = None,
445+
) -> torch.Tensor:
463446
wtype = hidden_states.dtype
464-
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = self.adaLN_modulation(
465-
adaln_input
466-
)[:, None].chunk(6, dim=-1)
447+
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = self.adaLN_modulation(temb)[
448+
:, None
449+
].chunk(6, dim=-1)
467450

468451
# 1. MM-Attention
469452
norm_hidden_states = self.norm1_i(hidden_states).to(dtype=wtype)
@@ -496,8 +479,6 @@ def __init__(
496479
super().__init__()
497480
self.num_attention_heads = num_attention_heads
498481
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 12 * dim, bias=True))
499-
nn.init.zeros_(self.adaLN_modulation[1].weight)
500-
nn.init.zeros_(self.adaLN_modulation[1].bias)
501482

502483
# 1. Attention
503484
self.norm1_i = nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False)
@@ -526,12 +507,12 @@ def __init__(
526507

527508
def forward(
528509
self,
529-
hidden_states: torch.FloatTensor,
530-
hidden_states_masks: Optional[torch.FloatTensor] = None,
531-
encoder_hidden_states: Optional[torch.FloatTensor] = None,
532-
adaln_input: Optional[torch.FloatTensor] = None,
533-
image_rotary_emb: torch.FloatTensor = None,
534-
) -> torch.FloatTensor:
510+
hidden_states: torch.Tensor,
511+
hidden_states_masks: Optional[torch.Tensor] = None,
512+
encoder_hidden_states: Optional[torch.Tensor] = None,
513+
temb: Optional[torch.Tensor] = None,
514+
image_rotary_emb: torch.Tensor = None,
515+
) -> torch.Tensor:
535516
wtype = hidden_states.dtype
536517
(
537518
shift_msa_i,
@@ -546,7 +527,7 @@ def forward(
546527
shift_mlp_t,
547528
scale_mlp_t,
548529
gate_mlp_t,
549-
) = self.adaLN_modulation(adaln_input)[:, None].chunk(12, dim=-1)
530+
) = self.adaLN_modulation(temb)[:, None].chunk(12, dim=-1)
550531

551532
# 1. MM-Attention
552533
norm_hidden_states = self.norm1_i(hidden_states).to(dtype=wtype)
@@ -577,6 +558,28 @@ def forward(
577558
return hidden_states, encoder_hidden_states
578559

579560

561+
class HiDreamBlock(nn.Module):
562+
def __init__(self, block: Union[HiDreamImageTransformerBlock, HiDreamImageSingleTransformerBlock]):
563+
super().__init__()
564+
self.block = block
565+
566+
def forward(
567+
self,
568+
hidden_states: torch.Tensor,
569+
hidden_states_masks: Optional[torch.Tensor] = None,
570+
encoder_hidden_states: Optional[torch.Tensor] = None,
571+
temb: Optional[torch.Tensor] = None,
572+
image_rotary_emb: torch.Tensor = None,
573+
) -> torch.Tensor:
574+
return self.block(
575+
hidden_states=hidden_states,
576+
hidden_states_masks=hidden_states_masks,
577+
encoder_hidden_states=encoder_hidden_states,
578+
temb=temb,
579+
image_rotary_emb=image_rotary_emb,
580+
)
581+
582+
580583
class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
581584
_supports_gradient_checkpointing = True
582585
_no_split_modules = ["HiDreamImageTransformerBlock", "HiDreamImageSingleTransformerBlock"]
@@ -615,25 +618,29 @@ def __init__(
615618

616619
self.double_stream_blocks = nn.ModuleList(
617620
[
618-
HiDreamImageTransformerBlock(
619-
dim=self.inner_dim,
620-
num_attention_heads=self.config.num_attention_heads,
621-
attention_head_dim=self.config.attention_head_dim,
622-
num_routed_experts=num_routed_experts,
623-
num_activated_experts=num_activated_experts,
621+
HiDreamBlock(
622+
HiDreamImageTransformerBlock(
623+
dim=self.inner_dim,
624+
num_attention_heads=self.config.num_attention_heads,
625+
attention_head_dim=self.config.attention_head_dim,
626+
num_routed_experts=num_routed_experts,
627+
num_activated_experts=num_activated_experts,
628+
)
624629
)
625630
for _ in range(self.config.num_layers)
626631
]
627632
)
628633

629634
self.single_stream_blocks = nn.ModuleList(
630635
[
631-
HiDreamImageSingleTransformerBlock(
632-
dim=self.inner_dim,
633-
num_attention_heads=self.config.num_attention_heads,
634-
attention_head_dim=self.config.attention_head_dim,
635-
num_routed_experts=num_routed_experts,
636-
num_activated_experts=num_activated_experts,
636+
HiDreamBlock(
637+
HiDreamImageSingleTransformerBlock(
638+
dim=self.inner_dim,
639+
num_attention_heads=self.config.num_attention_heads,
640+
attention_head_dim=self.config.attention_head_dim,
641+
num_routed_experts=num_routed_experts,
642+
num_activated_experts=num_activated_experts,
643+
)
637644
)
638645
for _ in range(self.config.num_single_layers)
639646
]
@@ -769,7 +776,7 @@ def forward(
769776
timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device)
770777
timesteps = self.t_embedder(timesteps, hidden_states_type)
771778
p_embedder = self.p_embedder(pooled_embeds)
772-
adaln_input = timesteps + p_embedder
779+
temb = timesteps + p_embedder
773780

774781
hidden_states, hidden_states_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes)
775782
if hidden_states_masks is None:
@@ -826,15 +833,15 @@ def forward(
826833
hidden_states,
827834
hidden_states_masks,
828835
cur_encoder_hidden_states,
829-
adaln_input,
836+
temb,
830837
image_rotary_emb,
831838
)
832839
else:
833840
hidden_states, initial_encoder_hidden_states = block(
834841
hidden_states=hidden_states,
835842
hidden_states_masks=hidden_states_masks,
836843
encoder_hidden_states=cur_encoder_hidden_states,
837-
adaln_input=adaln_input,
844+
temb=temb,
838845
image_rotary_emb=image_rotary_emb,
839846
)
840847
initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
@@ -860,22 +867,22 @@ def forward(
860867
hidden_states,
861868
hidden_states_masks,
862869
None,
863-
adaln_input,
870+
temb,
864871
image_rotary_emb,
865872
)
866873
else:
867874
hidden_states = block(
868875
hidden_states=hidden_states,
869876
hidden_states_masks=hidden_states_masks,
870877
encoder_hidden_states=None,
871-
adaln_input=adaln_input,
878+
temb=temb,
872879
image_rotary_emb=image_rotary_emb,
873880
)
874881
hidden_states = hidden_states[:, :hidden_states_seq_len]
875882
block_id += 1
876883

877884
hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
878-
output = self.final_layer(hidden_states, adaln_input)
885+
output = self.final_layer(hidden_states, temb)
879886
output = self.unpatchify(output, img_sizes, self.training)
880887
if hidden_states_masks is not None:
881888
hidden_states_masks = hidden_states_masks[:, :image_tokens_seq_len]

0 commit comments

Comments
 (0)