Skip to content

Commit f698524

Browse files
committed
refactor; add docs; add tests; update conversion script
1 parent 0f9daec commit f698524

File tree

12 files changed

+629
-476
lines changed

12 files changed

+629
-476
lines changed

docs/source/en/_toctree.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,8 @@
282282
title: PriorTransformer
283283
- local: api/models/sd3_transformer2d
284284
title: SD3Transformer2DModel
285+
- local: api/models/sana_transformer2d
286+
title: SanaTransformer2DModel
285287
- local: api/models/stable_audio_transformer
286288
title: StableAudioDiTModel
287289
- local: api/models/transformer2d
@@ -428,6 +430,8 @@
428430
title: PixArt-α
429431
- local: api/pipelines/pixart_sigma
430432
title: PixArt-Σ
433+
- local: api/pipelines/sana
434+
title: Sana
431435
- local: api/pipelines/self_attention_guidance
432436
title: Self-Attention Guidance
433437
- local: api/pipelines/semantic_stable_diffusion
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License. -->
11+
12+
# SanaTransformer2DModel
13+
14+
A Diffusion Transformer model for 2D data from [SANA: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformers](https://huggingface.co/papers/2410.10629) was introduced from NVIDIA and MIT HAN Lab, by Enze Xie, Junsong Chen, Junyu Chen, Han Cai, Haotian Tang, Yujun Lin, Zhekai Zhang, Muyang Li, Ligeng Zhu, Yao Lu, Song Han.
15+
16+
The abstract from the paper is:
17+
18+
*We introduce Sana, a text-to-image framework that can efficiently generate images up to 4096×4096 resolution. Sana can synthesize high-resolution, high-quality images with strong text-image alignment at a remarkably fast speed, deployable on laptop GPU. Core designs include: (1) Deep compression autoencoder: unlike traditional AEs, which compress images only 8×, we trained an AE that can compress images 32×, effectively reducing the number of latent tokens. (2) Linear DiT: we replace all vanilla attention in DiT with linear attention, which is more efficient at high resolutions without sacrificing quality. (3) Decoder-only text encoder: we replaced T5 with modern decoder-only small LLM as the text encoder and designed complex human instruction with in-context learning to enhance the image-text alignment. (4) Efficient training and sampling: we propose Flow-DPM-Solver to reduce sampling steps, with efficient caption labeling and selection to accelerate convergence. As a result, Sana-0.6B is very competitive with modern giant diffusion model (e.g. Flux-12B), being 20 times smaller and 100+ times faster in measured throughput. Moreover, Sana-0.6B can be deployed on a 16GB laptop GPU, taking less than 1 second to generate a 1024×1024 resolution image. Sana enables content creation at low cost. Code and model will be publicly released.*
19+
20+
The model can be loaded with the following code snippet.
21+
22+
```python
23+
TODO(aryan)
24+
```
25+
26+
## SanaTransformer2DModel
27+
28+
[[autodoc]] SanaPlusTransformer2DModel
29+
30+
## Transformer2DModelOutput
31+
32+
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License. -->
14+
15+
# SanaPipeline
16+
17+
[SANA: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformers](https://huggingface.co/papers/2410.10629) from NVIDIA and MIT HAN Lab, by Enze Xie, Junsong Chen, Junyu Chen, Han Cai, Haotian Tang, Yujun Lin, Zhekai Zhang, Muyang Li, Ligeng Zhu, Yao Lu, Song Han.
18+
19+
The abstract from the paper is:
20+
21+
*We introduce Sana, a text-to-image framework that can efficiently generate images up to 4096×4096 resolution. Sana can synthesize high-resolution, high-quality images with strong text-image alignment at a remarkably fast speed, deployable on laptop GPU. Core designs include: (1) Deep compression autoencoder: unlike traditional AEs, which compress images only 8×, we trained an AE that can compress images 32×, effectively reducing the number of latent tokens. (2) Linear DiT: we replace all vanilla attention in DiT with linear attention, which is more efficient at high resolutions without sacrificing quality. (3) Decoder-only text encoder: we replaced T5 with modern decoder-only small LLM as the text encoder and designed complex human instruction with in-context learning to enhance the image-text alignment. (4) Efficient training and sampling: we propose Flow-DPM-Solver to reduce sampling steps, with efficient caption labeling and selection to accelerate convergence. As a result, Sana-0.6B is very competitive with modern giant diffusion model (e.g. Flux-12B), being 20 times smaller and 100+ times faster in measured throughput. Moreover, Sana-0.6B can be deployed on a 16GB laptop GPU, taking less than 1 second to generate a 1024×1024 resolution image. Sana enables content creation at low cost. Code and model will be publicly released.*
22+
23+
<Tip>
24+
25+
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
26+
27+
</Tip>
28+
29+
This pipeline was contributed by [lawrence-cj](https://github.com/lawrence-cj). The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model]https://huggingface.co/Efficient-Large-Model).
30+
31+
## SanaPipeline
32+
33+
[[autodoc]] SanaPipeline
34+
- all
35+
- __call__
36+
37+
## SanaPipelineOutput
38+
39+
[[autodoc]] pipelines.sana.pipeline_output.SanaPipelineOutput

scripts/convert_sana_to_diffusers.py

Lines changed: 23 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
def main(args):
3939
ckpt_id = ckpt_ids[0]
4040
cache_dir_path = os.path.expanduser("~/.cache/huggingface/hub")
41+
4142
if args.orig_ckpt_path is None:
4243
snapshot_download(
4344
repo_id=ckpt_id,
@@ -52,6 +53,7 @@ def main(args):
5253
)
5354
else:
5455
file_path = args.orig_ckpt_path
56+
5557
all_state_dict = torch.load(file_path, weights_only=True)
5658
state_dict = all_state_dict.pop("state_dict")
5759
converted_state_dict = {}
@@ -96,8 +98,8 @@ def main(args):
9698
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
9799
f"blocks.{depth}.scale_shift_table"
98100
)
101+
99102
# Linear Attention is all you need 🤘
100-
101103
# Self attention.
102104
q, k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0)
103105
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
@@ -156,27 +158,20 @@ def main(args):
156158
# Transformer
157159
with CTX():
158160
transformer = SanaTransformer2DModel(
161+
in_channels=32,
162+
out_channels=32,
159163
num_attention_heads=model_kwargs[args.model_type]["num_attention_heads"],
160164
attention_head_dim=model_kwargs[args.model_type]["attention_head_dim"],
165+
num_layers=model_kwargs[args.model_type]["num_layers"],
161166
num_cross_attention_heads=model_kwargs[args.model_type]["num_cross_attention_heads"],
162167
cross_attention_head_dim=model_kwargs[args.model_type]["cross_attention_head_dim"],
163-
in_channels=32,
164-
out_channels=32,
165-
num_layers=model_kwargs[args.model_type]["num_layers"],
166168
cross_attention_dim=model_kwargs[args.model_type]["cross_attention_dim"],
167169
attention_bias=False,
168170
sample_size=32,
169171
patch_size=1,
170-
upcast_attention=False,
171-
norm_type="ada_norm_single",
172172
norm_elementwise_affine=False,
173173
norm_eps=1e-6,
174-
use_additional_conditions=False,
175174
caption_channels=2304,
176-
use_caption_norm=True,
177-
caption_norm_scale_factor=0.1,
178-
attention_type="default",
179-
use_pe=False,
180175
expand_ratio=2.5,
181176
)
182177
if is_accelerate_available():
@@ -203,24 +198,17 @@ def main(args):
203198
attrs=["bold"],
204199
)
205200
)
206-
transformer.to(weight_dtype).save_pretrained(os.path.join(args.dump_path, "transformer"))
201+
transformer.save_pretrained(os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB", variant=variant)
207202
else:
208203
print(colored(f"Saving the whole SanaPipeline containing {args.model_type}", "green", attrs=["bold"]))
209204
# VAE
210-
ae = AutoencoderDC.from_pretrained(
211-
"mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers",
212-
torch_dtype=torch.bfloat16,
213-
).to(device)
205+
ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers",)
214206

215207
# Text Encoder
216208
text_encoder_model_path = "google/gemma-2-2b-it"
217209
tokenizer = AutoTokenizer.from_pretrained(text_encoder_model_path)
218210
tokenizer.padding_side = "right"
219-
text_encoder = (
220-
AutoModelForCausalLM.from_pretrained(text_encoder_model_path, torch_dtype=torch.bfloat16)
221-
.get_decoder()
222-
.to(device)
223-
)
211+
text_encoder = AutoModelForCausalLM.from_pretrained(text_encoder_model_path).get_decoder()
224212

225213
# Scheduler
226214
if args.scheduler_type == "flow-dpm_solver":
@@ -234,27 +222,27 @@ def main(args):
234222
else:
235223
raise ValueError(f"Scheduler type {args.scheduler_type} is not supported")
236224

237-
# transformer
238-
transformer.to(device).to(weight_dtype)
239-
240225
pipe = SanaPipeline(
241226
tokenizer=tokenizer,
242227
text_encoder=text_encoder,
243228
transformer=transformer,
244229
vae=ae,
245230
scheduler=scheduler,
246231
)
232+
pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB", variant=variant)
247233

248-
image = pipe(
249-
"a dog",
250-
height=1024,
251-
width=1024,
252-
guidance_scale=5.0,
253-
)[0]
254234

255-
image[0].save("sana.png")
235+
DTYPE_MAPPING = {
236+
"fp32": torch.float32,
237+
"fp16": torch.float16,
238+
"bf16": torch.bfloat16,
239+
}
256240

257-
pipe.save_pretrained(args.dump_path)
241+
VARIANT_MAPPING = {
242+
"fp32": None,
243+
"fp16": "fp16",
244+
"bf16": "bf16",
245+
}
258246

259247

260248
if __name__ == "__main__":
@@ -279,6 +267,7 @@ def main(args):
279267
)
280268
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
281269
parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipelien elemets in one.")
270+
parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")
282271

