diff --git a/docs/source/en/api/pipelines/sana_sprint.md b/docs/source/en/api/pipelines/sana_sprint.md
new file mode 100644
index 000000000000..482ef3c2c99d
--- /dev/null
+++ b/docs/source/en/api/pipelines/sana_sprint.md
@@ -0,0 +1,96 @@
+
+
+# SanaSprintPipeline
+
+
+

+
+
+[SANA-Sprint: One-Step Diffusion with Continuous-Time Consistency Distillation](https://huggingface.co/papers/2503.09641) from NVIDIA and MIT HAN Lab, by Junsong Chen, Shuchen Xue, Yuyang Zhao, Jincheng Yu, Sayak Paul, Junyu Chen, Han Cai, Enze Xie, Song Han
+
+The abstract from the paper is:
+
+*This paper presents SANA-Sprint, an efficient diffusion model for ultra-fast text-to-image (T2I) generation. SANA-Sprint is built on a pre-trained foundation model and augmented with hybrid distillation, dramatically reducing inference steps from 20 to 1-4. We introduce three key innovations: (1) We propose a training-free approach that transforms a pre-trained flow-matching model for continuous-time consistency distillation (sCM), eliminating costly training from scratch and achieving high training efficiency. Our hybrid distillation strategy combines sCM with latent adversarial distillation (LADD): sCM ensures alignment with the teacher model, while LADD enhances single-step generation fidelity. (2) SANA-Sprint is a unified step-adaptive model that achieves high-quality generation in 1-4 steps, eliminating step-specific training and improving efficiency. (3) We integrate ControlNet with SANA-Sprint for real-time interactive image generation, enabling instant visual feedback for user interaction. SANA-Sprint establishes a new Pareto frontier in speed-quality tradeoffs, achieving state-of-the-art performance with 7.59 FID and 0.74 GenEval in only 1 step — outperforming FLUX-schnell (7.94 FID / 0.71 GenEval) while being 10× faster (0.1s vs 1.1s on H100). It also achieves 0.1s (T2I) and 0.25s (ControlNet) latency for 1024×1024 images on H100, and 0.31s (T2I) on an RTX 4090, showcasing its exceptional efficiency and potential for AI-powered consumer applications (AIPC). Code and pre-trained models will be open-sourced.*
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+This pipeline was contributed by [lawrence-cj](https://github.com/lawrence-cj), [shuchen Xue](https://github.com/scxue) and [Enze Xie](https://github.com/xieenze). The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://huggingface.co/Efficient-Large-Model/).
+
+Available models:
+
+| Model | Recommended dtype |
+|:-------------------------------------------------------------------------------------------------------------------------------------------:|:-----------------:|
+| [`Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers) | `torch.bfloat16` |
+| [`Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers) | `torch.bfloat16` |
+
+Refer to [this](https://huggingface.co/collections/Efficient-Large-Model/sana-sprint-67d6810d65235085b3b17c76) collection for more information.
+
+Note: The recommended dtype mentioned is for the transformer weights. The text encoder must stay in `torch.bfloat16` and VAE weights must stay in `torch.bfloat16` or `torch.float32` for the model to work correctly. Please refer to the inference example below to see how to load the model with the recommended dtype.
+
+
+## Quantization
+
+Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
+
+Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`SanaSprintPipeline`] for inference with bitsandbytes.
+
+```py
+import torch
+from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SanaTransformer2DModel, SanaSprintPipeline
+from transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModel
+
+quant_config = BitsAndBytesConfig(load_in_8bit=True)
+text_encoder_8bit = AutoModel.from_pretrained(
+ "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
+ subfolder="text_encoder",
+ quantization_config=quant_config,
+ torch_dtype=torch.bfloat16,
+)
+
+quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
+transformer_8bit = SanaTransformer2DModel.from_pretrained(
+ "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
+ subfolder="transformer",
+ quantization_config=quant_config,
+ torch_dtype=torch.bfloat16,
+)
+
+pipeline = SanaSprintPipeline.from_pretrained(
+ "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
+ text_encoder=text_encoder_8bit,
+ transformer=transformer_8bit,
+ torch_dtype=torch.bfloat16,
+ device_map="balanced",
+)
+
+prompt = "a tiny astronaut hatching from an egg on the moon"
+image = pipeline(prompt).images[0]
+image.save("sana.png")
+```
+
+## SanaSprintPipeline
+
+[[autodoc]] SanaSprintPipeline
+ - all
+ - __call__
+
+
+## SanaPipelineOutput
+
+[[autodoc]] pipelines.sana.pipeline_output.SanaPipelineOutput
diff --git a/scripts/convert_sana_to_diffusers.py b/scripts/convert_sana_to_diffusers.py
index a8bc1a51c13a..47e932ba5070 100644
--- a/scripts/convert_sana_to_diffusers.py
+++ b/scripts/convert_sana_to_diffusers.py
@@ -27,6 +27,7 @@
CTX = init_empty_weights if is_accelerate_available else nullcontext
ckpt_ids = [
+ "Efficient-Large-Model/SANA1.5_4.8B_1024px/checkpoints/SANA1.5_4.8B_1024px.pth",
"Efficient-Large-Model/Sana_1600M_4Kpx_BF16/checkpoints/Sana_1600M_4Kpx_BF16.pth",
"Efficient-Large-Model/Sana_1600M_2Kpx_BF16/checkpoints/Sana_1600M_2Kpx_BF16.pth",
"Efficient-Large-Model/Sana_1600M_1024px_MultiLing/checkpoints/Sana_1600M_1024px_MultiLing.pth",
@@ -75,7 +76,8 @@ def main(args):
converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")
# Handle different time embedding structure based on model type
- if args.model_type == "SanaSprint_1600M_P1_D20":
+
+ if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]:
# For Sana Sprint, the time embedding structure is different
converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = state_dict.pop(
"t_embedder.mlp.0.weight"
@@ -128,10 +130,18 @@ def main(args):
layer_num = 20
elif args.model_type == "SanaMS_600M_P1_D28":
layer_num = 28
+ elif args.model_type == "SanaMS_4800M_P1_D60":
+ layer_num = 60
else:
raise ValueError(f"{args.model_type} is not supported.")
# Positional embedding interpolation scale.
interpolation_scale = {512: None, 1024: None, 2048: 1.0, 4096: 2.0}
+ qk_norm = "rms_norm_across_heads" if args.model_type in [
+ "SanaMS1.5_1600M_P1_D20",
+ "SanaMS1.5_4800M_P1_D60",
+ "SanaSprint_600M_P1_D28",
+ "SanaSprint_1600M_P1_D20"
+ ] else None
for depth in range(layer_num):
# Transformer blocks.
@@ -145,6 +155,14 @@ def main(args):
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
+ if qk_norm is not None:
+ # Add Q/K normalization for self-attention (attn1) - needed for Sana-Sprint and Sana-1.5
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_q.weight"] = state_dict.pop(
+ f"blocks.{depth}.attn.q_norm.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_k.weight"] = state_dict.pop(
+ f"blocks.{depth}.attn.k_norm.weight"
+ )
# Projection.
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop(
f"blocks.{depth}.attn.proj.weight"
@@ -191,6 +209,14 @@ def main(args):
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias
+ if qk_norm is not None:
+ # Add Q/K normalization for cross-attention (attn2) - needed for Sana-Sprint and Sana-1.5
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_q.weight"] = state_dict.pop(
+ f"blocks.{depth}.cross_attn.q_norm.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_k.weight"] = state_dict.pop(
+ f"blocks.{depth}.cross_attn.k_norm.weight"
+ )
# Add Q/K normalization for cross-attention (attn2) - needed for Sana Sprint
if args.model_type == "SanaSprint_1600M_P1_D20":
@@ -235,8 +261,7 @@ def main(args):
}
# Add qk_norm parameter for Sana Sprint
- if args.model_type == "SanaSprint_1600M_P1_D20":
- transformer_kwargs["qk_norm"] = "rms_norm_across_heads"
+ if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]:
transformer_kwargs["guidance_embeds"] = True
transformer = SanaTransformer2DModel(**transformer_kwargs)
@@ -271,15 +296,15 @@ def main(args):
)
)
transformer.save_pretrained(
- os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB", variant=variant
+ os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB"
)
else:
print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"]))
# VAE
- ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers", torch_dtype=torch.float32)
+ ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers", torch_dtype=torch.float32)
# Text Encoder
- text_encoder_model_path = "google/gemma-2-2b-it"
+ text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(text_encoder_model_path)
tokenizer.padding_side = "right"
text_encoder = AutoModelForCausalLM.from_pretrained(
@@ -287,7 +312,8 @@ def main(args):
).get_decoder()
# Choose the appropriate pipeline and scheduler based on model type
- if args.model_type == "SanaSprint_1600M_P1_D20":
+ if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]:
+
# Force SCM Scheduler for Sana Sprint regardless of scheduler_type
if args.scheduler_type != "scm":
print(
@@ -335,7 +361,7 @@ def main(args):
scheduler=scheduler,
)
- pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB", variant=variant)
+ pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB")
DTYPE_MAPPING = {
@@ -344,12 +370,6 @@ def main(args):
"bf16": torch.bfloat16,
}
-VARIANT_MAPPING = {
- "fp32": None,
- "fp16": "fp16",
- "bf16": "bf16",
-}
-
if __name__ == "__main__":
parser = argparse.ArgumentParser()
@@ -369,7 +389,7 @@ def main(args):
"--model_type",
default="SanaMS_1600M_P1_D20",
type=str,
- choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28", "SanaSprint_1600M_P1_D20"],
+ choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28", "SanaMS_4800M_P1_D60", "SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"],
)
parser.add_argument(
"--scheduler_type",
@@ -400,6 +420,30 @@ def main(args):
"cross_attention_head_dim": 72,
"cross_attention_dim": 1152,
"num_layers": 28,
+ },
+ "SanaMS1.5_1600M_P1_D20": {
+ "num_attention_heads": 70,
+ "attention_head_dim": 32,
+ "num_cross_attention_heads": 20,
+ "cross_attention_head_dim": 112,
+ "cross_attention_dim": 2240,
+ "num_layers": 20,
+ },
+ "SanaMS1.5__4800M_P1_D60": {
+ "num_attention_heads": 70,
+ "attention_head_dim": 32,
+ "num_cross_attention_heads": 20,
+ "cross_attention_head_dim": 112,
+ "cross_attention_dim": 2240,
+ "num_layers": 60,
+ },
+ "SanaSprint_600M_P1_D28": {
+ "num_attention_heads": 36,
+ "attention_head_dim": 32,
+ "num_cross_attention_heads": 16,
+ "cross_attention_head_dim": 72,
+ "cross_attention_dim": 1152,
+ "num_layers": 28,
},
"SanaSprint_1600M_P1_D20": {
"num_attention_heads": 70,
@@ -413,6 +457,5 @@ def main(args):
device = "cuda" if torch.cuda.is_available() else "cpu"
weight_dtype = DTYPE_MAPPING[args.dtype]
- variant = VARIANT_MAPPING[args.dtype]
main(args)