Skip to content

Commit b58868e

Browse files
[Sana bug] bug fix for 2K model config (#10340)
* fix the Positinoal Embedding bug in 2K model; * Change the default model to the BF16 one for more stable training and output * make style * substract buffer size * add compute_module_persistent_sizes --------- Co-authored-by: yiyixuxu <[email protected]>
1 parent da21d59 commit b58868e

File tree

7 files changed

+93
-18
lines changed

7 files changed

+93
-18
lines changed

docs/source/en/api/models/sana_transformer2d.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ The model can be loaded with the following code snippet.
2222
```python
2323
from diffusers import SanaTransformer2DModel
2424

25-
transformer = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/Sana_1600M_1024px_diffusers", subfolder="transformer", torch_dtype=torch.float16)
25+
transformer = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
2626
```
2727

2828
## SanaTransformer2DModel

docs/source/en/api/pipelines/sana.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ Available models:
3232

3333
| Model | Recommended dtype |
3434
|:-----:|:-----------------:|
35+
| [`Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers) | `torch.bfloat16` |
3536
| [`Efficient-Large-Model/Sana_1600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_diffusers) | `torch.float16` |
3637
| [`Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers) | `torch.float16` |
37-
| [`Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers) | `torch.bfloat16` |
3838
| [`Efficient-Large-Model/Sana_1600M_512px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_diffusers) | `torch.float16` |
3939
| [`Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers) | `torch.float16` |
4040
| [`Efficient-Large-Model/Sana_600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_600M_1024px_diffusers) | `torch.float16` |

scripts/convert_sana_to_diffusers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,18 @@ def main(args):
8888
# y norm
8989
converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
9090

91+
# scheduler
9192
flow_shift = 3.0
93+
94+
# model config
9295
if args.model_type == "SanaMS_1600M_P1_D20":
9396
layer_num = 20
9497
elif args.model_type == "SanaMS_600M_P1_D28":
9598
layer_num = 28
9699
else:
97100
raise ValueError(f"{args.model_type} is not supported.")
101+
# Positional embedding interpolation scale.
102+
interpolation_scale = {512: None, 1024: None, 2048: 1.0}
98103

99104
for depth in range(layer_num):
100105
# Transformer blocks.
@@ -176,6 +181,7 @@ def main(args):
176181
patch_size=1,
177182
norm_elementwise_affine=False,
178183
norm_eps=1e-6,
184+
interpolation_scale=interpolation_scale[args.image_size],
179185
)
180186

181187
if is_accelerate_available():

src/diffusers/models/transformers/sana_transformer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,21 +242,22 @@ def __init__(
242242
patch_size: int = 1,
243243
norm_elementwise_affine: bool = False,
244244
norm_eps: float = 1e-6,
245+
interpolation_scale: Optional[int] = None,
245246
) -> None:
246247
super().__init__()
247248

248249
out_channels = out_channels or in_channels
249250
inner_dim = num_attention_heads * attention_head_dim
250251

251252
# 1. Patch Embedding
253+
interpolation_scale = interpolation_scale if interpolation_scale is not None else max(sample_size // 64, 1)
252254
self.patch_embed = PatchEmbed(
253255
height=sample_size,
254256
width=sample_size,
255257
patch_size=patch_size,
256258
in_channels=in_channels,
257259
embed_dim=inner_dim,
258-
interpolation_scale=None,
259-
pos_embed_type=None,
260+
interpolation_scale=interpolation_scale,
260261
)
261262

262263
# 2. Additional condition embeddings

src/diffusers/pipelines/pag/pipeline_pag_sana.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,13 @@
5959
>>> from diffusers import SanaPAGPipeline
6060
6161
>>> pipe = SanaPAGPipeline.from_pretrained(
62-
... "Efficient-Large-Model/Sana_1600M_1024px_diffusers",
62+
... "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
6363
... pag_applied_layers=["transformer_blocks.8"],
6464
... torch_dtype=torch.float32,
6565
... )
6666
>>> pipe.to("cuda")
6767
>>> pipe.text_encoder.to(torch.bfloat16)
68-
>>> pipe.transformer = pipe.transformer.to(torch.float16)
68+
>>> pipe.transformer = pipe.transformer.to(torch.bfloat16)
6969
7070
>>> image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0]
7171
>>> image[0].save("output.png")

src/diffusers/pipelines/sana/pipeline_sana.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,11 @@
6262
>>> from diffusers import SanaPipeline
6363
6464
>>> pipe = SanaPipeline.from_pretrained(
65-
... "Efficient-Large-Model/Sana_1600M_1024px_diffusers", torch_dtype=torch.float32
65+
... "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", torch_dtype=torch.float32
6666
... )
6767
>>> pipe.to("cuda")
6868
>>> pipe.text_encoder.to(torch.bfloat16)
69-
>>> pipe.transformer = pipe.transformer.to(torch.float16)
69+
>>> pipe.transformer = pipe.transformer.to(torch.bfloat16)
7070
7171
>>> image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0]
7272
>>> image[0].save("output.png")

tests/models/test_modeling_common.py

Lines changed: 78 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,14 @@
2222
import unittest
2323
import unittest.mock as mock
2424
import uuid
25-
from typing import Dict, List, Tuple
25+
from collections import defaultdict
26+
from typing import Dict, List, Optional, Tuple, Union
2627

2728
import numpy as np
2829
import requests_mock
2930
import torch
30-
from accelerate.utils import compute_module_sizes
31+
import torch.nn as nn
32+
from accelerate.utils.modeling import _get_proper_dtype, dtype_byte_size
3133
from huggingface_hub import ModelCard, delete_repo, snapshot_download
3234
from huggingface_hub.utils import is_jinja_available
3335
from parameterized import parameterized
@@ -113,6 +115,72 @@ def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
113115
out_queue.join()
114116

115117

118+
def named_persistent_module_tensors(
119+
module: nn.Module,
120+
recurse: bool = False,
121+
):
122+
"""
123+
A helper function that gathers all the tensors (parameters + persistent buffers) of a given module.
124+
125+
Args:
126+
module (`torch.nn.Module`):
127+
The module we want the tensors on.
128+
recurse (`bool`, *optional`, defaults to `False`):
129+
Whether or not to go look in every submodule or just return the direct parameters and buffers.
130+
"""
131+
yield from module.named_parameters(recurse=recurse)
132+
133+
for named_buffer in module.named_buffers(recurse=recurse):
134+
name, _ = named_buffer
135+
# Get parent by splitting on dots and traversing the model
136+
parent = module
137+
if "." in name:
138+
parent_name = name.rsplit(".", 1)[0]
139+
for part in parent_name.split("."):
140+
parent = getattr(parent, part)
141+
name = name.split(".")[-1]
142+
if name not in parent._non_persistent_buffers_set:
143+
yield named_buffer
144+
145+
146+
def compute_module_persistent_sizes(
147+
model: nn.Module,
148+
dtype: Optional[Union[str, torch.device]] = None,
149+
special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None,
150+
):
151+
"""
152+
Compute the size of each submodule of a given model (parameters + persistent buffers).
153+
"""
154+
if dtype is not None:
155+
dtype = _get_proper_dtype(dtype)
156+
dtype_size = dtype_byte_size(dtype)
157+
if special_dtypes is not None:
158+
special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()}
159+
special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()}
160+
module_sizes = defaultdict(int)
161+
162+
module_list = []
163+
164+
module_list = named_persistent_module_tensors(model, recurse=True)
165+
166+
for name, tensor in module_list:
167+
if special_dtypes is not None and name in special_dtypes:
168+
size = tensor.numel() * special_dtypes_size[name]
169+
elif dtype is None:
170+
size = tensor.numel() * dtype_byte_size(tensor.dtype)
171+
elif str(tensor.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
172+
# According to the code in set_module_tensor_to_device, these types won't be converted
173+
# so use their original size here
174+
size = tensor.numel() * dtype_byte_size(tensor.dtype)
175+
else:
176+
size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype))
177+
name_parts = name.split(".")
178+
for idx in range(len(name_parts) + 1):
179+
module_sizes[".".join(name_parts[:idx])] += size
180+
181+
return module_sizes
182+
183+
116184
class ModelUtilsTest(unittest.TestCase):
117185
def tearDown(self):
118186
super().tearDown()
@@ -1012,7 +1080,7 @@ def test_cpu_offload(self):
10121080
torch.manual_seed(0)
10131081
base_output = model(**inputs_dict)
10141082

1015-
model_size = compute_module_sizes(model)[""]
1083+
model_size = compute_module_persistent_sizes(model)[""]
10161084
# We test several splits of sizes to make sure it works.
10171085
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
10181086
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -1042,7 +1110,7 @@ def test_disk_offload_without_safetensors(self):
10421110
torch.manual_seed(0)
10431111
base_output = model(**inputs_dict)
10441112

1045-
model_size = compute_module_sizes(model)[""]
1113+
model_size = compute_module_persistent_sizes(model)[""]
10461114
with tempfile.TemporaryDirectory() as tmp_dir:
10471115
model.cpu().save_pretrained(tmp_dir, safe_serialization=False)
10481116

@@ -1076,7 +1144,7 @@ def test_disk_offload_with_safetensors(self):
10761144
torch.manual_seed(0)
10771145
base_output = model(**inputs_dict)
10781146

1079-
model_size = compute_module_sizes(model)[""]
1147+
model_size = compute_module_persistent_sizes(model)[""]
10801148
with tempfile.TemporaryDirectory() as tmp_dir:
10811149
model.cpu().save_pretrained(tmp_dir)
10821150

@@ -1104,7 +1172,7 @@ def test_model_parallelism(self):
11041172
torch.manual_seed(0)
11051173
base_output = model(**inputs_dict)
11061174

1107-
model_size = compute_module_sizes(model)[""]
1175+
model_size = compute_module_persistent_sizes(model)[""]
11081176
# We test several splits of sizes to make sure it works.
11091177
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
11101178
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -1132,7 +1200,7 @@ def test_sharded_checkpoints(self):
11321200

11331201
base_output = model(**inputs_dict)
11341202

1135-
model_size = compute_module_sizes(model)[""]
1203+
model_size = compute_module_persistent_sizes(model)[""]
11361204
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
11371205
with tempfile.TemporaryDirectory() as tmp_dir:
11381206
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
@@ -1164,7 +1232,7 @@ def test_sharded_checkpoints_with_variant(self):
11641232

11651233
base_output = model(**inputs_dict)
11661234

1167-
model_size = compute_module_sizes(model)[""]
1235+
model_size = compute_module_persistent_sizes(model)[""]
11681236
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
11691237
variant = "fp16"
11701238
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -1204,7 +1272,7 @@ def test_sharded_checkpoints_device_map(self):
12041272
torch.manual_seed(0)
12051273
base_output = model(**inputs_dict)
12061274

1207-
model_size = compute_module_sizes(model)[""]
1275+
model_size = compute_module_persistent_sizes(model)[""]
12081276
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
12091277
with tempfile.TemporaryDirectory() as tmp_dir:
12101278
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
@@ -1233,7 +1301,7 @@ def test_variant_sharded_ckpt_right_format(self):
12331301
config, _ = self.prepare_init_args_and_inputs_for_common()
12341302
model = self.model_class(**config).eval()
12351303

1236-
model_size = compute_module_sizes(model)[""]
1304+
model_size = compute_module_persistent_sizes(model)[""]
12371305
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
12381306
variant = "fp16"
12391307
with tempfile.TemporaryDirectory() as tmp_dir:

0 commit comments

Comments
 (0)