@@ -604,8 +604,7 @@ def __init__(
604604    ):
605605        super ().__init__ ()
606606        self .out_channels  =  out_channels  or  in_channels 
607-         self .inner_dim  =  self .config .num_attention_heads  *  self .config .attention_head_dim 
608-         self .llama_layers  =  llama_layers 
607+         self .inner_dim  =  num_attention_heads  *  attention_head_dim 
609608
610609        self .t_embedder  =  HiDreamImageTimestepEmbed (self .inner_dim )
611610        self .p_embedder  =  HiDreamImagePooledEmbed (text_emb_dim , self .inner_dim )
@@ -621,13 +620,13 @@ def __init__(
621620                HiDreamBlock (
622621                    HiDreamImageTransformerBlock (
623622                        dim = self .inner_dim ,
624-                         num_attention_heads = self . config . num_attention_heads ,
625-                         attention_head_dim = self . config . attention_head_dim ,
623+                         num_attention_heads = num_attention_heads ,
624+                         attention_head_dim = attention_head_dim ,
626625                        num_routed_experts = num_routed_experts ,
627626                        num_activated_experts = num_activated_experts ,
628627                    )
629628                )
630-                 for  _  in  range (self . config . num_layers )
629+                 for  _  in  range (num_layers )
631630            ]
632631        )
633632
@@ -636,43 +635,25 @@ def __init__(
636635                HiDreamBlock (
637636                    HiDreamImageSingleTransformerBlock (
638637                        dim = self .inner_dim ,
639-                         num_attention_heads = self . config . num_attention_heads ,
640-                         attention_head_dim = self . config . attention_head_dim ,
638+                         num_attention_heads = num_attention_heads ,
639+                         attention_head_dim = attention_head_dim ,
641640                        num_routed_experts = num_routed_experts ,
642641                        num_activated_experts = num_activated_experts ,
643642                    )
644643                )
645-                 for  _  in  range (self . config . num_single_layers )
644+                 for  _  in  range (num_single_layers )
646645            ]
647646        )
648647
649648        self .final_layer  =  HiDreamImageOutEmbed (self .inner_dim , patch_size , self .out_channels )
650649
651-         caption_channels  =  [
652-             caption_channels [1 ],
653-         ] *  (num_layers  +  num_single_layers ) +  [
654-             caption_channels [0 ],
655-         ]
650+         caption_channels  =  [caption_channels [1 ]] *  (num_layers  +  num_single_layers ) +  [caption_channels [0 ]]
656651        caption_projection  =  []
657652        for  caption_channel  in  caption_channels :
658653            caption_projection .append (TextProjection (in_features = caption_channel , hidden_size = self .inner_dim ))
659654        self .caption_projection  =  nn .ModuleList (caption_projection )
660655        self .max_seq  =  max_resolution [0 ] *  max_resolution [1 ] //  (patch_size  *  patch_size )
661656
662-     def  expand_timesteps (self , timesteps , batch_size , device ):
663-         if  not  torch .is_tensor (timesteps ):
664-             is_mps  =  device .type  ==  "mps" 
665-             if  isinstance (timesteps , float ):
666-                 dtype  =  torch .float32  if  is_mps  else  torch .float64 
667-             else :
668-                 dtype  =  torch .int32  if  is_mps  else  torch .int64 
669-             timesteps  =  torch .tensor ([timesteps ], dtype = dtype , device = device )
670-         elif  len (timesteps .shape ) ==  0 :
671-             timesteps  =  timesteps [None ].to (device )
672-         # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 
673-         timesteps  =  timesteps .expand (batch_size )
674-         return  timesteps 
675- 
676657    def  unpatchify (self , x : torch .Tensor , img_sizes : List [Tuple [int , int ]], is_training : bool ) ->  List [torch .Tensor ]:
677658        if  is_training :
678659            B , S , F  =  x .shape 
@@ -773,7 +754,6 @@ def forward(
773754            hidden_states  =  out 
774755
775756        # 0. time 
776-         timesteps  =  self .expand_timesteps (timesteps , batch_size , hidden_states .device )
777757        timesteps  =  self .t_embedder (timesteps , hidden_states_type )
778758        p_embedder  =  self .p_embedder (pooled_embeds )
779759        temb  =  timesteps  +  p_embedder 
@@ -793,7 +773,7 @@ def forward(
793773
794774        T5_encoder_hidden_states  =  encoder_hidden_states [0 ]
795775        encoder_hidden_states  =  encoder_hidden_states [- 1 ]
796-         encoder_hidden_states  =  [encoder_hidden_states [k ] for  k  in  self .llama_layers ]
776+         encoder_hidden_states  =  [encoder_hidden_states [k ] for  k  in  self .config . llama_layers ]
797777
798778        if  self .caption_projection  is  not None :
799779            new_encoder_hidden_states  =  []
0 commit comments