Skip to content

Commit 2c828c2

Browse files
committed
refactor
1 parent 0debade commit 2c828c2

File tree

6 files changed

+66
-114
lines changed

6 files changed

+66
-114
lines changed

scripts/convert_sana_to_diffusers.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ def main(args):
5959
converted_state_dict = {}
6060

6161
# Patch embeddings.
62-
converted_state_dict["pos_embed.proj.weight"] = state_dict.pop("x_embedder.proj.weight")
63-
converted_state_dict["pos_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias")
62+
converted_state_dict["patch_embed.proj.weight"] = state_dict.pop("x_embedder.proj.weight")
63+
converted_state_dict["patch_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias")
6464

6565
# Caption projection.
6666
converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight")
@@ -69,18 +69,18 @@ def main(args):
6969
converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")
7070

7171
# AdaLN-single LN
72-
converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
72+
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
7373
"t_embedder.mlp.0.weight"
7474
)
75-
converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
76-
converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
75+
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
76+
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
7777
"t_embedder.mlp.2.weight"
7878
)
79-
converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
79+
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
8080

8181
# Shared norm.
82-
converted_state_dict["adaln_single.linear.weight"] = state_dict.pop("t_block.1.weight")
83-
converted_state_dict["adaln_single.linear.bias"] = state_dict.pop("t_block.1.bias")
82+
converted_state_dict["time_embed.linear.weight"] = state_dict.pop("t_block.1.weight")
83+
converted_state_dict["time_embed.linear.bias"] = state_dict.pop("t_block.1.bias")
8484

8585
# y norm
8686
converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
@@ -166,18 +166,19 @@ def main(args):
166166
num_cross_attention_heads=model_kwargs[args.model_type]["num_cross_attention_heads"],
167167
cross_attention_head_dim=model_kwargs[args.model_type]["cross_attention_head_dim"],
168168
cross_attention_dim=model_kwargs[args.model_type]["cross_attention_dim"],
169+
caption_channels=2304,
170+
mlp_ratio=2.5,
169171
attention_bias=False,
170172
sample_size=32,
171173
patch_size=1,
172174
norm_elementwise_affine=False,
173175
norm_eps=1e-6,
174-
caption_channels=2304,
175-
expand_ratio=2.5,
176176
)
177+
177178
if is_accelerate_available():
178179
load_model_dict_into_meta(transformer, converted_state_dict)
179180
else:
180-
transformer.load_state_dict(converted_state_dict, strict=True)
181+
transformer.load_state_dict(converted_state_dict, strict=True, assign=True)
181182

182183
try:
183184
state_dict.pop("y_embedder.y_embedding")

src/diffusers/models/attention_processor.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5361,21 +5361,16 @@ def __call__(
53615361
) -> torch.Tensor:
53625362
original_dtype = hidden_states.dtype
53635363

5364-
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
5365-
53665364
if encoder_hidden_states is None:
53675365
encoder_hidden_states = hidden_states
53685366

53695367
query = attn.to_q(hidden_states)
53705368
key = attn.to_k(encoder_hidden_states)
53715369
value = attn.to_v(encoder_hidden_states)
53725370

5373-
inner_dim = key.shape[-1]
5374-
head_dim = inner_dim // attn.heads
5375-
5376-
query = query.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
5377-
key = key.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1).transpose(-1, -2)
5378-
value = value.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
5371+
query = query.transpose(1, 2).unflatten(1, (attn.heads, -1))
5372+
key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3)
5373+
value = value.transpose(1, 2).unflatten(1, (attn.heads, -1))
53795374

