Skip to content
Merged
Changes from 27 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
642203e
HiDream Image
hlky Apr 8, 2025
3b5e03b
update
hlky Apr 10, 2025
a40c95f
-einops
hlky Apr 10, 2025
eb798ed
Merge branch 'main' into hidream
hlky Apr 10, 2025
a90372e
py3.8
hlky Apr 10, 2025
f94d68e
Merge branch 'hidream' of https://github.com/hlky/diffusers into hidream
hlky Apr 10, 2025
b8aa38d
fix -einops
hlky Apr 10, 2025
9d43a32
mixins, offload_seq, option_components
hlky Apr 10, 2025
e1766a1
docs
hlky Apr 10, 2025
8fbc630
Apply style fixes
github-actions[bot] Apr 10, 2025
b6b9b45
trigger tests
hlky Apr 10, 2025
8dd065b
Apply suggestions from code review
hlky Apr 10, 2025
8b2670d
joint_attention_kwargs -> attention_kwargs, fixes
hlky Apr 10, 2025
f2aa727
fast tests
hlky Apr 10, 2025
8e328f3
-_init_weights
hlky Apr 10, 2025
7c4eced
style tests
hlky Apr 10, 2025
07c670e
move reshape logic
hlky Apr 10, 2025
efc44ea
update slice 😴
hlky Apr 10, 2025
745bcec
supports_dduf
hlky Apr 10, 2025
c1abec6
🤷🏻‍♂️
hlky Apr 11, 2025
9eb0b8b
Update src/diffusers/models/transformers/transformer_hidream_image.py
hlky Apr 11, 2025
32af5ce
address review comments
a-r-r-o-w Apr 11, 2025
3ec1896
update tests
a-r-r-o-w Apr 11, 2025
2d65aa2
doc updates
a-r-r-o-w Apr 11, 2025
72c9667
Merge branch 'main' into hidream
a-r-r-o-w Apr 11, 2025
5e0bca0
update
a-r-r-o-w Apr 12, 2025
3044fe0
Merge branch 'main' into refactor/hidream
a-r-r-o-w Apr 12, 2025
a68c103
Update src/diffusers/models/transformers/transformer_hidream_image.py
a-r-r-o-w Apr 12, 2025
13a9016
Apply style fixes
github-actions[bot] Apr 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 9 additions & 29 deletions src/diffusers/models/transformers/transformer_hidream_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,8 +604,7 @@ def __init__(
):
super().__init__()
self.out_channels = out_channels or in_channels
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.llama_layers = llama_layers
self.inner_dim = num_attention_heads * attention_head_dim

self.t_embedder = HiDreamImageTimestepEmbed(self.inner_dim)
self.p_embedder = HiDreamImagePooledEmbed(text_emb_dim, self.inner_dim)
Expand All @@ -621,13 +620,13 @@ def __init__(
HiDreamBlock(
HiDreamImageTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
num_routed_experts=num_routed_experts,
num_activated_experts=num_activated_experts,
)
)
for _ in range(self.config.num_layers)
for _ in range(num_layers)
]
)

Expand All @@ -636,43 +635,25 @@ def __init__(
HiDreamBlock(
HiDreamImageSingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
num_routed_experts=num_routed_experts,
num_activated_experts=num_activated_experts,
)
)
for _ in range(self.config.num_single_layers)
for _ in range(num_single_layers)
]
)

self.final_layer = HiDreamImageOutEmbed(self.inner_dim, patch_size, self.out_channels)

caption_channels = [
caption_channels[1],
] * (num_layers + num_single_layers) + [
caption_channels[0],
]
caption_channels = [caption_channels[1]] * (num_layers + num_single_layers) + [caption_channels[0]]
caption_projection = []
for caption_channel in caption_channels:
caption_projection.append(TextProjection(in_features=caption_channel, hidden_size=self.inner_dim))
self.caption_projection = nn.ModuleList(caption_projection)
self.max_seq = max_resolution[0] * max_resolution[1] // (patch_size * patch_size)

def expand_timesteps(self, timesteps, batch_size, device):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not needed because:

  • we already pass tensors from the pipeline
  • we already expand the shape based on batch_size in MPS friendly manner
  • we already pass torch.int64 timesteps

if not torch.is_tensor(timesteps):
is_mps = device.type == "mps"
if isinstance(timesteps, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(batch_size)
return timesteps

def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]:
if is_training:
B, S, F = x.shape
Expand Down Expand Up @@ -773,7 +754,6 @@ def forward(
hidden_states = out

# 0. time
timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device)
timesteps = self.t_embedder(timesteps, hidden_states_type)
p_embedder = self.p_embedder(pooled_embeds)
temb = timesteps + p_embedder
Expand All @@ -793,7 +773,7 @@ def forward(

T5_encoder_hidden_states = encoder_hidden_states[0]
encoder_hidden_states = encoder_hidden_states[-1]
encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers]
encoder_hidden_states = [encoder_hidden_states[k] for k in self.config.llama_layers]

if self.caption_projection is not None:
new_encoder_hidden_states = []
Expand Down