|
26 | 26 | FusedAttnProcessor2_0, |
27 | 27 | SanaLinearAttnProcessor2_0, |
28 | 28 | ) |
29 | | -from ..autoencoders.autoencoder_dc import GLUMBConv |
30 | 29 | from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, SinusoidalPositionalEmbedding |
31 | 30 | from ..modeling_outputs import Transformer2DModelOutput |
32 | 31 | from ..modeling_utils import ModelMixin |
@@ -58,40 +57,40 @@ def __init__(self, dim, eps: float, elementwise_affine: bool = True, scale_facto |
58 | 57 | self.weight = nn.Parameter(torch.ones(dim) * scale_factor) |
59 | 58 |
|
60 | 59 |
|
61 | | -# Modified from diffusers.models.autoencoders.ecae.GLUMBConv |
| 60 | +# Modified from diffusers.models.autoencoders.autoencoder_dc.GLUMBConv |
62 | 61 | @maybe_allow_in_graph |
63 | 62 | class SanaGLUMBConv(nn.Module): |
64 | | - def __init__(self, in_channels: int, out_channels: int) -> None: |
| 63 | + def __init__(self, in_channels: int, out_channels: int, expand_ratio: float = 2.5) -> None: |
65 | 64 | super().__init__() |
66 | 65 |
|
67 | | - hidden_channels = int(2.5 * in_channels) |
| 66 | + hidden_channels = int(expand_ratio * in_channels) |
68 | 67 |
|
69 | 68 | self.nonlinearity = nn.SiLU() |
70 | 69 |
|
71 | 70 | self.conv_inverted = nn.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0) |
72 | 71 | self.conv_depth = nn.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2) |
73 | 72 | self.conv_point = nn.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False) |
74 | | - self.norm = RMSNorm(out_channels, eps=1e-5, elementwise_affine=True, bias=True) |
75 | 73 |
|
76 | | - def forward(self, x: torch.Tensor, HW=None) -> torch.Tensor: |
77 | | - B, N, C = x.shape |
| 74 | + def forward(self, hidden_states: torch.Tensor, HW: Optional[tuple[int]] = None) -> torch.Tensor: |
| 75 | + B, N, C = hidden_states.shape |
78 | 76 | if HW is None: |
79 | 77 | H = W = int(N**0.5) |
80 | 78 | else: |
81 | 79 | H, W = HW |
82 | 80 |
|
83 | | - x = x.reshape(B, H, W, C).permute(0, 3, 1, 2) |
84 | | - x = self.inverted_conv(x) |
85 | | - x = self.depth_conv(x) |
| 81 | + hidden_states = hidden_states.reshape(B, H, W, C).permute(0, 3, 1, 2) |
86 | 82 |
|
87 | | - x, gate = torch.chunk(x, 2, dim=1) |
88 | | - gate = self.glu_act(gate) |
89 | | - x = x * gate |
| 83 | + hidden_states = self.conv_inverted(hidden_states) |
| 84 | + hidden_states = self.nonlinearity(hidden_states) |
90 | 85 |
|
91 | | - x = self.point_conv(x) |
92 | | - x = x.reshape(B, C, N).permute(0, 2, 1) |
| 86 | + hidden_states = self.conv_depth(hidden_states) |
| 87 | + hidden_states, gate = torch.chunk(hidden_states, 2, dim=1) |
| 88 | + hidden_states = hidden_states * self.nonlinearity(gate) |
93 | 89 |
|
94 | | - return x |
| 90 | + hidden_states = self.conv_point(hidden_states) |
| 91 | + hidden_states = hidden_states.reshape(B, C, N).permute(0, 2, 1) |
| 92 | + |
| 93 | + return hidden_states |
95 | 94 |
|
96 | 95 |
|
97 | 96 | # Modified from diffusers.models.attention.BasicTransformerBlock |
@@ -130,8 +129,6 @@ def __init__( |
130 | 129 | use_pe: bool = False, |
131 | 130 | num_positional_embeddings: Optional[int] = None, |
132 | 131 | expand_ratio: float = 2.5, |
133 | | - ff_bias: tuple =(True, True, False), |
134 | | - ff_norm: tuple =(None, None, None), |
135 | 132 | ): |
136 | 133 | super().__init__() |
137 | 134 | self.dim = dim |
@@ -186,9 +183,6 @@ def __init__( |
186 | 183 | in_channels=dim, |
187 | 184 | out_channels=dim, |
188 | 185 | expand_ratio=expand_ratio, |
189 | | - use_bias=ff_bias, |
190 | | - norm=ff_norm, |
191 | | - act_func=activation_fn, |
192 | 186 | ) |
193 | 187 |
|
194 | 188 | # 5. Scale-shift for Sana. |
@@ -362,8 +356,6 @@ def __init__( |
362 | 356 | attention_type: Optional[str] = "default", |
363 | 357 | use_pe: Optional[bool] = False, |
364 | 358 | expand_ratio=2.5, |
365 | | - ff_bias: tuple =(True, True, False), |
366 | | - ff_norm: tuple =(None, None, None), |
367 | 359 | ): |
368 | 360 | super().__init__() |
369 | 361 |
|
@@ -428,8 +420,6 @@ def __init__( |
428 | 420 | norm_eps=self.config.norm_eps, |
429 | 421 | use_pe=self.config.use_pe, |
430 | 422 | expand_ratio=self.config.expand_ratio, |
431 | | - ff_bias=self.config.ff_bias, |
432 | | - ff_norm=self.config.ff_norm, |
433 | 423 | ) |
434 | 424 | for _ in range(self.config.num_layers) |
435 | 425 | ] |
|
0 commit comments