53805375
query = self.kernel_func(query)
53815376
key = self.kernel_func(key)
@@ -5386,17 +5381,14 @@ def __call__(
53865381
scores = torch.matmul(value, key)
53875382
hidden_states = torch.matmul(scores, query)
53885383

5389-
if hidden_states.dtype in [torch.float16, torch.bfloat16]:
5390-
hidden_states = hidden_states.float()
5391-
53925384
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
5393-
hidden_states = hidden_states.view(batch_size, attn.heads * head_dim, -1).permute(0, 2, 1)
5385+
hidden_states = hidden_states.flatten(1, 2).transpose(1, 2)
53945386
hidden_states = hidden_states.to(original_dtype)
53955387

53965388
hidden_states = attn.to_out[0](hidden_states)
53975389
hidden_states = attn.to_out[1](hidden_states)
53985390

5399-
if hidden_states.dtype == torch.float16:
5391+
if original_dtype == torch.float16:
54005392
hidden_states = hidden_states.clip(-65504, 65504)
54015393

54025394
return hidden_states

src/diffusers/models/autoencoders/autoencoder_dc.py

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -26,39 +26,10 @@
2626
from ..attention_processor import SanaMultiscaleLinearAttention
2727
from ..modeling_utils import ModelMixin
2828
from ..normalization import RMSNorm, get_normalization
29+
from ..transformers.sana_transformer import GLUMBConv
2930
from .vae import DecoderOutput, EncoderOutput
3031

3132

32-
class GLUMBConv(nn.Module):
33-
def __init__(self, in_channels: int, out_channels: int) -> None:
34-
super().__init__()
35-
36-
hidden_channels = 4 * in_channels
37-
38-
self.nonlinearity = nn.SiLU()
39-
40-
self.conv_inverted = nn.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0)
41-
self.conv_depth = nn.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2)
42-
self.conv_point = nn.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False)
43-
self.norm = RMSNorm(out_channels, eps=1e-5, elementwise_affine=True, bias=True)
44-
45-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
46-
residual = hidden_states
47-
48-
hidden_states = self.conv_inverted(hidden_states)
49-
hidden_states = self.nonlinearity(hidden_states)
50-
51-
hidden_states = self.conv_depth(hidden_states)
52-
hidden_states, gate = torch.chunk(hidden_states, 2, dim=1)
53-
hidden_states = hidden_states * self.nonlinearity(gate)
54-
55-
hidden_states = self.conv_point(hidden_states)
56-
# move channel to the last dimension so we apply RMSnorm across channel dimension
57-
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
58-
59-
return hidden_states + residual
60-
61-
6233
class ResBlock(nn.Module):
6334
def __init__(
6435
self,
@@ -115,6 +86,7 @@ def __init__(
11586
self.conv_out = GLUMBConv(
11687
in_channels=in_channels,
11788
out_channels=in_channels,
89+
norm_type="rms_norm",
11890
)
11991

12092
def forward(self, x: torch.Tensor) -> torch.Tensor:

src/diffusers/models/transformers/sana_transformer.py

Lines changed: 44 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Dict, Optional, Union
15+
from typing import Dict, Optional, Tuple, Union
1616

1717
import torch
1818
from torch import nn
@@ -35,28 +35,27 @@
3535
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3636

3737

38-
# Modified from diffusers.models.autoencoders.autoencoder_dc.GLUMBConv
39-
@maybe_allow_in_graph
40-
class SanaGLUMBConv(nn.Module):
41-
def __init__(self, in_channels: int, out_channels: int, mlp_ratio: float = 2.5) -> None:
38+
class GLUMBConv(nn.Module):
39+
def __init__(self, in_channels: int, out_channels: int, expand_ratio: float = 4, norm_type: Optional[str] = None, residual_connection: bool = True) -> None:
4240
super().__init__()
4341

44-
hidden_channels = int(mlp_ratio * in_channels)
42+
hidden_channels = int(expand_ratio * in_channels)
43+
self.norm_type = norm_type
44+
self.residual_connection = residual_connection
4545

4646
self.nonlinearity = nn.SiLU()
4747

4848
self.conv_inverted = nn.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0)
4949
self.conv_depth = nn.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2)
5050
self.conv_point = nn.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False)
5151

