@@ -183,16 +183,16 @@ def __init__(
183183
184184 def forward (
185185 self ,
186- norm_image_tokens : torch .FloatTensor ,
187- image_tokens_masks : torch .FloatTensor = None ,
188- norm_text_tokens : torch .FloatTensor = None ,
186+ norm_hidden_states : torch .FloatTensor ,
187+ hidden_states_masks : torch .FloatTensor = None ,
188+ norm_encoder_hidden_states : torch .FloatTensor = None ,
189189 image_rotary_emb : torch .FloatTensor = None ,
190190 ) -> torch .Tensor :
191191 return self .processor (
192192 self ,
193- image_tokens = norm_image_tokens ,
194- image_tokens_masks = image_tokens_masks ,
195- text_tokens = norm_text_tokens ,
193+ hidden_states = norm_hidden_states ,
194+ hidden_states_masks = hidden_states_masks ,
195+ encoder_hidden_states = norm_encoder_hidden_states ,
196196 image_rotary_emb = image_rotary_emb ,
197197 )
198198
@@ -203,33 +203,33 @@ class HiDreamAttnProcessor:
203203 def __call__ (
204204 self ,
205205 attn : HiDreamAttention ,
206- image_tokens : torch .FloatTensor ,
207- image_tokens_masks : Optional [torch .FloatTensor ] = None ,
208- text_tokens : Optional [torch .FloatTensor ] = None ,
206+ hidden_states : torch .FloatTensor ,
207+ hidden_states_masks : Optional [torch .FloatTensor ] = None ,
208+ encoder_hidden_states : Optional [torch .FloatTensor ] = None ,
209209 image_rotary_emb : torch .FloatTensor = None ,
210210 * args ,
211211 ** kwargs ,
212212 ) -> torch .FloatTensor :
213- dtype = image_tokens .dtype
214- batch_size = image_tokens .shape [0 ]
213+ dtype = hidden_states .dtype
214+ batch_size = hidden_states .shape [0 ]
215215
216- query_i = attn .q_rms_norm (attn .to_q (image_tokens )).to (dtype = dtype )
217- key_i = attn .k_rms_norm (attn .to_k (image_tokens )).to (dtype = dtype )
218- value_i = attn .to_v (image_tokens )
216+ query_i = attn .q_rms_norm (attn .to_q (hidden_states )).to (dtype = dtype )
217+ key_i = attn .k_rms_norm (attn .to_k (hidden_states )).to (dtype = dtype )
218+ value_i = attn .to_v (hidden_states )
219219
220220 inner_dim = key_i .shape [- 1 ]
221221 head_dim = inner_dim // attn .heads
222222
223223 query_i = query_i .view (batch_size , - 1 , attn .heads , head_dim )
224224 key_i = key_i .view (batch_size , - 1 , attn .heads , head_dim )
225225 value_i = value_i .view (batch_size , - 1 , attn .heads , head_dim )
226- if image_tokens_masks is not None :
227- key_i = key_i * image_tokens_masks .view (batch_size , - 1 , 1 , 1 )
226+ if hidden_states_masks is not None :
227+ key_i = key_i * hidden_states_masks .view (batch_size , - 1 , 1 , 1 )
228228
229229 if not attn .single :
230- query_t = attn .q_rms_norm_t (attn .to_q_t (text_tokens )).to (dtype = dtype )
231- key_t = attn .k_rms_norm_t (attn .to_k_t (text_tokens )).to (dtype = dtype )
232- value_t = attn .to_v_t (text_tokens )
230+ query_t = attn .q_rms_norm_t (attn .to_q_t (encoder_hidden_states )).to (dtype = dtype )
231+ key_t = attn .k_rms_norm_t (attn .to_k_t (encoder_hidden_states )).to (dtype = dtype )
232+ value_t = attn .to_v_t (encoder_hidden_states )
233233
234234 query_t = query_t .view (batch_size , - 1 , attn .heads , head_dim )
235235 key_t = key_t .view (batch_size , - 1 , attn .heads , head_dim )
@@ -454,33 +454,33 @@ def __init__(
454454
455455 def forward (
456456 self ,
457- image_tokens : torch .FloatTensor ,
458- image_tokens_masks : Optional [torch .FloatTensor ] = None ,
459- text_tokens : Optional [torch .FloatTensor ] = None ,
457+ hidden_states : torch .FloatTensor ,
458+ hidden_states_masks : Optional [torch .FloatTensor ] = None ,
459+ encoder_hidden_states : Optional [torch .FloatTensor ] = None ,
460460 adaln_input : Optional [torch .FloatTensor ] = None ,
461461 image_rotary_emb : torch .FloatTensor = None ,
462462 ) -> torch .FloatTensor :
463- wtype = image_tokens .dtype
463+ wtype = hidden_states .dtype
464464 shift_msa_i , scale_msa_i , gate_msa_i , shift_mlp_i , scale_mlp_i , gate_mlp_i = self .adaLN_modulation (
465465 adaln_input
466466 )[:, None ].chunk (6 , dim = - 1 )
467467
468468 # 1. MM-Attention
469- norm_image_tokens = self .norm1_i (image_tokens ).to (dtype = wtype )
470- norm_image_tokens = norm_image_tokens * (1 + scale_msa_i ) + shift_msa_i
469+ norm_hidden_states = self .norm1_i (hidden_states ).to (dtype = wtype )
470+ norm_hidden_states = norm_hidden_states * (1 + scale_msa_i ) + shift_msa_i
471471 attn_output_i = self .attn1 (
472- norm_image_tokens ,
473- image_tokens_masks ,
472+ norm_hidden_states ,
473+ hidden_states_masks ,
474474 image_rotary_emb = image_rotary_emb ,
475475 )
476- image_tokens = gate_msa_i * attn_output_i + image_tokens
476+ hidden_states = gate_msa_i * attn_output_i + hidden_states
477477
478478 # 2. Feed-forward
479- norm_image_tokens = self .norm3_i (image_tokens ).to (dtype = wtype )
480- norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i ) + shift_mlp_i
481- ff_output_i = gate_mlp_i * self .ff_i (norm_image_tokens .to (dtype = wtype ))
482- image_tokens = ff_output_i + image_tokens
483- return image_tokens
479+ norm_hidden_states = self .norm3_i (hidden_states ).to (dtype = wtype )
480+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp_i ) + shift_mlp_i
481+ ff_output_i = gate_mlp_i * self .ff_i (norm_hidden_states .to (dtype = wtype ))
482+ hidden_states = ff_output_i + hidden_states
483+ return hidden_states
484484
485485
486486@maybe_allow_in_graph
@@ -526,13 +526,13 @@ def __init__(
526526
527527 def forward (
528528 self ,
529- image_tokens : torch .FloatTensor ,
530- image_tokens_masks : Optional [torch .FloatTensor ] = None ,
531- text_tokens : Optional [torch .FloatTensor ] = None ,
529+ hidden_states : torch .FloatTensor ,
530+ hidden_states_masks : Optional [torch .FloatTensor ] = None ,
531+ encoder_hidden_states : Optional [torch .FloatTensor ] = None ,
532532 adaln_input : Optional [torch .FloatTensor ] = None ,
533533 image_rotary_emb : torch .FloatTensor = None ,
534534 ) -> torch .FloatTensor :
535- wtype = image_tokens .dtype
535+ wtype = hidden_states .dtype
536536 (
537537 shift_msa_i ,
538538 scale_msa_i ,
@@ -549,74 +549,37 @@ def forward(
549549 ) = self .adaLN_modulation (adaln_input )[:, None ].chunk (12 , dim = - 1 )
550550
551551 # 1. MM-Attention
552- norm_image_tokens = self .norm1_i (image_tokens ).to (dtype = wtype )
553- norm_image_tokens = norm_image_tokens * (1 + scale_msa_i ) + shift_msa_i
554- norm_text_tokens = self .norm1_t (text_tokens ).to (dtype = wtype )
555- norm_text_tokens = norm_text_tokens * (1 + scale_msa_t ) + shift_msa_t
552+ norm_hidden_states = self .norm1_i (hidden_states ).to (dtype = wtype )
553+ norm_hidden_states = norm_hidden_states * (1 + scale_msa_i ) + shift_msa_i
554+ norm_encoder_hidden_states = self .norm1_t (encoder_hidden_states ).to (dtype = wtype )
555+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + scale_msa_t ) + shift_msa_t
556556
557557 attn_output_i , attn_output_t = self .attn1 (
558- norm_image_tokens ,
559- image_tokens_masks ,
560- norm_text_tokens ,
558+ norm_hidden_states ,
559+ hidden_states_masks ,
560+ norm_encoder_hidden_states ,
561561 image_rotary_emb = image_rotary_emb ,
562562 )
563563
564- image_tokens = gate_msa_i * attn_output_i + image_tokens
565- text_tokens = gate_msa_t * attn_output_t + text_tokens
564+ hidden_states = gate_msa_i * attn_output_i + hidden_states
565+ encoder_hidden_states = gate_msa_t * attn_output_t + encoder_hidden_states
566566
567567 # 2. Feed-forward
568- norm_image_tokens = self .norm3_i (image_tokens ).to (dtype = wtype )
569- norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i ) + shift_mlp_i
570- norm_text_tokens = self .norm3_t (text_tokens ).to (dtype = wtype )
571- norm_text_tokens = norm_text_tokens * (1 + scale_mlp_t ) + shift_mlp_t
572-
573- ff_output_i = gate_mlp_i * self .ff_i (norm_image_tokens )
574- ff_output_t = gate_mlp_t * self .ff_t (norm_text_tokens )
575- image_tokens = ff_output_i + image_tokens
576- text_tokens = ff_output_t + text_tokens
577- return image_tokens , text_tokens
568+ norm_hidden_states = self .norm3_i (hidden_states ).to (dtype = wtype )
569+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp_i ) + shift_mlp_i
570+ norm_encoder_hidden_states = self .norm3_t (encoder_hidden_states ).to (dtype = wtype )
571+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + scale_mlp_t ) + shift_mlp_t
578572
579-
580- @maybe_allow_in_graph
581- class HiDreamImageBlock (nn .Module ):
582- def __init__ (
583- self ,
584- dim : int ,
585- num_attention_heads : int ,
586- attention_head_dim : int ,
587- num_routed_experts : int = 4 ,
588- num_activated_experts : int = 2 ,
589- block_type : BlockType = BlockType .TransformerBlock ,
590- ):
591- super ().__init__ ()
592- block_classes = {
593- BlockType .TransformerBlock : HiDreamImageTransformerBlock ,
594- BlockType .SingleTransformerBlock : HiDreamImageSingleTransformerBlock ,
595- }
596- self .block = block_classes [block_type ](
597- dim , num_attention_heads , attention_head_dim , num_routed_experts , num_activated_experts
598- )
599-
600- def forward (
601- self ,
602- image_tokens : torch .FloatTensor ,
603- image_tokens_masks : Optional [torch .FloatTensor ] = None ,
604- text_tokens : Optional [torch .FloatTensor ] = None ,
605- adaln_input : torch .FloatTensor = None ,
606- image_rotary_emb : torch .FloatTensor = None ,
607- ) -> torch .FloatTensor :
608- return self .block (
609- image_tokens ,
610- image_tokens_masks ,
611- text_tokens ,
612- adaln_input ,
613- image_rotary_emb ,
614- )
573+ ff_output_i = gate_mlp_i * self .ff_i (norm_hidden_states )
574+ ff_output_t = gate_mlp_t * self .ff_t (norm_encoder_hidden_states )
575+ hidden_states = ff_output_i + hidden_states
576+ encoder_hidden_states = ff_output_t + encoder_hidden_states
577+ return hidden_states , encoder_hidden_states
615578
616579
617580class HiDreamImageTransformer2DModel (ModelMixin , ConfigMixin , PeftAdapterMixin ):
618581 _supports_gradient_checkpointing = True
619- _no_split_modules = ["HiDreamImageBlock " ]
582+ _no_split_modules = ["HiDreamImageTransformerBlock" , "HiDreamImageSingleTransformerBlock " ]
620583
621584 @register_to_config
622585 def __init__ (
@@ -652,29 +615,27 @@ def __init__(
652615
653616 self .double_stream_blocks = nn .ModuleList (
654617 [
655- HiDreamImageBlock (
618+ HiDreamImageTransformerBlock (
656619 dim = self .inner_dim ,
657620 num_attention_heads = self .config .num_attention_heads ,
658621 attention_head_dim = self .config .attention_head_dim ,
659622 num_routed_experts = num_routed_experts ,
660623 num_activated_experts = num_activated_experts ,
661- block_type = BlockType .TransformerBlock ,
662624 )
663- for i in range (self .config .num_layers )
625+ for _ in range (self .config .num_layers )
664626 ]
665627 )
666628
667629 self .single_stream_blocks = nn .ModuleList (
668630 [
669- HiDreamImageBlock (
631+ HiDreamImageSingleTransformerBlock (
670632 dim = self .inner_dim ,
671633 num_attention_heads = self .config .num_attention_heads ,
672634 attention_head_dim = self .config .attention_head_dim ,
673635 num_routed_experts = num_routed_experts ,
674636 num_activated_experts = num_activated_experts ,
675- block_type = BlockType .SingleTransformerBlock ,
676637 )
677- for i in range (self .config .num_single_layers )
638+ for _ in range (self .config .num_single_layers )
678639 ]
679640 )
680641
@@ -816,8 +777,8 @@ def forward(
816777 p_embedder = self .p_embedder (pooled_embeds )
817778 adaln_input = timesteps + p_embedder
818779
819- hidden_states , image_tokens_masks , img_sizes = self .patchify (hidden_states , self .max_seq , img_sizes )
820- if image_tokens_masks is None :
780+ hidden_states , hidden_states_masks , img_sizes = self .patchify (hidden_states , self .max_seq , img_sizes )
781+ if hidden_states_masks is None :
821782 pH , pW = img_sizes [0 ]
822783 img_ids = torch .zeros (pH , pW , 3 , device = hidden_states .device )
823784 img_ids [..., 1 ] = img_ids [..., 1 ] + torch .arange (pH , device = hidden_states .device )[:, None ]
@@ -869,16 +830,16 @@ def forward(
869830 hidden_states , initial_encoder_hidden_states = self ._gradient_checkpointing_func (
870831 block ,
871832 hidden_states ,
872- image_tokens_masks ,
833+ hidden_states_masks ,
873834 cur_encoder_hidden_states ,
874835 adaln_input ,
875836 image_rotary_emb ,
876837 )
877838 else :
878839 hidden_states , initial_encoder_hidden_states = block (
879- image_tokens = hidden_states ,
880- image_tokens_masks = image_tokens_masks ,
881- text_tokens = cur_encoder_hidden_states ,
840+ hidden_states = hidden_states ,
841+ hidden_states_masks = hidden_states_masks ,
842+ encoder_hidden_states = cur_encoder_hidden_states ,
882843 adaln_input = adaln_input ,
883844 image_rotary_emb = image_rotary_emb ,
884845 )
@@ -888,13 +849,13 @@ def forward(
888849 image_tokens_seq_len = hidden_states .shape [1 ]
889850 hidden_states = torch .cat ([hidden_states , initial_encoder_hidden_states ], dim = 1 )
890851 hidden_states_seq_len = hidden_states .shape [1 ]
891- if image_tokens_masks is not None :
852+ if hidden_states_masks is not None :
892853 encoder_attention_mask_ones = torch .ones (
893854 (batch_size , initial_encoder_hidden_states .shape [1 ] + cur_llama31_encoder_hidden_states .shape [1 ]),
894- device = image_tokens_masks .device ,
895- dtype = image_tokens_masks .dtype ,
855+ device = hidden_states_masks .device ,
856+ dtype = hidden_states_masks .dtype ,
896857 )
897- image_tokens_masks = torch .cat ([image_tokens_masks , encoder_attention_mask_ones ], dim = 1 )
858+ hidden_states_masks = torch .cat ([hidden_states_masks , encoder_attention_mask_ones ], dim = 1 )
898859
899860 for bid , block in enumerate (self .single_stream_blocks ):
900861 cur_llama31_encoder_hidden_states = encoder_hidden_states [block_id ]
@@ -903,16 +864,16 @@ def forward(
903864 hidden_states = self ._gradient_checkpointing_func (
904865 block ,
905866 hidden_states ,
906- image_tokens_masks ,
867+ hidden_states_masks ,
907868 None ,
908869 adaln_input ,
909870 image_rotary_emb ,
910871 )
911872 else :
912873 hidden_states = block (
913- image_tokens = hidden_states ,
914- image_tokens_masks = image_tokens_masks ,
915- text_tokens = None ,
874+ hidden_states = hidden_states ,
875+ hidden_states_masks = hidden_states_masks ,
876+ encoder_hidden_states = None ,
916877 adaln_input = adaln_input ,
917878 image_rotary_emb = image_rotary_emb ,
918879 )
@@ -922,13 +883,13 @@ def forward(
922883 hidden_states = hidden_states [:, :image_tokens_seq_len , ...]
923884 output = self .final_layer (hidden_states , adaln_input )
924885 output = self .unpatchify (output , img_sizes , self .training )
925- if image_tokens_masks is not None :
926- image_tokens_masks = image_tokens_masks [:, :image_tokens_seq_len ]
886+ if hidden_states_masks is not None :
887+ hidden_states_masks = hidden_states_masks [:, :image_tokens_seq_len ]
927888
928889 if USE_PEFT_BACKEND :
929890 # remove `lora_scale` from each PEFT layer
930891 unscale_lora_layers (self , lora_scale )
931892
932893 if not return_dict :
933- return (output , image_tokens_masks )
934- return Transformer2DModelOutput (sample = output , mask = image_tokens_masks )
894+ return (output , hidden_states_masks )
895+ return Transformer2DModelOutput (sample = output , mask = hidden_states_masks )
0 commit comments