| 
5 | 5 | import torch  | 
6 | 6 | import torch.nn as nn  | 
7 | 7 | import torch.nn.functional as F  | 
8 |  | -from einops import repeat  | 
9 | 8 | 
 
  | 
10 | 9 | from ...configuration_utils import ConfigMixin, register_to_config  | 
11 | 10 | from ...loaders import FromOriginalModelMixin, PeftAdapterMixin  | 
12 | 11 | from ...models.modeling_outputs import Transformer2DModelOutput  | 
13 | 12 | from ...models.modeling_utils import ModelMixin  | 
14 | 13 | from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers  | 
15 | 14 | from ...utils.torch_utils import maybe_allow_in_graph  | 
16 |  | -from ..attention import Attention, HiDreamImageFeedForwardSwiGLU  | 
 | 15 | +from ..attention import Attention  | 
17 | 16 | from ..embeddings import (  | 
18 |  | -    HiDreamImageEmbedND,  | 
19 |  | -    HiDreamImageOutEmbed,  | 
20 |  | -    HiDreamImagePatchEmbed,  | 
21 |  | -    HiDreamImagePooledEmbed,  | 
22 |  | -    HiDreamImageTimestepEmbed,  | 
 | 17 | +    TimestepEmbedding,  | 
 | 18 | +    Timesteps,  | 
23 | 19 | )  | 
24 | 20 | 
 
  | 
25 | 21 | 
 
  | 
26 | 22 | logger = logging.get_logger(__name__)  # pylint: disable=invalid-name  | 
27 | 23 | 
 
  | 
