@@ -604,8 +604,7 @@ def __init__(
604604 ):
605605 super ().__init__ ()
606606 self .out_channels = out_channels or in_channels
607- self .inner_dim = self .config .num_attention_heads * self .config .attention_head_dim
608- self .llama_layers = llama_layers
607+ self .inner_dim = num_attention_heads * attention_head_dim
609608
610609 self .t_embedder = HiDreamImageTimestepEmbed (self .inner_dim )
611610 self .p_embedder = HiDreamImagePooledEmbed (text_emb_dim , self .inner_dim )
@@ -621,13 +620,13 @@ def __init__(
621620 HiDreamBlock (
622621 HiDreamImageTransformerBlock (
623622 dim = self .inner_dim ,
624- num_attention_heads = self . config . num_attention_heads ,
625- attention_head_dim = self . config . attention_head_dim ,
623+ num_attention_heads = num_attention_heads ,
624+ attention_head_dim = attention_head_dim ,
626625 num_routed_experts = num_routed_experts ,
627626 num_activated_experts = num_activated_experts ,
628627 )
629628 )
630- for _ in range (self . config . num_layers )
629+ for _ in range (num_layers )
631630 ]
632631 )
633632
@@ -636,42 +635,26 @@ def __init__(
636635 HiDreamBlock (
637636 HiDreamImageSingleTransformerBlock (
638637 dim = self .inner_dim ,
639- num_attention_heads = self . config . num_attention_heads ,
640- attention_head_dim = self . config . attention_head_dim ,
638+ num_attention_heads = num_attention_heads ,
639+ attention_head_dim = attention_head_dim ,
641640 num_routed_experts = num_routed_experts ,
642641 num_activated_experts = num_activated_experts ,
643642 )
644643 )
645- for _ in range (self . config . num_single_layers )
644+ for _ in range (num_single_layers )
646645 ]
647646 )
648647
649648 self .final_layer = HiDreamImageOutEmbed (self .inner_dim , patch_size , self .out_channels )
650649
651- caption_channels = [
652- caption_channels [1 ],
653- ] * (num_layers + num_single_layers ) + [
654- caption_channels [0 ],
655- ]
650+ caption_channels = [caption_channels [1 ]] * (num_layers + num_single_layers ) + [caption_channels [0 ]]
656651 caption_projection = []
657652 for caption_channel in caption_channels :
658653 caption_projection .append (TextProjection (in_features = caption_channel , hidden_size = self .inner_dim ))
659654 self .caption_projection = nn .ModuleList (caption_projection )
660655 self .max_seq = max_resolution [0 ] * max_resolution [1 ] // (patch_size * patch_size )
661656
662- def expand_timesteps (self , timesteps , batch_size , device ):
663- if not torch .is_tensor (timesteps ):
664- is_mps = device .type == "mps"
665- if isinstance (timesteps , float ):
666- dtype = torch .float32 if is_mps else torch .float64
667- else :
668- dtype = torch .int32 if is_mps else torch .int64
669- timesteps = torch .tensor ([timesteps ], dtype = dtype , device = device )
670- elif len (timesteps .shape ) == 0 :
671- timesteps = timesteps [None ].to (device )
672- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
673- timesteps = timesteps .expand (batch_size )
674- return timesteps
657+ self .gradient_checkpointing = False
675658
676659 def unpatchify (self , x : torch .Tensor , img_sizes : List [Tuple [int , int ]], is_training : bool ) -> List [torch .Tensor ]:
677660 if is_training :
@@ -773,7 +756,6 @@ def forward(
773756 hidden_states = out
774757
775758 # 0. time
776- timesteps = self .expand_timesteps (timesteps , batch_size , hidden_states .device )
777759 timesteps = self .t_embedder (timesteps , hidden_states_type )
778760 p_embedder = self .p_embedder (pooled_embeds )
779761 temb = timesteps + p_embedder
@@ -793,7 +775,7 @@ def forward(
793775
794776 T5_encoder_hidden_states = encoder_hidden_states [0 ]
795777 encoder_hidden_states = encoder_hidden_states [- 1 ]
796- encoder_hidden_states = [encoder_hidden_states [k ] for k in self .llama_layers ]
778+ encoder_hidden_states = [encoder_hidden_states [k ] for k in self .config . llama_layers ]
797779
798780 if self .caption_projection is not None :
799781 new_encoder_hidden_states = []
0 commit comments