Skip to content

Commit 3b5e03b

Browse files
committed
update
1 parent 642203e commit 3b5e03b

File tree

3 files changed

+146
-149
lines changed

3 files changed

+146
-149
lines changed

src/diffusers/models/attention.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1249,33 +1249,3 @@ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
12491249
for module in self.net:
12501250
hidden_states = module(hidden_states)
12511251
return hidden_states
1252-
1253-
1254-
class HiDreamImageFeedForwardSwiGLU(nn.Module):
1255-
def __init__(
1256-
self,
1257-
dim: int,
1258-
hidden_dim: int,
1259-
multiple_of: int = 256,
1260-
ffn_dim_multiplier: Optional[float] = None,
1261-
):
1262-
super().__init__()
1263-
hidden_dim = int(2 * hidden_dim / 3)
1264-
# custom dim factor multiplier
1265-
if ffn_dim_multiplier is not None:
1266-
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
1267-
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
1268-
1269-
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
1270-
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
1271-
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
1272-
self.apply(self._init_weights)
1273-
1274-
def _init_weights(self, m):
1275-
if isinstance(m, nn.Linear):
1276-
nn.init.xavier_uniform_(m.weight)
1277-
if m.bias is not None:
1278-
nn.init.constant_(m.bias, 0)
1279-
1280-
def forward(self, x):
1281-
return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))

src/diffusers/models/embeddings.py

Lines changed: 0 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -2621,114 +2621,3 @@ def forward(self, image_embeds: List[torch.Tensor]):
26212621
projected_image_embeds.append(image_embed)
26222622

