|
5 | 5 | import torch |
6 | 6 | import torch.nn as nn |
7 | 7 | import torch.nn.functional as F |
8 | | -from einops import repeat |
9 | 8 |
|
10 | 9 | from ...configuration_utils import ConfigMixin, register_to_config |
11 | 10 | from ...loaders import FromOriginalModelMixin, PeftAdapterMixin |
12 | 11 | from ...models.modeling_outputs import Transformer2DModelOutput |
13 | 12 | from ...models.modeling_utils import ModelMixin |
14 | 13 | from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers |
15 | 14 | from ...utils.torch_utils import maybe_allow_in_graph |
16 | | -from ..attention import Attention, HiDreamImageFeedForwardSwiGLU |
| 15 | +from ..attention import Attention |
17 | 16 | from ..embeddings import ( |
18 | | - HiDreamImageEmbedND, |
19 | | - HiDreamImageOutEmbed, |
20 | | - HiDreamImagePatchEmbed, |
21 | | - HiDreamImagePooledEmbed, |
22 | | - HiDreamImageTimestepEmbed, |
| 17 | + TimestepEmbedding, |
| 18 | + Timesteps, |
23 | 19 | ) |
24 | 20 |
|
25 | 21 |
|
26 | 22 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
27 | 23 |
|
28 | 24 |
|
| 25 | +class HiDreamImageFeedForwardSwiGLU(nn.Module): |
| 26 | + def __init__( |
| 27 | + self, |
| 28 | + dim: int, |
| 29 | + hidden_dim: int, |
| 30 | + multiple_of: int = 256, |
| 31 | + ffn_dim_multiplier: Optional[float] = None, |
| 32 | + ): |
| 33 | + super().__init__() |
| 34 | + hidden_dim = int(2 * hidden_dim / 3) |
| 35 | + # custom dim factor multiplier |
| 36 | + if ffn_dim_multiplier is not None: |
| 37 | + hidden_dim = int(ffn_dim_multiplier * hidden_dim) |
| 38 | + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) |
| 39 | + |
| 40 | + self.w1 = nn.Linear(dim, hidden_dim, bias=False) |
| 41 | + self.w2 = nn.Linear(hidden_dim, dim, bias=False) |
| 42 | + self.w3 = nn.Linear(dim, hidden_dim, bias=False) |
| 43 | + self.apply(self._init_weights) |
| 44 | + |
| 45 | + def _init_weights(self, m): |
| 46 | + if isinstance(m, nn.Linear): |
| 47 | + nn.init.xavier_uniform_(m.weight) |
| 48 | + if m.bias is not None: |
| 49 | + nn.init.constant_(m.bias, 0) |
| 50 | + |
| 51 | + def forward(self, x): |
| 52 | + return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) |
| 53 | + |
| 54 | + |
| 55 | +class HiDreamImagePooledEmbed(nn.Module): |
| 56 | + def __init__(self, text_emb_dim, hidden_size): |
| 57 | + super().__init__() |
| 58 | + self.pooled_embedder = TimestepEmbedding(in_channels=text_emb_dim, time_embed_dim=hidden_size) |
| 59 | + self.apply(self._init_weights) |
| 60 | + |
| 61 | + def _init_weights(self, m): |
| 62 | + if isinstance(m, nn.Linear): |
| 63 | + nn.init.normal_(m.weight, std=0.02) |
| 64 | + if m.bias is not None: |
| 65 | + nn.init.constant_(m.bias, 0) |
| 66 | + |
| 67 | + def forward(self, pooled_embed): |
| 68 | + return self.pooled_embedder(pooled_embed) |
| 69 | + |
| 70 | + |
| 71 | +class HiDreamImageTimestepEmbed(nn.Module): |
| 72 | + def __init__(self, hidden_size, frequency_embedding_size=256): |
| 73 | + super().__init__() |
| 74 | + self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0) |
| 75 | + self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size) |
| 76 | + self.apply(self._init_weights) |
| 77 | + |
| 78 | + def _init_weights(self, m): |
| 79 | + if isinstance(m, nn.Linear): |
| 80 | + nn.init.normal_(m.weight, std=0.02) |
| 81 | + if m.bias is not None: |
| 82 | + nn.init.constant_(m.bias, 0) |
| 83 | + |
| 84 | + def forward(self, timesteps, wdtype): |
| 85 | + t_emb = self.time_proj(timesteps).to(dtype=wdtype) |
| 86 | + t_emb = self.timestep_embedder(t_emb) |
| 87 | + return t_emb |
| 88 | + |
| 89 | + |
| 90 | +class HiDreamImageOutEmbed(nn.Module): |
| 91 | + def __init__(self, hidden_size, patch_size, out_channels): |
| 92 | + super().__init__() |
| 93 | + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
| 94 | + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) |
| 95 | + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) |
| 96 | + self.apply(self._init_weights) |
| 97 | + |
| 98 | + def _init_weights(self, m): |
| 99 | + if isinstance(m, nn.Linear): |
| 100 | + nn.init.zeros_(m.weight) |
| 101 | + if m.bias is not None: |
| 102 | + nn.init.constant_(m.bias, 0) |
| 103 | + |
| 104 | + def forward(self, x, adaln_input): |
| 105 | + shift, scale = self.adaLN_modulation(adaln_input).chunk(2, dim=1) |
| 106 | + x = self.norm_final(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) |
| 107 | + x = self.linear(x) |
| 108 | + return x |
| 109 | + |
| 110 | + |
| 111 | +class HiDreamImagePatchEmbed(nn.Module): |
| 112 | + def __init__( |
| 113 | + self, |
| 114 | + patch_size=2, |
| 115 | + in_channels=4, |
| 116 | + out_channels=1024, |
| 117 | + ): |
| 118 | + super().__init__() |
| 119 | + self.patch_size = patch_size |
| 120 | + self.out_channels = out_channels |
| 121 | + self.proj = nn.Linear(in_channels * patch_size * patch_size, out_channels, bias=True) |
| 122 | + self.apply(self._init_weights) |
| 123 | + |
| 124 | + def _init_weights(self, m): |
| 125 | + if isinstance(m, nn.Linear): |
| 126 | + nn.init.xavier_uniform_(m.weight) |
| 127 | + if m.bias is not None: |
| 128 | + nn.init.constant_(m.bias, 0) |
| 129 | + |
| 130 | + def forward(self, latent): |
| 131 | + latent = self.proj(latent) |
| 132 | + return latent |
| 133 | + |
| 134 | + |
| 135 | +def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: |
| 136 | + assert dim % 2 == 0, "The dimension must be even." |
| 137 | + |
| 138 | + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim |
| 139 | + omega = 1.0 / (theta**scale) |
| 140 | + |
| 141 | + batch_size, seq_length = pos.shape |
| 142 | + out = torch.einsum("...n,d->...nd", pos, omega) |
| 143 | + cos_out = torch.cos(out) |
| 144 | + sin_out = torch.sin(out) |
| 145 | + |
| 146 | + stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1) |
| 147 | + out = stacked_out.view(batch_size, -1, dim // 2, 2, 2) |
| 148 | + return out.float() |
| 149 | + |
| 150 | + |
| 151 | +class HiDreamImageEmbedND(nn.Module): |
| 152 | + def __init__(self, theta: int, axes_dim: List[int]): |
| 153 | + super().__init__() |
| 154 | + self.theta = theta |
| 155 | + self.axes_dim = axes_dim |
| 156 | + |
| 157 | + def forward(self, ids: torch.Tensor) -> torch.Tensor: |
| 158 | + n_axes = ids.shape[-1] |
| 159 | + emb = torch.cat( |
| 160 | + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], |
| 161 | + dim=-3, |
| 162 | + ) |
| 163 | + return emb.unsqueeze(2) |
| 164 | + |
| 165 | + |
29 | 166 | def apply_rope(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
30 | 167 | xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) |
31 | 168 | xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) |
@@ -706,7 +843,8 @@ def forward( |
706 | 843 | img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device) |
707 | 844 | img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None] |
708 | 845 | img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :] |
709 | | - img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size) |
| 846 | + # repeat(img_ids, "h w c -> b (h w) c", b=batch_size) |
| 847 | + img_ids = img_ids.reshape(img_ids.shape[0], img_ids.shape[1] * img_ids.shape[2]).unsqueeze(0) |
710 | 848 | hidden_states = self.x_embedder(hidden_states) |
711 | 849 |
|
712 | 850 | T5_encoder_hidden_states = encoder_hidden_states[0] |
|
0 commit comments