Skip to content

Commit 7021362

Browse files
committed
HiDream Image
1 parent 723dbdd commit 7021362

File tree

12 files changed

+1806
-0
lines changed

12 files changed

+1806
-0
lines changed

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@
169169
"FluxControlNetModel",
170170
"FluxMultiControlNetModel",
171171
"FluxTransformer2DModel",
172+
"HiDreamImageTransformer2DModel",
172173
"HunyuanDiT2DControlNetModel",
173174
"HunyuanDiT2DModel",
174175
"HunyuanDiT2DMultiControlNetModel",
@@ -366,6 +367,7 @@
366367
"FluxInpaintPipeline",
367368
"FluxPipeline",
368369
"FluxPriorReduxPipeline",
370+
"HiDreamImagePipeline",
369371
"HunyuanDiTControlNetPipeline",
370372
"HunyuanDiTPAGPipeline",
371373
"HunyuanDiTPipeline",
@@ -745,6 +747,7 @@
745747
FluxControlNetModel,
746748
FluxMultiControlNetModel,
747749
FluxTransformer2DModel,
750+
HiDreamImageTransformer2DModel,
748751
HunyuanDiT2DControlNetModel,
749752
HunyuanDiT2DModel,
750753
HunyuanDiT2DMultiControlNetModel,
@@ -921,6 +924,7 @@
921924
FluxInpaintPipeline,
922925
FluxPipeline,
923926
FluxPriorReduxPipeline,
927+
HiDreamImagePipeline,
924928
HunyuanDiTControlNetPipeline,
925929
HunyuanDiTPAGPipeline,
926930
HunyuanDiTPipeline,

src/diffusers/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
7676
_import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"]
7777
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
78+
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
7879
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
7980
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
8081
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
@@ -149,6 +150,7 @@
149150
DualTransformer2DModel,
150151
EasyAnimateTransformer3DModel,
151152
FluxTransformer2DModel,
153+
HiDreamImageTransformer2DModel,
152154
HunyuanDiT2DModel,
153155
HunyuanVideoTransformer3DModel,
154156
LatteTransformer3DModel,

src/diffusers/models/attention.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1249,3 +1249,33 @@ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
12491249
for module in self.net:
12501250
hidden_states = module(hidden_states)
12511251
return hidden_states
1252+
1253+
1254+
class HiDreamImageFeedForwardSwiGLU(nn.Module):
1255+
def __init__(
1256+
self,
1257+
dim: int,
1258+
hidden_dim: int,
1259+
multiple_of: int = 256,
1260+
ffn_dim_multiplier: Optional[float] = None,
1261+
):
1262+
super().__init__()
1263+
hidden_dim = int(2 * hidden_dim / 3)
1264+
# custom dim factor multiplier
1265+
if ffn_dim_multiplier is not None:
1266+
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
1267+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
1268+
1269+
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
1270+
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
1271+
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
1272+
self.apply(self._init_weights)
1273+
1274+
def _init_weights(self, m):
1275+
if isinstance(m, nn.Linear):
1276+
nn.init.xavier_uniform_(m.weight)
1277+
if m.bias is not None:
1278+
nn.init.constant_(m.bias, 0)
1279+
1280+
def forward(self, x):
1281+
return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))

src/diffusers/models/embeddings.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2621,3 +2621,59 @@ def forward(self, image_embeds: List[torch.Tensor]):
26212621
projected_image_embeds.append(image_embed)
26222622

26232623
return projected_image_embeds
2624+
2625+
2626+
class HiDreamImagePooledEmbed(nn.Module):
2627+
def __init__(self, text_emb_dim, hidden_size):
2628+
super().__init__()
2629+
self.pooled_embedder = TimestepEmbedding(in_channels=text_emb_dim, time_embed_dim=hidden_size)
2630+
self.apply(self._init_weights)
2631+
2632+
def _init_weights(self, m):
2633+
if isinstance(m, nn.Linear):
2634+
nn.init.normal_(m.weight, std=0.02)
2635+
if m.bias is not None:
2636+
nn.init.constant_(m.bias, 0)
2637+
2638+
def forward(self, pooled_embed):
2639+
return self.pooled_embedder(pooled_embed)
2640+
2641+
2642+
class HiDreamImageTimestepEmbed(nn.Module):
2643+
def __init__(self, hidden_size, frequency_embedding_size=256):
2644+
super().__init__()
2645+
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0)
2646+
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size)
2647+
self.apply(self._init_weights)
2648+
2649+
def _init_weights(self, m):
2650+
if isinstance(m, nn.Linear):
2651+
nn.init.normal_(m.weight, std=0.02)
2652+
if m.bias is not None:
2653+
nn.init.constant_(m.bias, 0)
2654+
2655+
def forward(self, timesteps, wdtype):
2656+
t_emb = self.time_proj(timesteps).to(dtype=wdtype)
2657+
t_emb = self.timestep_embedder(t_emb)
2658+
return t_emb
2659+
2660+
2661+
class HiDreamImageOutEmbed(nn.Module):
2662+
def __init__(self, hidden_size, patch_size, out_channels):
2663+
super().__init__()
2664+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
2665+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
2666+
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
2667+
self.apply(self._init_weights)
2668+
2669+
def _init_weights(self, m):
2670+
if isinstance(m, nn.Linear):
2671+
nn.init.zeros_(m.weight)
2672+
if m.bias is not None:
2673+
nn.init.constant_(m.bias, 0)
2674+
2675+
def forward(self, x, adaln_input):
2676+
shift, scale = self.adaLN_modulation(adaln_input).chunk(2, dim=1)
2677+
x = self.norm_final(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
2678+
x = self.linear(x)
2679+
return x

src/diffusers/models/transformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .transformer_cogview4 import CogView4Transformer2DModel
2222
from .transformer_easyanimate import EasyAnimateTransformer3DModel
2323
from .transformer_flux import FluxTransformer2DModel
24+
from .transformer_hidream_image import HiDreamImageTransformer2DModel
2425
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
2526
from .transformer_ltx import LTXVideoTransformer3DModel
2627
from .transformer_lumina2 import Lumina2Transformer2DModel

0 commit comments

Comments
 (0)