26232623
return projected_image_embeds
2624-
2625-
2626-
class HiDreamImagePooledEmbed(nn.Module):
2627-
def __init__(self, text_emb_dim, hidden_size):
2628-
super().__init__()
2629-
self.pooled_embedder = TimestepEmbedding(in_channels=text_emb_dim, time_embed_dim=hidden_size)
2630-
self.apply(self._init_weights)
2631-
2632-
def _init_weights(self, m):
2633-
if isinstance(m, nn.Linear):
2634-
nn.init.normal_(m.weight, std=0.02)
2635-
if m.bias is not None:
2636-
nn.init.constant_(m.bias, 0)
2637-
2638-
def forward(self, pooled_embed):
2639-
return self.pooled_embedder(pooled_embed)
2640-
2641-
2642-
class HiDreamImageTimestepEmbed(nn.Module):
2643-
def __init__(self, hidden_size, frequency_embedding_size=256):
2644-
super().__init__()
2645-
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0)
2646-
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size)
2647-
self.apply(self._init_weights)
2648-
2649-
def _init_weights(self, m):
2650-
if isinstance(m, nn.Linear):
2651-
nn.init.normal_(m.weight, std=0.02)
2652-
if m.bias is not None:
2653-
nn.init.constant_(m.bias, 0)
2654-
2655-
def forward(self, timesteps, wdtype):
2656-
t_emb = self.time_proj(timesteps).to(dtype=wdtype)
2657-
t_emb = self.timestep_embedder(t_emb)
2658-
return t_emb
2659-
2660-
2661-
class HiDreamImageOutEmbed(nn.Module):
2662-
def __init__(self, hidden_size, patch_size, out_channels):
2663-
super().__init__()
2664-
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
2665-
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
2666-
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
2667-
self.apply(self._init_weights)
2668-
2669-
def _init_weights(self, m):
2670-
if isinstance(m, nn.Linear):
2671-
nn.init.zeros_(m.weight)
2672-
if m.bias is not None:
2673-
nn.init.constant_(m.bias, 0)
2674-
2675-
def forward(self, x, adaln_input):
2676-
shift, scale = self.adaLN_modulation(adaln_input).chunk(2, dim=1)
2677-
x = self.norm_final(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
2678-
x = self.linear(x)
2679-
return x
2680-
2681-
2682-
class HiDreamImagePatchEmbed(nn.Module):
2683-
def __init__(
2684-
self,
2685-
patch_size=2,
2686-
in_channels=4,
2687-
out_channels=1024,
2688-
):
2689-
super().__init__()
2690-
self.patch_size = patch_size
2691-
self.out_channels = out_channels
2692-
self.proj = nn.Linear(in_channels * patch_size * patch_size, out_channels, bias=True)
2693-
self.apply(self._init_weights)
2694-
2695-
def _init_weights(self, m):
2696-
if isinstance(m, nn.Linear):
2697-
nn.init.xavier_uniform_(m.weight)
2698-
if m.bias is not None:
2699-
nn.init.constant_(m.bias, 0)
2700-
2701-
def forward(self, latent):
2702-
latent = self.proj(latent)
2703-
return latent
2704-
2705-
2706-
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
2707-
assert dim % 2 == 0, "The dimension must be even."
2708-
2709-
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
2710-
omega = 1.0 / (theta**scale)
2711-
2712-
batch_size, seq_length = pos.shape
2713-
out = torch.einsum("...n,d->...nd", pos, omega)
2714-
cos_out = torch.cos(out)
2715-
sin_out = torch.sin(out)
2716-
2717-
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
2718-
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
2719-
return out.float()
2720-
2721-
2722-
class HiDreamImageEmbedND(nn.Module):
2723-
def __init__(self, theta: int, axes_dim: List[int]):
2724-
super().__init__()
2725-
self.theta = theta
2726-
self.axes_dim = axes_dim
2727-
2728-
def forward(self, ids: torch.Tensor) -> torch.Tensor:
2729-
n_axes = ids.shape[-1]
2730-
emb = torch.cat(
2731-
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
2732-
dim=-3,
2733-
)
2734-
return emb.unsqueeze(2)

src/diffusers/models/transformers/transformer_hidream_image.py

Lines changed: 146 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,27 +5,164 @@
55
import torch
66
import torch.nn as nn
77
import torch.nn.functional as F
8-
from einops import repeat
98

109
from ...configuration_utils import ConfigMixin, register_to_config
1110
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
1211
from ...models.modeling_outputs import Transformer2DModelOutput
1312
from ...models.modeling_utils import ModelMixin
1413
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
1514
from ...utils.torch_utils import maybe_allow_in_graph
16-
from ..attention import Attention, HiDreamImageFeedForwardSwiGLU
15+
from ..attention import Attention
1716
from ..embeddings import (
18-
HiDreamImageEmbedND,
19-
HiDreamImageOutEmbed,
20-
HiDreamImagePatchEmbed,
21-
HiDreamImagePooledEmbed,
22-
HiDreamImageTimestepEmbed,
17+
TimestepEmbedding,
18+
Timesteps,
2319
)
2420

2521

2622
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
2723

2824

25+
class HiDreamImageFeedForwardSwiGLU(nn.Module):
26+
def __init__(
27+
self,
28+
dim: int,
29+
hidden_dim: int,
30+
multiple_of: int = 256,
31+
ffn_dim_multiplier: Optional[float] = None,
32+
):
33+
super().__init__()
34+
hidden_dim = int(2 * hidden_dim / 3)
35+
# custom dim factor multiplier
36+
if ffn_dim_multiplier is not None:
37+
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
38+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
39+
40+
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
41+
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
42+
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
43+
self.apply(self._init_weights)
44+
45+
def _init_weights(self, m):
46+
if isinstance(m, nn.Linear):
47+
nn.init.xavier_uniform_(m.weight)
48+
if m.bias is not None:
49+
nn.init.constant_(m.bias, 0)
50+
51+
def forward(self, x):
52+
return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
53+
54+
55+
class HiDreamImagePooledEmbed(nn.Module):
56+
def __init__(self, text_emb_dim, hidden_size):
57+
super().__init__()
58+
self.pooled_embedder = TimestepEmbedding(in_channels=text_emb_dim, time_embed_dim=hidden_size)
59+
self.apply(self._init_weights)
60+
61+
def _init_weights(self, m):
62+
if isinstance(m, nn.Linear):
63+
nn.init.normal_(m.weight, std=0.02)
64+
if m.bias is not None:
65+
nn.init.constant_(m.bias, 0)
66+
67+
def forward(self, pooled_embed):
68+
return self.pooled_embedder(pooled_embed)
69+
70+
71+
class HiDreamImageTimestepEmbed(nn.Module):
72+
def __init__(self, hidden_size, frequency_embedding_size=256):
73+
super().__init__()
74+
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0)
75+
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size)
76+
self.apply(self._init_weights)
77+
78+
def _init_weights(self, m):
79+
if isinstance(m, nn.Linear):
80+
nn.init.normal_(m.weight, std=0.02)
81+
if m.bias is not None:
82+
nn.init.constant_(m.bias, 0)
83+
84+
def forward(self, timesteps, wdtype):
85+
t_emb = self.time_proj(timesteps).to(dtype=wdtype)
86+
t_emb = self.timestep_embedder(t_emb)
87+
return t_emb
88+
89+
90+
class HiDreamImageOutEmbed(nn.Module):
91+
def __init__(self, hidden_size, patch_size, out_channels):
92+
super().__init__()
93+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
94+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
95+
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
96+
self.apply(self._init_weights)
97+
98+
def _init_weights(self, m):
99+
if isinstance(m, nn.Linear):
100+
nn.init.zeros_(m.weight)
101+
if m.bias is not None:
102+
nn.init.constant_(m.bias, 0)
103+
104+
def forward(self, x, adaln_input):
105+
shift, scale = self.adaLN_modulation(adaln_input).chunk(2, dim=1)
106+
x = self.norm_final(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
107+
x = self.linear(x)
108+
return x
109+
110+
111+
class HiDreamImagePatchEmbed(nn.Module):
112+
def __init__(
113+
self,
114+
patch_size=2,
115+
in_channels=4,
116+
out_channels=1024,
117+
):
118+
super().__init__()
119+
self.patch_size = patch_size
120+
self.out_channels = out_channels
121+
self.proj = nn.Linear(in_channels * patch_size * patch_size, out_channels, bias=True)
122+
self.apply(self._init_weights)
123+
124+
def _init_weights(self, m):
125+
if isinstance(m, nn.Linear):
126+
nn.init.xavier_uniform_(m.weight)
127+
if m.bias is not None:
128+
nn.init.constant_(m.bias, 0)
129+
130+
def forward(self, latent):
131+
latent = self.proj(latent)
132+
return latent
133+
134+
135+
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
136+
assert dim % 2 == 0, "The dimension must be even."
137+
138+
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
139+
omega = 1.0 / (theta**scale)
140+
141+
batch_size, seq_length = pos.shape
142+
out = torch.einsum("...n,d->...nd", pos, omega)
143+
cos_out = torch.cos(out)
144+
sin_out = torch.sin(out)
145+
146+
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
147+
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
148+
return out.float()
149+
150+
151+
class HiDreamImageEmbedND(nn.Module):
152+
def __init__(self, theta: int, axes_dim: List[int]):
153+
super().__init__()
154+
self.theta = theta
155+
self.axes_dim = axes_dim
156+
157+
def forward(self, ids: torch.Tensor) -> torch.Tensor:
158+
n_axes = ids.shape[-1]
159+
emb = torch.cat(
160+
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
161+
dim=-3,
162+
)
163+
return emb.unsqueeze(2)
164+
165+
29166
def apply_rope(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
30167
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
31168
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
@@ -706,7 +843,8 @@ def forward(
706843
img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device)
707844
img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None]
708845
img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :]
709-
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
846+
# repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
847+
img_ids = img_ids.reshape(img_ids.shape[0], img_ids.shape[1] * img_ids.shape[2]).unsqueeze(0)
710848
hidden_states = self.x_embedder(hidden_states)
711849

712850
T5_encoder_hidden_states = encoder_hidden_states[0]

0 commit comments

Comments
 (0)