283272
args = parser.parse_args()
284273

@@ -302,6 +291,7 @@ def main(args):
302291
}
303292

304293
device = "cuda" if torch.cuda.is_available() else "cpu"
305-
weight_dtype = torch.float16
294+
weight_dtype = DTYPE_MAPPING[args.dtype]
295+
variant = VARIANT_MAPPING[args.dtype]
306296

307297
main(args)

src/diffusers/models/attention_processor.py

Lines changed: 8 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5358,77 +5358,47 @@ def __call__(
53585358
hidden_states: torch.Tensor,
53595359
encoder_hidden_states: Optional[torch.Tensor] = None,
53605360
attention_mask: Optional[torch.Tensor] = None,
5361-
temb: Optional[torch.Tensor] = None,
5362-
*args,
5363-
**kwargs,
53645361
) -> torch.Tensor:
5365-
if len(args) > 0 or kwargs.get("scale", None) is not None:
5366-
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
5367-
deprecate("scale", "1.0.0", deprecation_message)
5368-
5369-
residual = hidden_states
5370-
if attn.spatial_norm is not None:
5371-
hidden_states = attn.spatial_norm(hidden_states, temb)
5372-
53735362
input_ndim = hidden_states.ndim
5363+
original_dtype = hidden_states.dtype
53745364