52-
def forward(self, hidden_states: torch.Tensor, HW: Optional[tuple[int]] = None) -> torch.Tensor:
53-
B, N, C = hidden_states.shape
54-
if HW is None:
55-
H = W = int(N**0.5)
56-
else:
57-
H, W = HW
52+
self.norm = None
53+
if norm_type == "rms_norm":
54+
self.norm = RMSNorm(out_channels, eps=1e-5, elementwise_affine=True, bias=True)
5855

59-
hidden_states = hidden_states.reshape(B, H, W, C).permute(0, 3, 1, 2)
56+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
57+
if self.residual_connection:
58+
residual = hidden_states
6059

6160
hidden_states = self.conv_inverted(hidden_states)
6261
hidden_states = self.nonlinearity(hidden_states)
@@ -66,23 +65,22 @@ def forward(self, hidden_states: torch.Tensor, HW: Optional[tuple[int]] = None)
6665
hidden_states = hidden_states * self.nonlinearity(gate)
6766

6867
hidden_states = self.conv_point(hidden_states)
69-
hidden_states = hidden_states.reshape(B, C, N).permute(0, 2, 1)
70-
68+
69+
if self.norm_type == "rms_norm":
70+
# move channel to the last dimension so we apply RMSnorm across channel dimension
71+
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
72+
73+
if self.residual_connection:
74+
hidden_states = hidden_states + residual
75+
7176
return hidden_states
7277

7378

