@@ -183,16 +183,16 @@ def __init__(
183183
184184    def  forward (
185185        self ,
186-         norm_image_tokens : torch .FloatTensor ,
187-         image_tokens_masks : torch .FloatTensor  =  None ,
188-         norm_text_tokens : torch .FloatTensor  =  None ,
186+         norm_hidden_states : torch .FloatTensor ,
187+         hidden_states_masks : torch .FloatTensor  =  None ,
188+         norm_encoder_hidden_states : torch .FloatTensor  =  None ,
189189        image_rotary_emb : torch .FloatTensor  =  None ,
190190    ) ->  torch .Tensor :
191191        return  self .processor (
192192            self ,
193-             image_tokens = norm_image_tokens ,
194-             image_tokens_masks = image_tokens_masks ,
195-             text_tokens = norm_text_tokens ,
193+             hidden_states = norm_hidden_states ,
194+             hidden_states_masks = hidden_states_masks ,
195+             encoder_hidden_states = norm_encoder_hidden_states ,
196196            image_rotary_emb = image_rotary_emb ,
197197        )
198198
@@ -203,33 +203,33 @@ class HiDreamAttnProcessor:
203203    def  __call__ (
204204        self ,
205205        attn : HiDreamAttention ,
206-         image_tokens : torch .FloatTensor ,
207-         image_tokens_masks : Optional [torch .FloatTensor ] =  None ,
208-         text_tokens : Optional [torch .FloatTensor ] =  None ,
206+         hidden_states : torch .FloatTensor ,
207+         hidden_states_masks : Optional [torch .FloatTensor ] =  None ,
208+         encoder_hidden_states : Optional [torch .FloatTensor ] =  None ,
209209        image_rotary_emb : torch .FloatTensor  =  None ,
210210        * args ,
211211        ** kwargs ,
212212    ) ->  torch .FloatTensor :
213-         dtype  =  image_tokens .dtype 
214-         batch_size  =  image_tokens .shape [0 ]
213+         dtype  =  hidden_states .dtype 
214+         batch_size  =  hidden_states .shape [0 ]
215215
216-         query_i  =  attn .q_rms_norm (attn .to_q (image_tokens )).to (dtype = dtype )
217-         key_i  =  attn .k_rms_norm (attn .to_k (image_tokens )).to (dtype = dtype )
218-         value_i  =  attn .to_v (image_tokens )
216+         query_i  =  attn .q_rms_norm (attn .to_q (hidden_states )).to (dtype = dtype )
217+         key_i  =  attn .k_rms_norm (attn .to_k (hidden_states )).to (dtype = dtype )
218+         value_i  =  attn .to_v (hidden_states )
219219
220220        inner_dim  =  key_i .shape [- 1 ]
221221        head_dim  =  inner_dim  //  attn .heads 
222222
223223        query_i  =  query_i .view (batch_size , - 1 , attn .heads , head_dim )
224224        key_i  =  key_i .view (batch_size , - 1 , attn .heads , head_dim )
225225        value_i  =  value_i .view (batch_size , - 1 , attn .heads , head_dim )
226-         if  image_tokens_masks  is  not None :
227-             key_i  =  key_i  *  image_tokens_masks .view (batch_size , - 1 , 1 , 1 )
226+         if  hidden_states_masks  is  not None :
227+             key_i  =  key_i  *  hidden_states_masks .view (batch_size , - 1 , 1 , 1 )
228228
229229        if  not  attn .single :
230-             query_t  =  attn .q_rms_norm_t (attn .to_q_t (text_tokens )).to (dtype = dtype )
231-             key_t  =  attn .k_rms_norm_t (attn .to_k_t (text_tokens )).to (dtype = dtype )
232-             value_t  =  attn .to_v_t (text_tokens )
230+             query_t  =  attn .q_rms_norm_t (attn .to_q_t (encoder_hidden_states )).to (dtype = dtype )
231+             key_t  =  attn .k_rms_norm_t (attn .to_k_t (encoder_hidden_states )).to (dtype = dtype )
232+             value_t  =  attn .to_v_t (encoder_hidden_states )
233233
234234            query_t  =  query_t .view (batch_size , - 1 , attn .heads , head_dim )
235235            key_t  =  key_t .view (batch_size , - 1 , attn .heads , head_dim )
@@ -454,33 +454,33 @@ def __init__(
454454
455455    def  forward (
456456        self ,
457-         image_tokens : torch .FloatTensor ,
458-         image_tokens_masks : Optional [torch .FloatTensor ] =  None ,
459-         text_tokens : Optional [torch .FloatTensor ] =  None ,
457+         hidden_states : torch .FloatTensor ,
458+         hidden_states_masks : Optional [torch .FloatTensor ] =  None ,
459+         encoder_hidden_states : Optional [torch .FloatTensor ] =  None ,
460460        adaln_input : Optional [torch .FloatTensor ] =  None ,
461461        image_rotary_emb : torch .FloatTensor  =  None ,
462462    ) ->  torch .FloatTensor :
463-         wtype  =  image_tokens .dtype 
463+         wtype  =  hidden_states .dtype 
464464        shift_msa_i , scale_msa_i , gate_msa_i , shift_mlp_i , scale_mlp_i , gate_mlp_i  =  self .adaLN_modulation (
465465            adaln_input 
466466        )[:, None ].chunk (6 , dim = - 1 )
467467
468468        # 1. MM-Attention 
469-         norm_image_tokens  =  self .norm1_i (image_tokens ).to (dtype = wtype )
470-         norm_image_tokens  =  norm_image_tokens  *  (1  +  scale_msa_i ) +  shift_msa_i 
469+         norm_hidden_states  =  self .norm1_i (hidden_states ).to (dtype = wtype )
470+         norm_hidden_states  =  norm_hidden_states  *  (1  +  scale_msa_i ) +  shift_msa_i 
471471        attn_output_i  =  self .attn1 (
472-             norm_image_tokens ,
473-             image_tokens_masks ,
472+             norm_hidden_states ,
473+             hidden_states_masks ,
474474            image_rotary_emb = image_rotary_emb ,
475475        )
476-         image_tokens  =  gate_msa_i  *  attn_output_i  +  image_tokens 
476+         hidden_states  =  gate_msa_i  *  attn_output_i  +  hidden_states 
477477
478478        # 2. Feed-forward 
479-         norm_image_tokens  =  self .norm3_i (image_tokens ).to (dtype = wtype )
480-         norm_image_tokens  =  norm_image_tokens  *  (1  +  scale_mlp_i ) +  shift_mlp_i 
481-         ff_output_i  =  gate_mlp_i  *  self .ff_i (norm_image_tokens .to (dtype = wtype ))
482-         image_tokens  =  ff_output_i  +  image_tokens 
483-         return  image_tokens 
479+         norm_hidden_states  =  self .norm3_i (hidden_states ).to (dtype = wtype )
480+         norm_hidden_states  =  norm_hidden_states  *  (1  +  scale_mlp_i ) +  shift_mlp_i 
481+         ff_output_i  =  gate_mlp_i  *  self .ff_i (norm_hidden_states .to (dtype = wtype ))
482+         hidden_states  =  ff_output_i  +  hidden_states 
483+         return  hidden_states 
484484
485485
486486@maybe_allow_in_graph  
@@ -526,13 +526,13 @@ def __init__(
526526
527527    def  forward (
528528        self ,
529-         image_tokens : torch .FloatTensor ,
530-         image_tokens_masks : Optional [torch .FloatTensor ] =  None ,
531-         text_tokens : Optional [torch .FloatTensor ] =  None ,
529+         hidden_states : torch .FloatTensor ,
530+         hidden_states_masks : Optional [torch .FloatTensor ] =  None ,
531+         encoder_hidden_states : Optional [torch .FloatTensor ] =  None ,
532532        adaln_input : Optional [torch .FloatTensor ] =  None ,
533533        image_rotary_emb : torch .FloatTensor  =  None ,
534534    ) ->  torch .FloatTensor :
535-         wtype  =  image_tokens .dtype 
535+         wtype  =  hidden_states .dtype 
536536        (
537537            shift_msa_i ,
538538            scale_msa_i ,
@@ -549,74 +549,37 @@ def forward(
549549        ) =  self .adaLN_modulation (adaln_input )[:, None ].chunk (12 , dim = - 1 )
550550
551551        # 1. MM-Attention 
552-         norm_image_tokens  =  self .norm1_i (image_tokens ).to (dtype = wtype )
553-         norm_image_tokens  =  norm_image_tokens  *  (1  +  scale_msa_i ) +  shift_msa_i 
554-         norm_text_tokens  =  self .norm1_t (text_tokens ).to (dtype = wtype )
555-         norm_text_tokens  =  norm_text_tokens  *  (1  +  scale_msa_t ) +  shift_msa_t 
552+         norm_hidden_states  =  self .norm1_i (hidden_states ).to (dtype = wtype )
553+         norm_hidden_states  =  norm_hidden_states  *  (1  +  scale_msa_i ) +  shift_msa_i 
554+         norm_encoder_hidden_states  =  self .norm1_t (encoder_hidden_states ).to (dtype = wtype )
555+         norm_encoder_hidden_states  =  norm_encoder_hidden_states  *  (1  +  scale_msa_t ) +  shift_msa_t 
556556
557557        attn_output_i , attn_output_t  =  self .attn1 (
558-             norm_image_tokens ,
559-             image_tokens_masks ,
560-             norm_text_tokens ,
558+             norm_hidden_states ,
559+             hidden_states_masks ,
560+             norm_encoder_hidden_states ,
561561            image_rotary_emb = image_rotary_emb ,
562562        )
563563
564-         image_tokens  =  gate_msa_i  *  attn_output_i  +  image_tokens 
565-         text_tokens  =  gate_msa_t  *  attn_output_t  +  text_tokens 
564+         hidden_states  =  gate_msa_i  *  attn_output_i  +  hidden_states 
565+         encoder_hidden_states  =  gate_msa_t  *  attn_output_t  +  encoder_hidden_states 
566566
567567        # 2. Feed-forward 
568-         norm_image_tokens  =  self .norm3_i (image_tokens ).to (dtype = wtype )
569-         norm_image_tokens  =  norm_image_tokens  *  (1  +  scale_mlp_i ) +  shift_mlp_i 
570-         norm_text_tokens  =  self .norm3_t (text_tokens ).to (dtype = wtype )
571-         norm_text_tokens  =  norm_text_tokens  *  (1  +  scale_mlp_t ) +  shift_mlp_t 
572- 
573-         ff_output_i  =  gate_mlp_i  *  self .ff_i (norm_image_tokens )
574-         ff_output_t  =  gate_mlp_t  *  self .ff_t (norm_text_tokens )
575-         image_tokens  =  ff_output_i  +  image_tokens 
576-         text_tokens  =  ff_output_t  +  text_tokens 
577-         return  image_tokens , text_tokens 
568+         norm_hidden_states  =  self .norm3_i (hidden_states ).to (dtype = wtype )
569+         norm_hidden_states  =  norm_hidden_states  *  (1  +  scale_mlp_i ) +  shift_mlp_i 
570+         norm_encoder_hidden_states  =  self .norm3_t (encoder_hidden_states ).to (dtype = wtype )
571+         norm_encoder_hidden_states  =  norm_encoder_hidden_states  *  (1  +  scale_mlp_t ) +  shift_mlp_t 
578572
579- 
580- @maybe_allow_in_graph  
581- class  HiDreamImageBlock (nn .Module ):
582-     def  __init__ (
583-         self ,
584-         dim : int ,
585-         num_attention_heads : int ,
586-         attention_head_dim : int ,
587-         num_routed_experts : int  =  4 ,
588-         num_activated_experts : int  =  2 ,
589-         block_type : BlockType  =  BlockType .TransformerBlock ,
590-     ):
591-         super ().__init__ ()
592-         block_classes  =  {
593-             BlockType .TransformerBlock : HiDreamImageTransformerBlock ,
594-             BlockType .SingleTransformerBlock : HiDreamImageSingleTransformerBlock ,
595-         }
596-         self .block  =  block_classes [block_type ](
597-             dim , num_attention_heads , attention_head_dim , num_routed_experts , num_activated_experts 
598-         )
599- 
600-     def  forward (
601-         self ,
602-         image_tokens : torch .FloatTensor ,
603-         image_tokens_masks : Optional [torch .FloatTensor ] =  None ,
604-         text_tokens : Optional [torch .FloatTensor ] =  None ,
605-         adaln_input : torch .FloatTensor  =  None ,
606-         image_rotary_emb : torch .FloatTensor  =  None ,
607-     ) ->  torch .FloatTensor :
608-         return  self .block (
609-             image_tokens ,
610-             image_tokens_masks ,
611-             text_tokens ,
612-             adaln_input ,
613-             image_rotary_emb ,
614-         )
573+         ff_output_i  =  gate_mlp_i  *  self .ff_i (norm_hidden_states )
574+         ff_output_t  =  gate_mlp_t  *  self .ff_t (norm_encoder_hidden_states )
575+         hidden_states  =  ff_output_i  +  hidden_states 
576+         encoder_hidden_states  =  ff_output_t  +  encoder_hidden_states 
577+         return  hidden_states , encoder_hidden_states 
615578
616579
617580class  HiDreamImageTransformer2DModel (ModelMixin , ConfigMixin , PeftAdapterMixin ):
618581    _supports_gradient_checkpointing  =  True 
619-     _no_split_modules  =  ["HiDreamImageBlock " ]
582+     _no_split_modules  =  ["HiDreamImageTransformerBlock"  ,  "HiDreamImageSingleTransformerBlock "
620583
621584    @register_to_config  
622585    def  __init__ (
@@ -652,29 +615,27 @@ def __init__(
652615
653616        self .double_stream_blocks  =  nn .ModuleList (
654617            [
655-                 HiDreamImageBlock (
618+                 HiDreamImageTransformerBlock (
656619                    dim = self .inner_dim ,
657620                    num_attention_heads = self .config .num_attention_heads ,
658621                    attention_head_dim = self .config .attention_head_dim ,
659622                    num_routed_experts = num_routed_experts ,
660623                    num_activated_experts = num_activated_experts ,
661-                     block_type = BlockType .TransformerBlock ,
662624                )
663-                 for  i  in  range (self .config .num_layers )
625+                 for  _  in  range (self .config .num_layers )
664626            ]
665627        )
666628
667629        self .single_stream_blocks  =  nn .ModuleList (
668630            [
669-                 HiDreamImageBlock (
631+                 HiDreamImageSingleTransformerBlock (
670632                    dim = self .inner_dim ,
671633                    num_attention_heads = self .config .num_attention_heads ,
672634                    attention_head_dim = self .config .attention_head_dim ,
673635                    num_routed_experts = num_routed_experts ,
674636                    num_activated_experts = num_activated_experts ,
675-                     block_type = BlockType .SingleTransformerBlock ,
676637                )
677-                 for  i  in  range (self .config .num_single_layers )
638+                 for  _  in  range (self .config .num_single_layers )
678639            ]
679640        )
680641
@@ -816,8 +777,8 @@ def forward(
816777        p_embedder  =  self .p_embedder (pooled_embeds )
817778        adaln_input  =  timesteps  +  p_embedder 
818779
819-         hidden_states , image_tokens_masks , img_sizes  =  self .patchify (hidden_states , self .max_seq , img_sizes )
820-         if  image_tokens_masks  is  None :
780+         hidden_states , hidden_states_masks , img_sizes  =  self .patchify (hidden_states , self .max_seq , img_sizes )
781+         if  hidden_states_masks  is  None :
821782            pH , pW  =  img_sizes [0 ]
822783            img_ids  =  torch .zeros (pH , pW , 3 , device = hidden_states .device )
823784            img_ids [..., 1 ] =  img_ids [..., 1 ] +  torch .arange (pH , device = hidden_states .device )[:, None ]
@@ -869,16 +830,16 @@ def forward(
869830                hidden_states , initial_encoder_hidden_states  =  self ._gradient_checkpointing_func (
870831                    block ,
871832                    hidden_states ,
872-                     image_tokens_masks ,
833+                     hidden_states_masks ,
873834                    cur_encoder_hidden_states ,
874835                    adaln_input ,
875836                    image_rotary_emb ,
876837                )
877838            else :
878839                hidden_states , initial_encoder_hidden_states  =  block (
879-                     image_tokens = hidden_states ,
880-                     image_tokens_masks = image_tokens_masks ,
881-                     text_tokens = cur_encoder_hidden_states ,
840+                     hidden_states = hidden_states ,
841+                     hidden_states_masks = hidden_states_masks ,
842+                     encoder_hidden_states = cur_encoder_hidden_states ,
882843                    adaln_input = adaln_input ,
883844                    image_rotary_emb = image_rotary_emb ,
884845                )
@@ -888,13 +849,13 @@ def forward(
888849        image_tokens_seq_len  =  hidden_states .shape [1 ]
889850        hidden_states  =  torch .cat ([hidden_states , initial_encoder_hidden_states ], dim = 1 )
890851        hidden_states_seq_len  =  hidden_states .shape [1 ]
891-         if  image_tokens_masks  is  not None :
852+         if  hidden_states_masks  is  not None :
892853            encoder_attention_mask_ones  =  torch .ones (
893854                (batch_size , initial_encoder_hidden_states .shape [1 ] +  cur_llama31_encoder_hidden_states .shape [1 ]),
894-                 device = image_tokens_masks .device ,
895-                 dtype = image_tokens_masks .dtype ,
855+                 device = hidden_states_masks .device ,
856+                 dtype = hidden_states_masks .dtype ,
896857            )
897-             image_tokens_masks  =  torch .cat ([image_tokens_masks , encoder_attention_mask_ones ], dim = 1 )
858+             hidden_states_masks  =  torch .cat ([hidden_states_masks , encoder_attention_mask_ones ], dim = 1 )
898859
899860        for  bid , block  in  enumerate (self .single_stream_blocks ):
900861            cur_llama31_encoder_hidden_states  =  encoder_hidden_states [block_id ]
@@ -903,16 +864,16 @@ def forward(
903864                hidden_states  =  self ._gradient_checkpointing_func (
904865                    block ,
905866                    hidden_states ,
906-                     image_tokens_masks ,
867+                     hidden_states_masks ,
907868                    None ,
908869                    adaln_input ,
909870                    image_rotary_emb ,
910871                )
911872            else :
912873                hidden_states  =  block (
913-                     image_tokens = hidden_states ,
914-                     image_tokens_masks = image_tokens_masks ,
915-                     text_tokens = None ,
874+                     hidden_states = hidden_states ,
875+                     hidden_states_masks = hidden_states_masks ,
876+                     encoder_hidden_states = None ,
916877                    adaln_input = adaln_input ,
917878                    image_rotary_emb = image_rotary_emb ,
918879                )
@@ -922,13 +883,13 @@ def forward(
922883        hidden_states  =  hidden_states [:, :image_tokens_seq_len , ...]
923884        output  =  self .final_layer (hidden_states , adaln_input )
924885        output  =  self .unpatchify (output , img_sizes , self .training )
925-         if  image_tokens_masks  is  not None :
926-             image_tokens_masks  =  image_tokens_masks [:, :image_tokens_seq_len ]
886+         if  hidden_states_masks  is  not None :
887+             hidden_states_masks  =  hidden_states_masks [:, :image_tokens_seq_len ]
927888
928889        if  USE_PEFT_BACKEND :
929890            # remove `lora_scale` from each PEFT layer 
930891            unscale_lora_layers (self , lora_scale )
931892
932893        if  not  return_dict :
933-             return  (output , image_tokens_masks )
934-         return  Transformer2DModelOutput (sample = output , mask = image_tokens_masks )
894+             return  (output , hidden_states_masks )
895+         return  Transformer2DModelOutput (sample = output , mask = hidden_states_masks )
0 commit comments