28 | 24 | 
 
  | 
 | 25 | +class HiDreamImageFeedForwardSwiGLU(nn.Module):  | 
 | 26 | +    def __init__(  | 
 | 27 | +        self,  | 
 | 28 | +        dim: int,  | 
 | 29 | +        hidden_dim: int,  | 
 | 30 | +        multiple_of: int = 256,  | 
 | 31 | +        ffn_dim_multiplier: Optional[float] = None,  | 
 | 32 | +    ):  | 
 | 33 | +        super().__init__()  | 
 | 34 | +        hidden_dim = int(2 * hidden_dim / 3)  | 
 | 35 | +        # custom dim factor multiplier  | 
 | 36 | +        if ffn_dim_multiplier is not None:  | 
 | 37 | +            hidden_dim = int(ffn_dim_multiplier * hidden_dim)  | 
 | 38 | +        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)  | 
 | 39 | + | 
 | 40 | +        self.w1 = nn.Linear(dim, hidden_dim, bias=False)  | 
 | 41 | +        self.w2 = nn.Linear(hidden_dim, dim, bias=False)  | 
 | 42 | +        self.w3 = nn.Linear(dim, hidden_dim, bias=False)  | 
 | 43 | +        self.apply(self._init_weights)  | 
 | 44 | + | 
 | 45 | +    def _init_weights(self, m):  | 
 | 46 | +        if isinstance(m, nn.Linear):  | 
 | 47 | +            nn.init.xavier_uniform_(m.weight)  | 
 | 48 | +            if m.bias is not None:  | 
 | 49 | +                nn.init.constant_(m.bias, 0)  | 
 | 50 | + | 
 | 51 | +    def forward(self, x):  | 
 | 52 | +        return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))  | 
 | 53 | + | 
 | 54 | + | 
 | 55 | +class HiDreamImagePooledEmbed(nn.Module):  | 
 | 56 | +    def __init__(self, text_emb_dim, hidden_size):  | 
 | 57 | +        super().__init__()  | 
 | 58 | +        self.pooled_embedder = TimestepEmbedding(in_channels=text_emb_dim, time_embed_dim=hidden_size)  | 
 | 59 | +        self.apply(self._init_weights)  | 
 | 60 | + | 
 | 61 | +    def _init_weights(self, m):  | 
 | 62 | +        if isinstance(m, nn.Linear):  | 
 | 63 | +            nn.init.normal_(m.weight, std=0.02)  | 
 | 64 | +            if m.bias is not None:  | 
 | 65 | +                nn.init.constant_(m.bias, 0)  | 
 | 66 | + | 
 | 67 | +    def forward(self, pooled_embed):  | 
 | 68 | +        return self.pooled_embedder(pooled_embed)  | 
 | 69 | + | 
 | 70 | + | 
 | 71 | +class HiDreamImageTimestepEmbed(nn.Module):  | 
 | 72 | +    def __init__(self, hidden_size, frequency_embedding_size=256):  | 
 | 73 | +        super().__init__()  | 
 | 74 | +        self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0)  | 
 | 75 | +        self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size)  | 
 | 76 | +        self.apply(self._init_weights)  | 
 | 77 | + | 
 | 78 | +    def _init_weights(self, m):  | 
 | 79 | +        if isinstance(m, nn.Linear):  | 
 | 80 | +            nn.init.normal_(m.weight, std=0.02)  | 
 | 81 | +            if m.bias is not None:  | 
 | 82 | +                nn.init.constant_(m.bias, 0)  | 
 | 83 | + | 
 | 84 | +    def forward(self, timesteps, wdtype):  | 
 | 85 | +        t_emb = self.time_proj(timesteps).to(dtype=wdtype)  | 
 | 86 | +        t_emb = self.timestep_embedder(t_emb)  | 
 | 87 | +        return t_emb  | 
 | 88 | + | 
 | 89 | + | 
 | 90 | +class HiDreamImageOutEmbed(nn.Module):  | 
 | 91 | +    def __init__(self, hidden_size, patch_size, out_channels):  | 
 | 92 | +        super().__init__()  | 
 | 93 | +        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)  | 
 | 94 | +        self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)  | 
 | 95 | +        self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))  | 
 | 96 | +        self.apply(self._init_weights)  | 
 | 97 | + | 
 | 98 | +    def _init_weights(self, m):  | 
 | 99 | +        if isinstance(m, nn.Linear):  | 
 | 100 | +            nn.init.zeros_(m.weight)  | 
 | 101 | +            if m.bias is not None:  | 
 | 102 | +                nn.init.constant_(m.bias, 0)  | 
 | 103 | + | 
 | 104 | +    def forward(self, x, adaln_input):  | 
 | 105 | +        shift, scale = self.adaLN_modulation(adaln_input).chunk(2, dim=1)  | 
 | 106 | +        x = self.norm_final(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)  | 
 | 107 | +        x = self.linear(x)  | 
 | 108 | +        return x  | 
 | 109 | + | 
 | 110 | + | 
 | 111 | +class HiDreamImagePatchEmbed(nn.Module):  | 
 | 112 | +    def __init__(  | 
 | 113 | +        self,  | 
 | 114 | +        patch_size=2,  | 
 | 115 | +        in_channels=4,  | 
 | 116 | +        out_channels=1024,  | 
 | 117 | +    ):  | 
 | 118 | +        super().__init__()  | 
 | 119 | +        self.patch_size = patch_size  | 
 | 120 | +        self.out_channels = out_channels  | 
 | 121 | +        self.proj = nn.Linear(in_channels * patch_size * patch_size, out_channels, bias=True)  | 
 | 122 | +        self.apply(self._init_weights)  | 
 | 123 | + | 
 | 124 | +    def _init_weights(self, m):  | 
 | 125 | +        if isinstance(m, nn.Linear):  | 
 | 126 | +            nn.init.xavier_uniform_(m.weight)  | 
 | 127 | +            if m.bias is not None:  | 
 | 128 | +                nn.init.constant_(m.bias, 0)  | 
 | 129 | + | 
 | 130 | +    def forward(self, latent):  | 
 | 131 | +        latent = self.proj(latent)  | 
 | 132 | +        return latent  | 
 | 133 | + | 
 | 134 | + | 
 | 135 | +def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:  | 
 | 136 | +    assert dim % 2 == 0, "The dimension must be even."  | 
 | 137 | + | 
 | 138 | +    scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim  | 
 | 139 | +    omega = 1.0 / (theta**scale)  | 
 | 140 | + | 
 | 141 | +    batch_size, seq_length = pos.shape  | 
 | 142 | +    out = torch.einsum("...n,d->...nd", pos, omega)  | 
 | 143 | +    cos_out = torch.cos(out)  | 
 | 144 | +    sin_out = torch.sin(out)  | 
 | 145 | + | 
 | 146 | +    stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)  | 
 | 147 | +    out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)  | 
 | 148 | +    return out.float()  | 
 | 149 | + | 
 | 150 | + | 
 | 151 | +class HiDreamImageEmbedND(nn.Module):  | 
 | 152 | +    def __init__(self, theta: int, axes_dim: List[int]):  | 
 | 153 | +        super().__init__()  | 
 | 154 | +        self.theta = theta  | 
 | 155 | +        self.axes_dim = axes_dim  | 
 | 156 | + | 
 | 157 | +    def forward(self, ids: torch.Tensor) -> torch.Tensor:  | 
 | 158 | +        n_axes = ids.shape[-1]  | 
 | 159 | +        emb = torch.cat(  | 
 | 160 | +            [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],  | 
 | 161 | +            dim=-3,  | 
 | 162 | +        )  | 
 | 163 | +        return emb.unsqueeze(2)  | 
 | 164 | + | 
 | 165 | + | 
29 | 166 | def apply_rope(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:  | 
30 | 167 |     xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)  | 
31 | 168 |     xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)  | 
@@ -706,7 +843,8 @@ def forward(  | 
706 | 843 |             img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device)  | 
707 | 844 |             img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None]  | 
708 | 845 |             img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :]  | 
709 |  | -            img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)  | 
 | 846 | +            # repeat(img_ids, "h w c -> b (h w) c", b=batch_size)  | 
 | 847 | +            img_ids = img_ids.reshape(img_ids.shape[0], img_ids.shape[1] * img_ids.shape[2]).unsqueeze(0)  | 
710 | 848 |         hidden_states = self.x_embedder(hidden_states)  | 
711 | 849 | 
 
  | 
712 | 850 |         T5_encoder_hidden_states = encoder_hidden_states[0]  | 
 | 
0 commit comments