Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/api/models/sana_transformer2d.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ The model can be loaded with the following code snippet.
```python
from diffusers import SanaTransformer2DModel

transformer = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/Sana_1600M_1024px_diffusers", subfolder="transformer", torch_dtype=torch.float16)
transformer = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
```

## SanaTransformer2DModel
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/api/pipelines/sana.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ Available models:

| Model | Recommended dtype |
|:-----:|:-----------------:|
| [`Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers) | `torch.bfloat16` |
| [`Efficient-Large-Model/Sana_1600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_diffusers) | `torch.float16` |
| [`Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers) | `torch.float16` |
| [`Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers) | `torch.bfloat16` |
| [`Efficient-Large-Model/Sana_1600M_512px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_diffusers) | `torch.float16` |
| [`Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers) | `torch.float16` |
| [`Efficient-Large-Model/Sana_600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_600M_1024px_diffusers) | `torch.float16` |
Expand Down
6 changes: 6 additions & 0 deletions scripts/convert_sana_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,18 @@ def main(args):
# y norm
converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")

# scheduler
flow_shift = 3.0

# model config
if args.model_type == "SanaMS_1600M_P1_D20":
layer_num = 20
elif args.model_type == "SanaMS_600M_P1_D28":
layer_num = 28
else:
raise ValueError(f"{args.model_type} is not supported.")
# Positional embedding interpolation scale.
interpolation_scale = {512: None, 1024: None, 2048: 1.0}

for depth in range(layer_num):
# Transformer blocks.
Expand Down Expand Up @@ -176,6 +181,7 @@ def main(args):
patch_size=1,
norm_elementwise_affine=False,
norm_eps=1e-6,
interpolation_scale=interpolation_scale[args.image_size],
)

if is_accelerate_available():
Expand Down
5 changes: 3 additions & 2 deletions src/diffusers/models/transformers/sana_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,21 +242,22 @@ def __init__(
patch_size: int = 1,
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6,
interpolation_scale: Optional[int] = None,
) -> None:
super().__init__()

out_channels = out_channels or in_channels
inner_dim = num_attention_heads * attention_head_dim

# 1. Patch Embedding
interpolation_scale = interpolation_scale if interpolation_scale is not None else max(sample_size // 64, 1)
self.patch_embed = PatchEmbed(
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
interpolation_scale=None,
pos_embed_type=None,
interpolation_scale=interpolation_scale,
)

# 2. Additional condition embeddings
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/pag/pipeline_pag_sana.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,13 @@
>>> from diffusers import SanaPAGPipeline

>>> pipe = SanaPAGPipeline.from_pretrained(
... "Efficient-Large-Model/Sana_1600M_1024px_diffusers",
... "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
... pag_applied_layers=["transformer_blocks.8"],
... torch_dtype=torch.float32,
... )
>>> pipe.to("cuda")
>>> pipe.text_encoder.to(torch.bfloat16)
>>> pipe.transformer = pipe.transformer.to(torch.float16)
>>> pipe.transformer = pipe.transformer.to(torch.bfloat16)

>>> image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0]
>>> image[0].save("output.png")
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/sana/pipeline_sana.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@
>>> from diffusers import SanaPipeline

>>> pipe = SanaPipeline.from_pretrained(
... "Efficient-Large-Model/Sana_1600M_1024px_diffusers", torch_dtype=torch.float32
... "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", torch_dtype=torch.float32
... )
>>> pipe.to("cuda")
>>> pipe.text_encoder.to(torch.bfloat16)
>>> pipe.transformer = pipe.transformer.to(torch.float16)
>>> pipe.transformer = pipe.transformer.to(torch.bfloat16)

>>> image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0]
>>> image[0].save("output.png")
Expand Down
88 changes: 78 additions & 10 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
import unittest
import unittest.mock as mock
import uuid
from typing import Dict, List, Tuple
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import requests_mock
import torch
from accelerate.utils import compute_module_sizes
import torch.nn as nn
from accelerate.utils.modeling import _get_proper_dtype, dtype_byte_size
from huggingface_hub import ModelCard, delete_repo, snapshot_download
from huggingface_hub.utils import is_jinja_available
from parameterized import parameterized
Expand Down Expand Up @@ -113,6 +115,72 @@ def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
out_queue.join()


def named_persistent_module_tensors(
module: nn.Module,
recurse: bool = False,
):
"""
A helper function that gathers all the tensors (parameters + persistent buffers) of a given module.

Args:
module (`torch.nn.Module`):
The module we want the tensors on.
recurse (`bool`, *optional`, defaults to `False`):
Whether or not to go look in every submodule or just return the direct parameters and buffers.
"""
yield from module.named_parameters(recurse=recurse)

for named_buffer in module.named_buffers(recurse=recurse):
name, _ = named_buffer
# Get parent by splitting on dots and traversing the model
parent = module
if "." in name:
parent_name = name.rsplit(".", 1)[0]
for part in parent_name.split("."):
parent = getattr(parent, part)
name = name.split(".")[-1]
if name not in parent._non_persistent_buffers_set:
yield named_buffer


def compute_module_persistent_sizes(
model: nn.Module,
dtype: Optional[Union[str, torch.device]] = None,
special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None,
):
"""
Compute the size of each submodule of a given model (parameters + persistent buffers).
"""
if dtype is not None:
dtype = _get_proper_dtype(dtype)
dtype_size = dtype_byte_size(dtype)
if special_dtypes is not None:
special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()}
special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()}
module_sizes = defaultdict(int)

module_list = []

module_list = named_persistent_module_tensors(model, recurse=True)

for name, tensor in module_list:
if special_dtypes is not None and name in special_dtypes:
size = tensor.numel() * special_dtypes_size[name]
elif dtype is None:
size = tensor.numel() * dtype_byte_size(tensor.dtype)
elif str(tensor.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
# According to the code in set_module_tensor_to_device, these types won't be converted
# so use their original size here
size = tensor.numel() * dtype_byte_size(tensor.dtype)
else:
size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype))
name_parts = name.split(".")
for idx in range(len(name_parts) + 1):
module_sizes[".".join(name_parts[:idx])] += size

return module_sizes


class ModelUtilsTest(unittest.TestCase):
def tearDown(self):
super().tearDown()
Expand Down Expand Up @@ -1012,7 +1080,7 @@ def test_cpu_offload(self):
torch.manual_seed(0)
base_output = model(**inputs_dict)

model_size = compute_module_sizes(model)[""]
model_size = compute_module_persistent_sizes(model)[""]
# We test several splits of sizes to make sure it works.
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down Expand Up @@ -1042,7 +1110,7 @@ def test_disk_offload_without_safetensors(self):
torch.manual_seed(0)
base_output = model(**inputs_dict)

model_size = compute_module_sizes(model)[""]
model_size = compute_module_persistent_sizes(model)[""]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we estimate the checkpoint size here, so should not include non-persistent buffers; technically we could pass remove_non_persistent to the named_module_tensors here but it will require a PR to accelerate so adding a function here for now
https://github.com/huggingface/accelerate/blob/200c9eb7833cfa505907f6f224ebf5a275aa6d92/src/accelerate/utils/modeling.py#L724

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there anything I need to do here? I'm not familiar with what you said above. 🤔

with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, safe_serialization=False)

Expand Down Expand Up @@ -1076,7 +1144,7 @@ def test_disk_offload_with_safetensors(self):
torch.manual_seed(0)
base_output = model(**inputs_dict)

model_size = compute_module_sizes(model)[""]
model_size = compute_module_persistent_sizes(model)[""]
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir)

Expand Down Expand Up @@ -1104,7 +1172,7 @@ def test_model_parallelism(self):
torch.manual_seed(0)
base_output = model(**inputs_dict)

model_size = compute_module_sizes(model)[""]
model_size = compute_module_persistent_sizes(model)[""]
# We test several splits of sizes to make sure it works.
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down Expand Up @@ -1132,7 +1200,7 @@ def test_sharded_checkpoints(self):

base_output = model(**inputs_dict)

model_size = compute_module_sizes(model)[""]
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
Expand Down Expand Up @@ -1164,7 +1232,7 @@ def test_sharded_checkpoints_with_variant(self):

base_output = model(**inputs_dict)

model_size = compute_module_sizes(model)[""]
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
variant = "fp16"
with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down Expand Up @@ -1204,7 +1272,7 @@ def test_sharded_checkpoints_device_map(self):
torch.manual_seed(0)
base_output = model(**inputs_dict)

model_size = compute_module_sizes(model)[""]
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
Expand Down Expand Up @@ -1233,7 +1301,7 @@ def test_variant_sharded_ckpt_right_format(self):
config, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval()

model_size = compute_module_sizes(model)[""]
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
variant = "fp16"
with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down
Loading