Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/api/models/sana_transformer2d.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ The model can be loaded with the following code snippet.
```python
from diffusers import SanaTransformer2DModel

transformer = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/Sana_1600M_1024px_diffusers", subfolder="transformer", torch_dtype=torch.float16)
transformer = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
```

## SanaTransformer2DModel
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/api/pipelines/sana.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ Available models:

| Model | Recommended dtype |
|:-----:|:-----------------:|
| [`Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers) | `torch.bfloat16` |
| [`Efficient-Large-Model/Sana_1600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_diffusers) | `torch.float16` |
| [`Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers) | `torch.float16` |
| [`Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers) | `torch.bfloat16` |
| [`Efficient-Large-Model/Sana_1600M_512px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_diffusers) | `torch.float16` |
| [`Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers) | `torch.float16` |
| [`Efficient-Large-Model/Sana_600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_600M_1024px_diffusers) | `torch.float16` |
Expand Down
6 changes: 6 additions & 0 deletions scripts/convert_sana_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,18 @@ def main(args):
# y norm
converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")

# scheduler
flow_shift = 3.0

# model config
if args.model_type == "SanaMS_1600M_P1_D20":
layer_num = 20
elif args.model_type == "SanaMS_600M_P1_D28":
layer_num = 28
else:
raise ValueError(f"{args.model_type} is not supported.")
# Positional embedding interpolation scale.
interpolation_scale = {512: None, 1024: None, 2048: 1.0}

for depth in range(layer_num):
# Transformer blocks.
Expand Down Expand Up @@ -176,6 +181,7 @@ def main(args):
patch_size=1,
norm_elementwise_affine=False,
norm_eps=1e-6,
interpolation_scale=interpolation_scale[args.image_size],
)

if is_accelerate_available():
Expand Down
9 changes: 7 additions & 2 deletions src/diffusers/models/transformers/sana_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,21 +242,26 @@ def __init__(
patch_size: int = 1,
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6,
interpolation_scale: Optional[int] = None,
) -> None:
super().__init__()

out_channels = out_channels or in_channels
inner_dim = num_attention_heads * attention_head_dim

# 1. Patch Embedding
interpolation_scale = (
interpolation_scale
if interpolation_scale is not None
else max(sample_size // 64, 1)
)
self.patch_embed = PatchEmbed(
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
interpolation_scale=None,
pos_embed_type=None,
interpolation_scale=interpolation_scale,
)

# 2. Additional condition embeddings
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/pag/pipeline_pag_sana.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,13 @@
>>> from diffusers import SanaPAGPipeline
>>> pipe = SanaPAGPipeline.from_pretrained(
... "Efficient-Large-Model/Sana_1600M_1024px_diffusers",
... "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
... pag_applied_layers=["transformer_blocks.8"],
... torch_dtype=torch.float32,
... )
>>> pipe.to("cuda")
>>> pipe.text_encoder.to(torch.bfloat16)
>>> pipe.transformer = pipe.transformer.to(torch.float16)
>>> pipe.transformer = pipe.transformer.to(torch.bfloat16)
>>> image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0]
>>> image[0].save("output.png")
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/sana/pipeline_sana.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@
>>> from diffusers import SanaPipeline
>>> pipe = SanaPipeline.from_pretrained(
... "Efficient-Large-Model/Sana_1600M_1024px_diffusers", torch_dtype=torch.float32
... "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", torch_dtype=torch.float32
... )
>>> pipe.to("cuda")
>>> pipe.text_encoder.to(torch.bfloat16)
>>> pipe.transformer = pipe.transformer.to(torch.float16)
>>> pipe.transformer = pipe.transformer.to(torch.bfloat16)
>>> image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0]
>>> image[0].save("output.png")
Expand Down
Loading