diff --git a/src/diffusers/models/transformers/transformer_hidream_image.py b/src/diffusers/models/transformers/transformer_hidream_image.py index 2bdf7d152268..43949f797c3d 100644 --- a/src/diffusers/models/transformers/transformer_hidream_image.py +++ b/src/diffusers/models/transformers/transformer_hidream_image.py @@ -604,8 +604,7 @@ def __init__( ): super().__init__() self.out_channels = out_channels or in_channels - self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim - self.llama_layers = llama_layers + self.inner_dim = num_attention_heads * attention_head_dim self.t_embedder = HiDreamImageTimestepEmbed(self.inner_dim) self.p_embedder = HiDreamImagePooledEmbed(text_emb_dim, self.inner_dim) @@ -621,13 +620,13 @@ def __init__( HiDreamBlock( HiDreamImageTransformerBlock( dim=self.inner_dim, - num_attention_heads=self.config.num_attention_heads, - attention_head_dim=self.config.attention_head_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, num_routed_experts=num_routed_experts, num_activated_experts=num_activated_experts, ) ) - for _ in range(self.config.num_layers) + for _ in range(num_layers) ] ) @@ -636,42 +635,26 @@ def __init__( HiDreamBlock( HiDreamImageSingleTransformerBlock( dim=self.inner_dim, - num_attention_heads=self.config.num_attention_heads, - attention_head_dim=self.config.attention_head_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, num_routed_experts=num_routed_experts, num_activated_experts=num_activated_experts, ) ) - for _ in range(self.config.num_single_layers) + for _ in range(num_single_layers) ] ) self.final_layer = HiDreamImageOutEmbed(self.inner_dim, patch_size, self.out_channels) - caption_channels = [ - caption_channels[1], - ] * (num_layers + num_single_layers) + [ - caption_channels[0], - ] + caption_channels = [caption_channels[1]] * (num_layers + num_single_layers) + [caption_channels[0]] caption_projection = [] for caption_channel in caption_channels: caption_projection.append(TextProjection(in_features=caption_channel, hidden_size=self.inner_dim)) self.caption_projection = nn.ModuleList(caption_projection) self.max_seq = max_resolution[0] * max_resolution[1] // (patch_size * patch_size) - def expand_timesteps(self, timesteps, batch_size, device): - if not torch.is_tensor(timesteps): - is_mps = device.type == "mps" - if isinstance(timesteps, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - timesteps = torch.tensor([timesteps], dtype=dtype, device=device) - elif len(timesteps.shape) == 0: - timesteps = timesteps[None].to(device) - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(batch_size) - return timesteps + self.gradient_checkpointing = False def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]: if is_training: @@ -773,7 +756,6 @@ def forward( hidden_states = out # 0. time - timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device) timesteps = self.t_embedder(timesteps, hidden_states_type) p_embedder = self.p_embedder(pooled_embeds) temb = timesteps + p_embedder @@ -793,7 +775,7 @@ def forward( T5_encoder_hidden_states = encoder_hidden_states[0] encoder_hidden_states = encoder_hidden_states[-1] - encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers] + encoder_hidden_states = [encoder_hidden_states[k] for k in self.config.llama_layers] if self.caption_projection is not None: new_encoder_hidden_states = []