7479
class SanaTransformerBlock(nn.Module):
7580
r"""
76-
A Transformer block following the Linear Transformer architecture, introduced in Sana
77-
78-
Reference: https://arxiv.org/abs/2410.10629
79-
80-
Parameters:
81-
dim (`int`): The number of channels in the input and output.
82-
num_attention_heads (`int`): The number of heads to use for multi-head attention.
83-
attention_head_dim (`int`): The number of channels in each head.
81+
Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629).
8482
"""
85-
83+
8684
def __init__(
8785
self,
8886
dim: int = 2240,
@@ -127,11 +125,7 @@ def __init__(
127125
)
128126

129127
# 3. Feed-forward
130-
self.ff = SanaGLUMBConv(
131-
in_channels=dim,
132-
out_channels=dim,
133-
mlp_ratio=mlp_ratio,
134-
)
128+
self.ff = GLUMBConv(dim, dim, mlp_ratio, norm_type=None, residual_connection=False)
135129

136130
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
137131

@@ -142,7 +136,8 @@ def forward(
142136
encoder_hidden_states: Optional[torch.Tensor] = None,
143137
encoder_attention_mask: Optional[torch.Tensor] = None,
144138
timestep: Optional[torch.LongTensor] = None,
145-
HW: Optional[tuple[int]] = None,
139+
height: int = None,
140+
width: int = None,
146141
) -> torch.Tensor:
147142
batch_size = hidden_states.shape[0]
148143

@@ -171,15 +166,17 @@ def forward(
171166
norm_hidden_states = self.norm2(hidden_states)
172167
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
173168

174-
ff_output = self.ff(norm_hidden_states, HW=HW)
169+
norm_hidden_states = norm_hidden_states.unflatten(1, (height, width)).permute(0, 3, 1, 2)
170+
ff_output = self.ff(norm_hidden_states)
171+
ff_output = ff_output.flatten(2, 3).permute(0, 2, 1)
175172
hidden_states = hidden_states + gate_mlp * ff_output
176173

177174
return hidden_states
178175

179176

180177
class SanaTransformer2DModel(ModelMixin, ConfigMixin):
181178
r"""
182-
A 2D Transformer model as introduced in [Sana](https://arxiv.org/abs/2410.10629) family of models.
179+
A 2D Transformer model introduced in [Sana](https://huggingface.co/papers/2410.10629) family of models.
183180
184181
Args:
185182
in_channels (`int`, defaults to `32`):
@@ -204,7 +201,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin):
204201
The expansion ratio to use in the GLUMBConv layer.
205202
dropout (`float`, defaults to `0.0`):
206203
The dropout probability.
207-
attention_bias (`bool`, defaults to `True`):
204+
attention_bias (`bool`, defaults to `False`):
208205
Whether to use bias in the attention layer.
209206
sample_size (`int`, defaults to `32`):
210207
The base size of the input latent.
@@ -233,7 +230,7 @@ def __init__(
233230
caption_channels: int = 2304,
234231
mlp_ratio: float = 2.5,
235232
dropout: float = 0.0,
236-
attention_bias: bool = True,
233+
attention_bias: bool = False,
237234
sample_size: int = 32,
238235
patch_size: int = 1,
239236
norm_elementwise_affine: bool = False,
@@ -245,7 +242,7 @@ def __init__(
245242
inner_dim = num_attention_heads * attention_head_dim
246243

247244
# 1. Patch Embedding
248-
self.pos_embed = PatchEmbed(
245+
self.patch_embed = PatchEmbed(
249246
height=sample_size,
250247
width=sample_size,
251248
patch_size=patch_size,
@@ -255,7 +252,9 @@ def __init__(
255252
pos_embed_type=None,
256253
)
257254

258-
# 2. Caption Embedding
255+
# 2. Additional condition embeddings
256+
self.time_embed = AdaLayerNormSingle(inner_dim)
257+
259258
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
260259
self.caption_norm = RMSNorm(inner_dim, eps=1e-5)
261260

@@ -285,8 +284,6 @@ def __init__(
285284
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
286285
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
287286

288-
self.adaln_single = AdaLayerNormSingle(inner_dim)
289-
290287
self.gradient_checkpointing = False
291288

292289
def _set_gradient_checkpointing(self, module, value=False):
@@ -361,7 +358,7 @@ def forward(
361358
encoder_attention_mask: Optional[torch.Tensor] = None,
362359
attention_mask: Optional[torch.Tensor] = None,
363360
return_dict: bool = True,
364-
):
361+
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
365362
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
366363
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
367364
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
@@ -387,11 +384,12 @@ def forward(
387384

388385
# 1. Input
389386
batch_size, num_channels, height, width = hidden_states.shape
390-
post_patch_height = height // self.config.patch_size
391-
post_patch_width = width // self.config.patch_size
392-
hidden_states = self.pos_embed(hidden_states)
387+
p = self.config.patch_size
388+
post_patch_height, post_patch_width = height // p, width // p
389+
390+
hidden_states = self.patch_embed(hidden_states)
393391

394-
timestep, embedded_timestep = self.adaln_single(
392+
timestep, embedded_timestep = self.time_embed(
395393
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
396394
)
397395

@@ -418,7 +416,8 @@ def create_block_forward(block):
418416
encoder_hidden_states,
419417
encoder_attention_mask,
420418
timestep,
421-
(post_patch_height, post_patch_width),
419+
post_patch_height,
420+
post_patch_width,
422421
)
423422

424423
# 3. Normalization
@@ -436,14 +435,7 @@ def create_block_forward(block):
436435
batch_size, post_patch_height, post_patch_width, self.config.patch_size, self.config.patch_size, -1
437436
)
438437
hidden_states = hidden_states.permute(0, 5, 1, 3, 2, 4)
439-
output = hidden_states.reshape(
440-
shape=(
441-
batch_size,
442-
-1,
443-
post_patch_height * self.config.patch_size,
444-
post_patch_width * self.config.patch_size,
445-
)
446-
)
438+
output = hidden_states.reshape(batch_size, -1, post_patch_height * p, post_patch_width * p)
447439

448440
if not return_dict:
449441
return (output,)

src/diffusers/pipelines/pag/pipeline_pag_sana.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from ...schedulers import FlowDPMSolverMultistepScheduler
2828
from ...utils import (
2929
BACKENDS_MAPPING,
30-
deprecate,
3130
is_bs4_available,
3231
is_ftfy_available,
3332
logging,

0 commit comments

Comments
 (0)