Skip to content

Commit 5e0bca0

Browse files
committed
update
1 parent 72c9667 commit 5e0bca0

File tree

1 file changed

+9
-29
lines changed

1 file changed

+9
-29
lines changed

src/diffusers/models/transformers/transformer_hidream_image.py

Lines changed: 9 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -604,8 +604,7 @@ def __init__(
604604
):
605605
super().__init__()
606606
self.out_channels = out_channels or in_channels
607-
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
608-
self.llama_layers = llama_layers
607+
self.inner_dim = num_attention_heads * attention_head_dim
609608

610609
self.t_embedder = HiDreamImageTimestepEmbed(self.inner_dim)
611610
self.p_embedder = HiDreamImagePooledEmbed(text_emb_dim, self.inner_dim)
@@ -621,13 +620,13 @@ def __init__(
621620
HiDreamBlock(
622621
HiDreamImageTransformerBlock(
623622
dim=self.inner_dim,
624-
num_attention_heads=self.config.num_attention_heads,
625-
attention_head_dim=self.config.attention_head_dim,
623+
num_attention_heads=num_attention_heads,
624+
attention_head_dim=attention_head_dim,
626625
num_routed_experts=num_routed_experts,
627626
num_activated_experts=num_activated_experts,
628627
)
629628
)
630-
for _ in range(self.config.num_layers)
629+
for _ in range(num_layers)
631630
]
632631
)
633632

@@ -636,43 +635,25 @@ def __init__(
636635
HiDreamBlock(
637636
HiDreamImageSingleTransformerBlock(
638637
dim=self.inner_dim,
639-
num_attention_heads=self.config.num_attention_heads,
640-
attention_head_dim=self.config.attention_head_dim,
638+
num_attention_heads=num_attention_heads,
639+
attention_head_dim=attention_head_dim,
641640
num_routed_experts=num_routed_experts,
642641
num_activated_experts=num_activated_experts,
643642
)
644643
)
645-
for _ in range(self.config.num_single_layers)
644+
for _ in range(num_single_layers)
646645
]
647646
)
648647

649648
self.final_layer = HiDreamImageOutEmbed(self.inner_dim, patch_size, self.out_channels)
650649

651-
caption_channels = [
652-
caption_channels[1],
653-
] * (num_layers + num_single_layers) + [
654-
caption_channels[0],
655-
]
650+
caption_channels = [caption_channels[1]] * (num_layers + num_single_layers) + [caption_channels[0]]
656651
caption_projection = []
657652
for caption_channel in caption_channels:
658653
caption_projection.append(TextProjection(in_features=caption_channel, hidden_size=self.inner_dim))
659654
self.caption_projection = nn.ModuleList(caption_projection)
660655
self.max_seq = max_resolution[0] * max_resolution[1] // (patch_size * patch_size)
661656

662-
def expand_timesteps(self, timesteps, batch_size, device):
663-
if not torch.is_tensor(timesteps):
664-
is_mps = device.type == "mps"
665-
if isinstance(timesteps, float):
666-
dtype = torch.float32 if is_mps else torch.float64
667-
else:
668-
dtype = torch.int32 if is_mps else torch.int64
669-
timesteps = torch.tensor([timesteps], dtype=dtype, device=device)
670-
elif len(timesteps.shape) == 0:
671-
timesteps = timesteps[None].to(device)
672-
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
673-
timesteps = timesteps.expand(batch_size)
674-
return timesteps
675-
676657
def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]:
677658
if is_training:
678659
B, S, F = x.shape
@@ -773,7 +754,6 @@ def forward(
773754
hidden_states = out
774755

775756
# 0. time
776-
timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device)
777757
timesteps = self.t_embedder(timesteps, hidden_states_type)
778758
p_embedder = self.p_embedder(pooled_embeds)
779759
temb = timesteps + p_embedder
@@ -793,7 +773,7 @@ def forward(
793773

794774
T5_encoder_hidden_states = encoder_hidden_states[0]
795775
encoder_hidden_states = encoder_hidden_states[-1]
796-
encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers]
776+
encoder_hidden_states = [encoder_hidden_states[k] for k in self.config.llama_layers]
797777

798778
if self.caption_projection is not None:
799779
new_encoder_hidden_states = []

0 commit comments

Comments
 (0)