5375-
if input_ndim == 4:
5376-
batch_size, channel, height, width = hidden_states.shape
5377-
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
5378-
5379-
batch_size, sequence_length, _ = (
5365+
batch_size, _, _ = (
53805366
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
53815367
)
53825368

5383-
query = attn.to_q(hidden_states)
5384-
53855369
if encoder_hidden_states is None:
53865370
encoder_hidden_states = hidden_states
5387-
elif attn.norm_cross:
5388-
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
53895371

5372+
query = attn.to_q(hidden_states)
53905373
key = attn.to_k(encoder_hidden_states)
53915374
value = attn.to_v(encoder_hidden_states)
53925375

53935376
inner_dim = key.shape[-1]
53945377
head_dim = inner_dim // attn.heads
53955378

5396-
dtype = query.dtype
5397-
53985379
query = query.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
53995380
key = key.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1).transpose(-1, -2)
54005381
value = value.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
54015382

5402-
query = self.kernel_func(query) # B, h, h_d, N
5383+
query = self.kernel_func(query)
54035384
key = self.kernel_func(key)
54045385

5405-
# need torch.float
54065386
query, key, value = query.float(), key.float(), value.float()
54075387

54085388
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=self.pad_val)
5409-
vk = torch.matmul(value, key)
5410-
hidden_states = torch.matmul(vk, query)
5389+
scores = torch.matmul(value, key)
5390+
hidden_states = torch.matmul(scores, query)
54115391

