diff --git a/docs/source/en/api/pipelines/sana_sprint.md b/docs/source/en/api/pipelines/sana_sprint.md new file mode 100644 index 000000000000..482ef3c2c99d --- /dev/null +++ b/docs/source/en/api/pipelines/sana_sprint.md @@ -0,0 +1,96 @@ + + +# SanaSprintPipeline + +
+ LoRA +
+ +[SANA-Sprint: One-Step Diffusion with Continuous-Time Consistency Distillation](https://huggingface.co/papers/2503.09641) from NVIDIA and MIT HAN Lab, by Junsong Chen, Shuchen Xue, Yuyang Zhao, Jincheng Yu, Sayak Paul, Junyu Chen, Han Cai, Enze Xie, Song Han + +The abstract from the paper is: + +*This paper presents SANA-Sprint, an efficient diffusion model for ultra-fast text-to-image (T2I) generation. SANA-Sprint is built on a pre-trained foundation model and augmented with hybrid distillation, dramatically reducing inference steps from 20 to 1-4. We introduce three key innovations: (1) We propose a training-free approach that transforms a pre-trained flow-matching model for continuous-time consistency distillation (sCM), eliminating costly training from scratch and achieving high training efficiency. Our hybrid distillation strategy combines sCM with latent adversarial distillation (LADD): sCM ensures alignment with the teacher model, while LADD enhances single-step generation fidelity. (2) SANA-Sprint is a unified step-adaptive model that achieves high-quality generation in 1-4 steps, eliminating step-specific training and improving efficiency. (3) We integrate ControlNet with SANA-Sprint for real-time interactive image generation, enabling instant visual feedback for user interaction. SANA-Sprint establishes a new Pareto frontier in speed-quality tradeoffs, achieving state-of-the-art performance with 7.59 FID and 0.74 GenEval in only 1 step — outperforming FLUX-schnell (7.94 FID / 0.71 GenEval) while being 10× faster (0.1s vs 1.1s on H100). It also achieves 0.1s (T2I) and 0.25s (ControlNet) latency for 1024×1024 images on H100, and 0.31s (T2I) on an RTX 4090, showcasing its exceptional efficiency and potential for AI-powered consumer applications (AIPC). Code and pre-trained models will be open-sourced.* + + + +Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. + + + +This pipeline was contributed by [lawrence-cj](https://github.com/lawrence-cj), [shuchen Xue](https://github.com/scxue) and [Enze Xie](https://github.com/xieenze). The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://huggingface.co/Efficient-Large-Model/). + +Available models: + +| Model | Recommended dtype | +|:-------------------------------------------------------------------------------------------------------------------------------------------:|:-----------------:| +| [`Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers) | `torch.bfloat16` | +| [`Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers) | `torch.bfloat16` | + +Refer to [this](https://huggingface.co/collections/Efficient-Large-Model/sana-sprint-67d6810d65235085b3b17c76) collection for more information. + +Note: The recommended dtype mentioned is for the transformer weights. The text encoder must stay in `torch.bfloat16` and VAE weights must stay in `torch.bfloat16` or `torch.float32` for the model to work correctly. Please refer to the inference example below to see how to load the model with the recommended dtype. + + +## Quantization + +Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model. + +Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`SanaSprintPipeline`] for inference with bitsandbytes. + +```py +import torch +from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SanaTransformer2DModel, SanaSprintPipeline +from transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModel + +quant_config = BitsAndBytesConfig(load_in_8bit=True) +text_encoder_8bit = AutoModel.from_pretrained( + "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers", + subfolder="text_encoder", + quantization_config=quant_config, + torch_dtype=torch.bfloat16, +) + +quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True) +transformer_8bit = SanaTransformer2DModel.from_pretrained( + "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers", + subfolder="transformer", + quantization_config=quant_config, + torch_dtype=torch.bfloat16, +) + +pipeline = SanaSprintPipeline.from_pretrained( + "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers", + text_encoder=text_encoder_8bit, + transformer=transformer_8bit, + torch_dtype=torch.bfloat16, + device_map="balanced", +) + +prompt = "a tiny astronaut hatching from an egg on the moon" +image = pipeline(prompt).images[0] +image.save("sana.png") +``` + +## SanaSprintPipeline + +[[autodoc]] SanaSprintPipeline + - all + - __call__ + + +## SanaPipelineOutput + +[[autodoc]] pipelines.sana.pipeline_output.SanaPipelineOutput diff --git a/scripts/convert_sana_to_diffusers.py b/scripts/convert_sana_to_diffusers.py index a8bc1a51c13a..47e932ba5070 100644 --- a/scripts/convert_sana_to_diffusers.py +++ b/scripts/convert_sana_to_diffusers.py @@ -27,6 +27,7 @@ CTX = init_empty_weights if is_accelerate_available else nullcontext ckpt_ids = [ + "Efficient-Large-Model/SANA1.5_4.8B_1024px/checkpoints/SANA1.5_4.8B_1024px.pth", "Efficient-Large-Model/Sana_1600M_4Kpx_BF16/checkpoints/Sana_1600M_4Kpx_BF16.pth", "Efficient-Large-Model/Sana_1600M_2Kpx_BF16/checkpoints/Sana_1600M_2Kpx_BF16.pth", "Efficient-Large-Model/Sana_1600M_1024px_MultiLing/checkpoints/Sana_1600M_1024px_MultiLing.pth", @@ -75,7 +76,8 @@ def main(args): converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias") # Handle different time embedding structure based on model type - if args.model_type == "SanaSprint_1600M_P1_D20": + + if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]: # For Sana Sprint, the time embedding structure is different converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = state_dict.pop( "t_embedder.mlp.0.weight" @@ -128,10 +130,18 @@ def main(args): layer_num = 20 elif args.model_type == "SanaMS_600M_P1_D28": layer_num = 28 + elif args.model_type == "SanaMS_4800M_P1_D60": + layer_num = 60 else: raise ValueError(f"{args.model_type} is not supported.") # Positional embedding interpolation scale. interpolation_scale = {512: None, 1024: None, 2048: 1.0, 4096: 2.0} + qk_norm = "rms_norm_across_heads" if args.model_type in [ + "SanaMS1.5_1600M_P1_D20", + "SanaMS1.5_4800M_P1_D60", + "SanaSprint_600M_P1_D28", + "SanaSprint_1600M_P1_D20" + ] else None for depth in range(layer_num): # Transformer blocks. @@ -145,6 +155,14 @@ def main(args): converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v + if qk_norm is not None: + # Add Q/K normalization for self-attention (attn1) - needed for Sana-Sprint and Sana-1.5 + converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_q.weight"] = state_dict.pop( + f"blocks.{depth}.attn.q_norm.weight" + ) + converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_k.weight"] = state_dict.pop( + f"blocks.{depth}.attn.k_norm.weight" + ) # Projection. converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop( f"blocks.{depth}.attn.proj.weight" @@ -191,6 +209,14 @@ def main(args): converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias + if qk_norm is not None: + # Add Q/K normalization for cross-attention (attn2) - needed for Sana-Sprint and Sana-1.5 + converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_q.weight"] = state_dict.pop( + f"blocks.{depth}.cross_attn.q_norm.weight" + ) + converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_k.weight"] = state_dict.pop( + f"blocks.{depth}.cross_attn.k_norm.weight" + ) # Add Q/K normalization for cross-attention (attn2) - needed for Sana Sprint if args.model_type == "SanaSprint_1600M_P1_D20": @@ -235,8 +261,7 @@ def main(args): } # Add qk_norm parameter for Sana Sprint - if args.model_type == "SanaSprint_1600M_P1_D20": - transformer_kwargs["qk_norm"] = "rms_norm_across_heads" + if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]: transformer_kwargs["guidance_embeds"] = True transformer = SanaTransformer2DModel(**transformer_kwargs) @@ -271,15 +296,15 @@ def main(args): ) ) transformer.save_pretrained( - os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB", variant=variant + os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB" ) else: print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"])) # VAE - ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers", torch_dtype=torch.float32) + ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers", torch_dtype=torch.float32) # Text Encoder - text_encoder_model_path = "google/gemma-2-2b-it" + text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it" tokenizer = AutoTokenizer.from_pretrained(text_encoder_model_path) tokenizer.padding_side = "right" text_encoder = AutoModelForCausalLM.from_pretrained( @@ -287,7 +312,8 @@ def main(args): ).get_decoder() # Choose the appropriate pipeline and scheduler based on model type - if args.model_type == "SanaSprint_1600M_P1_D20": + if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]: + # Force SCM Scheduler for Sana Sprint regardless of scheduler_type if args.scheduler_type != "scm": print( @@ -335,7 +361,7 @@ def main(args): scheduler=scheduler, ) - pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB", variant=variant) + pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB") DTYPE_MAPPING = { @@ -344,12 +370,6 @@ def main(args): "bf16": torch.bfloat16, } -VARIANT_MAPPING = { - "fp32": None, - "fp16": "fp16", - "bf16": "bf16", -} - if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -369,7 +389,7 @@ def main(args): "--model_type", default="SanaMS_1600M_P1_D20", type=str, - choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28", "SanaSprint_1600M_P1_D20"], + choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28", "SanaMS_4800M_P1_D60", "SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"], ) parser.add_argument( "--scheduler_type", @@ -400,6 +420,30 @@ def main(args): "cross_attention_head_dim": 72, "cross_attention_dim": 1152, "num_layers": 28, + }, + "SanaMS1.5_1600M_P1_D20": { + "num_attention_heads": 70, + "attention_head_dim": 32, + "num_cross_attention_heads": 20, + "cross_attention_head_dim": 112, + "cross_attention_dim": 2240, + "num_layers": 20, + }, + "SanaMS1.5__4800M_P1_D60": { + "num_attention_heads": 70, + "attention_head_dim": 32, + "num_cross_attention_heads": 20, + "cross_attention_head_dim": 112, + "cross_attention_dim": 2240, + "num_layers": 60, + }, + "SanaSprint_600M_P1_D28": { + "num_attention_heads": 36, + "attention_head_dim": 32, + "num_cross_attention_heads": 16, + "cross_attention_head_dim": 72, + "cross_attention_dim": 1152, + "num_layers": 28, }, "SanaSprint_1600M_P1_D20": { "num_attention_heads": 70, @@ -413,6 +457,5 @@ def main(args): device = "cuda" if torch.cuda.is_available() else "cpu" weight_dtype = DTYPE_MAPPING[args.dtype] - variant = VARIANT_MAPPING[args.dtype] main(args)