1- import math
2- from typing import Any , Dict , List , Optional , Tuple
1+ from typing import Any , Dict , List , Optional , Tuple , Union
32
43import torch
54import torch .nn as nn
1211from ...utils import USE_PEFT_BACKEND , logging , scale_lora_layers , unscale_lora_layers
1312from ...utils .torch_utils import maybe_allow_in_graph
1413from ..attention import Attention
15- from ..embeddings import (
16- TimestepEmbedding ,
17- Timesteps ,
18- )
14+ from ..embeddings import TimestepEmbedding , Timesteps
1915
2016
2117logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
@@ -40,7 +36,7 @@ def __init__(
4036 self .w2 = nn .Linear (hidden_dim , dim , bias = False )
4137 self .w3 = nn .Linear (dim , hidden_dim , bias = False )
4238
43- def forward (self , x ) :
39+ def forward (self , x : torch . Tensor ) -> torch . Tensor :
4440 return self .w2 (torch .nn .functional .silu (self .w1 (x )) * self .w3 (x ))
4541
4642
@@ -49,7 +45,7 @@ def __init__(self, text_emb_dim, hidden_size):
4945 super ().__init__ ()
5046 self .pooled_embedder = TimestepEmbedding (in_channels = text_emb_dim , time_embed_dim = hidden_size )
5147
52- def forward (self , pooled_embed ) :
48+ def forward (self , pooled_embed : torch . Tensor ) -> torch . Tensor :
5349 return self .pooled_embedder (pooled_embed )
5450
5551
@@ -59,7 +55,7 @@ def __init__(self, hidden_size, frequency_embedding_size=256):
5955 self .time_proj = Timesteps (num_channels = frequency_embedding_size , flip_sin_to_cos = True , downscale_freq_shift = 0 )
6056 self .timestep_embedder = TimestepEmbedding (in_channels = frequency_embedding_size , time_embed_dim = hidden_size )
6157
62- def forward (self , timesteps , wdtype ):
58+ def forward (self , timesteps : torch . Tensor , wdtype : Optional [ torch . dtype ] = None ):
6359 t_emb = self .time_proj (timesteps ).to (dtype = wdtype )
6460 t_emb = self .timestep_embedder (t_emb )
6561 return t_emb
@@ -72,11 +68,11 @@ def __init__(self, hidden_size, patch_size, out_channels):
7268 self .linear = nn .Linear (hidden_size , patch_size * patch_size * out_channels , bias = True )
7369 self .adaLN_modulation = nn .Sequential (nn .SiLU (), nn .Linear (hidden_size , 2 * hidden_size , bias = True ))
7470
75- def forward (self , x , adaln_input ) :
76- shift , scale = self .adaLN_modulation (adaln_input ).chunk (2 , dim = 1 )
77- x = self .norm_final (x ) * (1 + scale .unsqueeze (1 )) + shift .unsqueeze (1 )
78- x = self .linear (x )
79- return x
71+ def forward (self , hidden_states : torch . Tensor , temb : torch . Tensor ) -> torch . Tensor :
72+ shift , scale = self .adaLN_modulation (temb ).chunk (2 , dim = 1 )
73+ hidden_states = self .norm_final (hidden_states ) * (1 + scale .unsqueeze (1 )) + shift .unsqueeze (1 )
74+ hidden_states = self .linear (hidden_states )
75+ return hidden_states
8076
8177
8278class HiDreamImagePatchEmbed (nn .Module ):
@@ -183,10 +179,10 @@ def __init__(
183179
184180 def forward (
185181 self ,
186- norm_hidden_states : torch .FloatTensor ,
187- hidden_states_masks : torch .FloatTensor = None ,
188- norm_encoder_hidden_states : torch .FloatTensor = None ,
189- image_rotary_emb : torch .FloatTensor = None ,
182+ norm_hidden_states : torch .Tensor ,
183+ hidden_states_masks : torch .Tensor = None ,
184+ norm_encoder_hidden_states : torch .Tensor = None ,
185+ image_rotary_emb : torch .Tensor = None ,
190186 ) -> torch .Tensor :
191187 return self .processor (
192188 self ,
@@ -203,13 +199,13 @@ class HiDreamAttnProcessor:
203199 def __call__ (
204200 self ,
205201 attn : HiDreamAttention ,
206- hidden_states : torch .FloatTensor ,
207- hidden_states_masks : Optional [torch .FloatTensor ] = None ,
208- encoder_hidden_states : Optional [torch .FloatTensor ] = None ,
209- image_rotary_emb : torch .FloatTensor = None ,
202+ hidden_states : torch .Tensor ,
203+ hidden_states_masks : Optional [torch .Tensor ] = None ,
204+ encoder_hidden_states : Optional [torch .Tensor ] = None ,
205+ image_rotary_emb : torch .Tensor = None ,
210206 * args ,
211207 ** kwargs ,
212- ) -> torch .FloatTensor :
208+ ) -> torch .Tensor :
213209 dtype = hidden_states .dtype
214210 batch_size = hidden_states .shape [0 ]
215211
@@ -286,13 +282,7 @@ def __init__(self, embed_dim, num_routed_experts=4, num_activated_experts=2, aux
286282 # topk selection algorithm
287283 self .norm_topk_prob = False
288284 self .gating_dim = embed_dim
289- self .weight = nn .Parameter (torch .empty ((self .n_routed_experts , self .gating_dim )))
290- self .reset_parameters ()
291-
292- def reset_parameters (self ) -> None :
293- import torch .nn .init as init
294-
295- init .kaiming_uniform_ (self .weight , a = math .sqrt (5 ))
285+ self .weight = nn .Parameter (torch .randn (self .n_routed_experts , self .gating_dim ) / embed_dim ** 0.5 )
296286
297287 def forward (self , hidden_states ):
298288 bsz , seq_len , h = hidden_states .shape
@@ -409,11 +399,6 @@ def forward(self, caption):
409399 return hidden_states
410400
411401
412- class BlockType :
413- TransformerBlock = 1
414- SingleTransformerBlock = 2
415-
416-
417402@maybe_allow_in_graph
418403class HiDreamImageSingleTransformerBlock (nn .Module ):
419404 def __init__ (
@@ -427,8 +412,6 @@ def __init__(
427412 super ().__init__ ()
428413 self .num_attention_heads = num_attention_heads
429414 self .adaLN_modulation = nn .Sequential (nn .SiLU (), nn .Linear (dim , 6 * dim , bias = True ))
430- nn .init .zeros_ (self .adaLN_modulation [1 ].weight )
431- nn .init .zeros_ (self .adaLN_modulation [1 ].bias )
432415
433416 # 1. Attention
434417 self .norm1_i = nn .LayerNorm (dim , eps = 1e-06 , elementwise_affine = False )
@@ -454,16 +437,16 @@ def __init__(
454437
455438 def forward (
456439 self ,
457- hidden_states : torch .FloatTensor ,
458- hidden_states_masks : Optional [torch .FloatTensor ] = None ,
459- encoder_hidden_states : Optional [torch .FloatTensor ] = None ,
460- adaln_input : Optional [torch .FloatTensor ] = None ,
461- image_rotary_emb : torch .FloatTensor = None ,
462- ) -> torch .FloatTensor :
440+ hidden_states : torch .Tensor ,
441+ hidden_states_masks : Optional [torch .Tensor ] = None ,
442+ encoder_hidden_states : Optional [torch .Tensor ] = None ,
443+ temb : Optional [torch .Tensor ] = None ,
444+ image_rotary_emb : torch .Tensor = None ,
445+ ) -> torch .Tensor :
463446 wtype = hidden_states .dtype
464- shift_msa_i , scale_msa_i , gate_msa_i , shift_mlp_i , scale_mlp_i , gate_mlp_i = self .adaLN_modulation (
465- adaln_input
466- )[:, None ].chunk (6 , dim = - 1 )
447+ shift_msa_i , scale_msa_i , gate_msa_i , shift_mlp_i , scale_mlp_i , gate_mlp_i = self .adaLN_modulation (temb )[
448+ :, None
449+ ].chunk (6 , dim = - 1 )
467450
468451 # 1. MM-Attention
469452 norm_hidden_states = self .norm1_i (hidden_states ).to (dtype = wtype )
@@ -496,8 +479,6 @@ def __init__(
496479 super ().__init__ ()
497480 self .num_attention_heads = num_attention_heads
498481 self .adaLN_modulation = nn .Sequential (nn .SiLU (), nn .Linear (dim , 12 * dim , bias = True ))
499- nn .init .zeros_ (self .adaLN_modulation [1 ].weight )
500- nn .init .zeros_ (self .adaLN_modulation [1 ].bias )
501482
502483 # 1. Attention
503484 self .norm1_i = nn .LayerNorm (dim , eps = 1e-06 , elementwise_affine = False )
@@ -526,12 +507,12 @@ def __init__(
526507
527508 def forward (
528509 self ,
529- hidden_states : torch .FloatTensor ,
530- hidden_states_masks : Optional [torch .FloatTensor ] = None ,
531- encoder_hidden_states : Optional [torch .FloatTensor ] = None ,
532- adaln_input : Optional [torch .FloatTensor ] = None ,
533- image_rotary_emb : torch .FloatTensor = None ,
534- ) -> torch .FloatTensor :
510+ hidden_states : torch .Tensor ,
511+ hidden_states_masks : Optional [torch .Tensor ] = None ,
512+ encoder_hidden_states : Optional [torch .Tensor ] = None ,
513+ temb : Optional [torch .Tensor ] = None ,
514+ image_rotary_emb : torch .Tensor = None ,
515+ ) -> torch .Tensor :
535516 wtype = hidden_states .dtype
536517 (
537518 shift_msa_i ,
@@ -546,7 +527,7 @@ def forward(
546527 shift_mlp_t ,
547528 scale_mlp_t ,
548529 gate_mlp_t ,
549- ) = self .adaLN_modulation (adaln_input )[:, None ].chunk (12 , dim = - 1 )
530+ ) = self .adaLN_modulation (temb )[:, None ].chunk (12 , dim = - 1 )
550531
551532 # 1. MM-Attention
552533 norm_hidden_states = self .norm1_i (hidden_states ).to (dtype = wtype )
@@ -577,6 +558,28 @@ def forward(
577558 return hidden_states , encoder_hidden_states
578559
579560
561+ class HiDreamBlock (nn .Module ):
562+ def __init__ (self , block : Union [HiDreamImageTransformerBlock , HiDreamImageSingleTransformerBlock ]):
563+ super ().__init__ ()
564+ self .block = block
565+
566+ def forward (
567+ self ,
568+ hidden_states : torch .Tensor ,
569+ hidden_states_masks : Optional [torch .Tensor ] = None ,
570+ encoder_hidden_states : Optional [torch .Tensor ] = None ,
571+ temb : Optional [torch .Tensor ] = None ,
572+ image_rotary_emb : torch .Tensor = None ,
573+ ) -> torch .Tensor :
574+ return self .block (
575+ hidden_states = hidden_states ,
576+ hidden_states_masks = hidden_states_masks ,
577+ encoder_hidden_states = encoder_hidden_states ,
578+ temb = temb ,
579+ image_rotary_emb = image_rotary_emb ,
580+ )
581+
582+
580583class HiDreamImageTransformer2DModel (ModelMixin , ConfigMixin , PeftAdapterMixin ):
581584 _supports_gradient_checkpointing = True
582585 _no_split_modules = ["HiDreamImageTransformerBlock" , "HiDreamImageSingleTransformerBlock" ]
@@ -615,25 +618,29 @@ def __init__(
615618
616619 self .double_stream_blocks = nn .ModuleList (
617620 [
618- HiDreamImageTransformerBlock (
619- dim = self .inner_dim ,
620- num_attention_heads = self .config .num_attention_heads ,
621- attention_head_dim = self .config .attention_head_dim ,
622- num_routed_experts = num_routed_experts ,
623- num_activated_experts = num_activated_experts ,
621+ HiDreamBlock (
622+ HiDreamImageTransformerBlock (
623+ dim = self .inner_dim ,
624+ num_attention_heads = self .config .num_attention_heads ,
625+ attention_head_dim = self .config .attention_head_dim ,
626+ num_routed_experts = num_routed_experts ,
627+ num_activated_experts = num_activated_experts ,
628+ )
624629 )
625630 for _ in range (self .config .num_layers )
626631 ]
627632 )
628633
629634 self .single_stream_blocks = nn .ModuleList (
630635 [
631- HiDreamImageSingleTransformerBlock (
632- dim = self .inner_dim ,
633- num_attention_heads = self .config .num_attention_heads ,
634- attention_head_dim = self .config .attention_head_dim ,
635- num_routed_experts = num_routed_experts ,
636- num_activated_experts = num_activated_experts ,
636+ HiDreamBlock (
637+ HiDreamImageSingleTransformerBlock (
638+ dim = self .inner_dim ,
639+ num_attention_heads = self .config .num_attention_heads ,
640+ attention_head_dim = self .config .attention_head_dim ,
641+ num_routed_experts = num_routed_experts ,
642+ num_activated_experts = num_activated_experts ,
643+ )
637644 )
638645 for _ in range (self .config .num_single_layers )
639646 ]
@@ -769,7 +776,7 @@ def forward(
769776 timesteps = self .expand_timesteps (timesteps , batch_size , hidden_states .device )
770777 timesteps = self .t_embedder (timesteps , hidden_states_type )
771778 p_embedder = self .p_embedder (pooled_embeds )
772- adaln_input = timesteps + p_embedder
779+ temb = timesteps + p_embedder
773780
774781 hidden_states , hidden_states_masks , img_sizes = self .patchify (hidden_states , self .max_seq , img_sizes )
775782 if hidden_states_masks is None :
@@ -826,15 +833,15 @@ def forward(
826833 hidden_states ,
827834 hidden_states_masks ,
828835 cur_encoder_hidden_states ,
829- adaln_input ,
836+ temb ,
830837 image_rotary_emb ,
831838 )
832839 else :
833840 hidden_states , initial_encoder_hidden_states = block (
834841 hidden_states = hidden_states ,
835842 hidden_states_masks = hidden_states_masks ,
836843 encoder_hidden_states = cur_encoder_hidden_states ,
837- adaln_input = adaln_input ,
844+ temb = temb ,
838845 image_rotary_emb = image_rotary_emb ,
839846 )
840847 initial_encoder_hidden_states = initial_encoder_hidden_states [:, :initial_encoder_hidden_states_seq_len ]
@@ -860,22 +867,22 @@ def forward(
860867 hidden_states ,
861868 hidden_states_masks ,
862869 None ,
863- adaln_input ,
870+ temb ,
864871 image_rotary_emb ,
865872 )
866873 else :
867874 hidden_states = block (
868875 hidden_states = hidden_states ,
869876 hidden_states_masks = hidden_states_masks ,
870877 encoder_hidden_states = None ,
871- adaln_input = adaln_input ,
878+ temb = temb ,
872879 image_rotary_emb = image_rotary_emb ,
873880 )
874881 hidden_states = hidden_states [:, :hidden_states_seq_len ]
875882 block_id += 1
876883
877884 hidden_states = hidden_states [:, :image_tokens_seq_len , ...]
878- output = self .final_layer (hidden_states , adaln_input )
885+ output = self .final_layer (hidden_states , temb )
879886 output = self .unpatchify (output , img_sizes , self .training )
880887 if hidden_states_masks is not None :
881888 hidden_states_masks = hidden_states_masks [:, :image_tokens_seq_len ]
0 commit comments