54125392
if hidden_states.dtype in [torch.float16, torch.bfloat16]:
54135393
hidden_states = hidden_states.float()
5394+
54145395
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
5415-
54165396
hidden_states = hidden_states.view(batch_size, attn.heads * head_dim, -1).permute(0, 2, 1)
5417-
hidden_states = hidden_states.to(dtype)
5397+
hidden_states = hidden_states.to(original_dtype)
54185398

5419-
# linear proj
54205399
hidden_states = attn.to_out[0](hidden_states)
5421-
# dropout
54225400
hidden_states = attn.to_out[1](hidden_states)
54235401

5424-
if input_ndim == 4:
5425-
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
5426-
5427-
if attn.residual_connection:
5428-
hidden_states = hidden_states + residual
5429-
5430-
hidden_states = hidden_states / attn.rescale_output_factor
5431-
54325402
if hidden_states.dtype == torch.float16:
54335403
hidden_states = hidden_states.clip(-65504, 65504)
54345404

src/diffusers/models/normalization.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -590,42 +590,3 @@ def get_normalization(
590590
else:
591591
raise ValueError(f"{norm_type=} is not supported.")
592592
return norm
593-
594-
595-
class RMSNormScaled(nn.Module):
596-
def __init__(self, dim, eps: float, elementwise_affine: bool = True, scale_factor: float = 1.0, bias: bool = False):
597-
super().__init__()
598-
self.weight = nn.Parameter(torch.ones(dim) * scale_factor)
599-
600-
self.eps = eps
601-
self.elementwise_affine = elementwise_affine
602-
603-
if isinstance(dim, numbers.Integral):
604-
dim = (dim,)
605-
606-
self.dim = torch.Size(dim)
607-
608-
self.weight = None
609-
self.bias = None
610-
611-
if elementwise_affine:
612-
self.weight = nn.Parameter(torch.ones(dim) * scale_factor)
613-
if bias:
614-
self.bias = nn.Parameter(torch.zeros(dim))
615-
616-
def forward(self, hidden_states):
617-
input_dtype = hidden_states.dtype
618-
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
619-
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
620-
621-
if self.weight is not None:
622-
# convert into half-precision if necessary
623-
if self.weight.dtype in [torch.float16, torch.bfloat16]:
624-
hidden_states = hidden_states.to(self.weight.dtype)
625-
hidden_states = hidden_states * self.weight
626-
if self.bias is not None:
627-
hidden_states = hidden_states + self.bias
628-
else:
629-
hidden_states = hidden_states.to(input_dtype)
630-
631-
return hidden_states

0 commit comments

Comments
 (0)