From 8232a2f0aae56b98add928227e31ff08ca29aead Mon Sep 17 00:00:00 2001 From: junsong Date: Wed, 8 Jan 2025 00:41:58 -0800 Subject: [PATCH 1/6] [Sana 4K] add 4K support for Sana --- scripts/convert_sana_to_diffusers.py | 12 +++-- src/diffusers/pipelines/sana/pipeline_sana.py | 47 ++++++++++++++++++- 2 files changed, 54 insertions(+), 5 deletions(-) diff --git a/scripts/convert_sana_to_diffusers.py b/scripts/convert_sana_to_diffusers.py index 2f1732817be3..99a9ff322251 100644 --- a/scripts/convert_sana_to_diffusers.py +++ b/scripts/convert_sana_to_diffusers.py @@ -25,6 +25,7 @@ CTX = init_empty_weights if is_accelerate_available else nullcontext ckpt_ids = [ + "Efficient-Large-Model/Sana_1600M_4Kpx_BF16/checkpoints/Sana_1600M_4Kpx_BF16.pth", "Efficient-Large-Model/Sana_1600M_2Kpx_BF16/checkpoints/Sana_1600M_2Kpx_BF16.pth", "Efficient-Large-Model/Sana_1600M_1024px_MultiLing/checkpoints/Sana_1600M_1024px_MultiLing.pth", "Efficient-Large-Model/Sana_1600M_1024px_BF16/checkpoints/Sana_1600M_1024px_BF16.pth", @@ -89,7 +90,10 @@ def main(args): converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight") # scheduler - flow_shift = 3.0 + if args.image_size == 4096: + flow_shift = 6.0 + else: + flow_shift = 3.0 # model config if args.model_type == "SanaMS_1600M_P1_D20": @@ -99,7 +103,7 @@ def main(args): else: raise ValueError(f"{args.model_type} is not supported.") # Positional embedding interpolation scale. - interpolation_scale = {512: None, 1024: None, 2048: 1.0} + interpolation_scale = {512: None, 1024: None, 2048: 1.0, 4096: 2.0} for depth in range(layer_num): # Transformer blocks. @@ -272,9 +276,9 @@ def main(args): "--image_size", default=1024, type=int, - choices=[512, 1024, 2048], + choices=[512, 1024, 2048, 4096], required=False, - help="Image size of pretrained model, 512, 1024 or 2048.", + help="Image size of pretrained model, 512, 1024, 2048 or 4096.", ) parser.add_argument( "--model_type", default="SanaMS_1600M_P1_D20", type=str, choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28"] diff --git a/src/diffusers/pipelines/sana/pipeline_sana.py b/src/diffusers/pipelines/sana/pipeline_sana.py index 895396fae3c4..afc2f74c9e8f 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana.py +++ b/src/diffusers/pipelines/sana/pipeline_sana.py @@ -63,6 +63,49 @@ import ftfy +ASPECT_RATIO_4096_BIN = { + "0.25": [2048.0, 8192.0], + "0.26": [2048.0, 7936.0], + "0.27": [2048.0, 7680.0], + "0.28": [2048.0, 7424.0], + "0.32": [2304.0, 7168.0], + "0.33": [2304.0, 6912.0], + "0.35": [2304.0, 6656.0], + "0.4": [2560.0, 6400.0], + "0.42": [2560.0, 6144.0], + "0.48": [2816.0, 5888.0], + "0.5": [2816.0, 5632.0], + "0.52": [2816.0, 5376.0], + "0.57": [3072.0, 5376.0], + "0.6": [3072.0, 5120.0], + "0.68": [3328.0, 4864.0], + "0.72": [3328.0, 4608.0], + "0.78": [3584.0, 4608.0], + "0.82": [3584.0, 4352.0], + "0.88": [3840.0, 4352.0], + "0.94": [3840.0, 4096.0], + "1.0": [4096.0, 4096.0], + "1.07": [4096.0, 3840.0], + "1.13": [4352.0, 3840.0], + "1.21": [4352.0, 3584.0], + "1.29": [4608.0, 3584.0], + "1.38": [4608.0, 3328.0], + "1.46": [4864.0, 3328.0], + "1.67": [5120.0, 3072.0], + "1.75": [5376.0, 3072.0], + "2.0": [5632.0, 2816.0], + "2.09": [5888.0, 2816.0], + "2.4": [6144.0, 2560.0], + "2.5": [6400.0, 2560.0], + "2.89": [6656.0, 2304.0], + "3.0": [6912.0, 2304.0], + "3.11": [7168.0, 2304.0], + "3.62": [7424.0, 2048.0], + "3.75": [7680.0, 2048.0], + "3.88": [7936.0, 2048.0], + "4.0": [8192.0, 2048.0], +} + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -734,7 +777,9 @@ def __call__( # 1. Check inputs. Raise error if not correct if use_resolution_binning: - if self.transformer.config.sample_size == 64: + if self.transformer.config.sample_size == 128: + aspect_ratio_bin = ASPECT_RATIO_4096_BIN + elif self.transformer.config.sample_size == 64: aspect_ratio_bin = ASPECT_RATIO_2048_BIN elif self.transformer.config.sample_size == 32: aspect_ratio_bin = ASPECT_RATIO_1024_BIN From b00ce800d29ef52db39fca5083bac4efa4427b24 Mon Sep 17 00:00:00 2001 From: junsong Date: Sat, 11 Jan 2025 22:15:22 -0800 Subject: [PATCH 2/6] [Sana-4K] fix SanaPAGPipeline --- src/diffusers/pipelines/pag/pipeline_pag_sana.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sana.py b/src/diffusers/pipelines/pag/pipeline_pag_sana.py index 2cdc1c70cdcc..9b9486c8cdd0 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sana.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sana.py @@ -41,6 +41,7 @@ ASPECT_RATIO_1024_BIN, ) from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN +from ..sana.pipeline_sana import ASPECT_RATIO_4096_BIN from .pag_utils import PAGMixin @@ -755,7 +756,9 @@ def __call__( callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs if use_resolution_binning: - if self.transformer.config.sample_size == 64: + if self.transformer.config.sample_size == 128: + aspect_ratio_bin = ASPECT_RATIO_4096_BIN + elif self.transformer.config.sample_size == 64: aspect_ratio_bin = ASPECT_RATIO_2048_BIN elif self.transformer.config.sample_size == 32: aspect_ratio_bin = ASPECT_RATIO_1024_BIN From 4d18b5357e50d34adcd05dbbd73454237e27a34d Mon Sep 17 00:00:00 2001 From: junsong Date: Sat, 11 Jan 2025 22:22:05 -0800 Subject: [PATCH 3/6] add VAE automatically tiling function; --- src/diffusers/pipelines/pag/pipeline_pag_sana.py | 8 +++++++- src/diffusers/pipelines/sana/pipeline_sana.py | 8 +++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sana.py b/src/diffusers/pipelines/pag/pipeline_pag_sana.py index 9b9486c8cdd0..a0a315b9984e 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sana.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sana.py @@ -915,7 +915,13 @@ def __call__( image = latents else: latents = latents.to(self.vae.dtype) - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + try: + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + except torch.cuda.OutOfMemoryError as e: + print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") + self.vae.enable_tiling(tile_sample_min_width=1024, tile_sample_min_height=1024) + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + self.vae.disable_tiling() if use_resolution_binning: image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) diff --git a/src/diffusers/pipelines/sana/pipeline_sana.py b/src/diffusers/pipelines/sana/pipeline_sana.py index 8b318597c12d..45fb2506c06f 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana.py +++ b/src/diffusers/pipelines/sana/pipeline_sana.py @@ -953,7 +953,13 @@ def __call__( image = latents else: latents = latents.to(self.vae.dtype) - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + try: + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + except torch.cuda.OutOfMemoryError as e: + print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") + self.vae.enable_tiling(tile_sample_min_width=1024, tile_sample_min_height=1024) + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + self.vae.disable_tiling() if use_resolution_binning: image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) From 8fe11f4a5522db6c5f47cb9526c03b52eae69809 Mon Sep 17 00:00:00 2001 From: junsong Date: Sat, 11 Jan 2025 23:12:10 -0800 Subject: [PATCH 4/6] set clean_caption to False; --- src/diffusers/pipelines/pag/pipeline_pag_sana.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sana.py b/src/diffusers/pipelines/pag/pipeline_pag_sana.py index a0a315b9984e..69e90992b1be 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sana.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sana.py @@ -640,7 +640,7 @@ def __call__( negative_prompt_attention_mask: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - clean_caption: bool = True, + clean_caption: bool = False, use_resolution_binning: bool = True, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], From 0932a4e5aef41e314e7a8748d3dba270ed434814 Mon Sep 17 00:00:00 2001 From: lawrence-cj Date: Tue, 14 Jan 2025 09:43:47 +0800 Subject: [PATCH 5/6] add warnings for VAE OOM. --- src/diffusers/pipelines/pag/pipeline_pag_sana.py | 8 ++++---- src/diffusers/pipelines/sana/pipeline_sana.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sana.py b/src/diffusers/pipelines/pag/pipeline_pag_sana.py index 69e90992b1be..334c752e7870 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sana.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sana.py @@ -16,6 +16,7 @@ import inspect import re import urllib.parse as ul +import warnings from typing import Callable, Dict, List, Optional, Tuple, Union import torch @@ -918,10 +919,9 @@ def __call__( try: image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] except torch.cuda.OutOfMemoryError as e: - print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") - self.vae.enable_tiling(tile_sample_min_width=1024, tile_sample_min_height=1024) - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] - self.vae.disable_tiling() + warnings.warn(f"{e}. \n" + f"Try to use VAE tiling for large images. For example: \n" + f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)") if use_resolution_binning: image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) diff --git a/src/diffusers/pipelines/sana/pipeline_sana.py b/src/diffusers/pipelines/sana/pipeline_sana.py index 45fb2506c06f..8ffb2c429f77 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana.py +++ b/src/diffusers/pipelines/sana/pipeline_sana.py @@ -16,6 +16,7 @@ import inspect import re import urllib.parse as ul +import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch @@ -956,10 +957,9 @@ def __call__( try: image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] except torch.cuda.OutOfMemoryError as e: - print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") - self.vae.enable_tiling(tile_sample_min_width=1024, tile_sample_min_height=1024) - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] - self.vae.disable_tiling() + warnings.warn(f"{e}. \n" + f"Try to use VAE tiling for large images. For example: \n" + f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)") if use_resolution_binning: image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) From 422fd3692a1f4478e0a41c59b75b1c5f720e7638 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 14 Jan 2025 21:30:10 +0100 Subject: [PATCH 6/6] style --- src/diffusers/pipelines/pag/pipeline_pag_sana.py | 8 +++++--- src/diffusers/pipelines/sana/pipeline_sana.py | 8 +++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sana.py b/src/diffusers/pipelines/pag/pipeline_pag_sana.py index 334c752e7870..416b2f7c60f2 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sana.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sana.py @@ -919,9 +919,11 @@ def __call__( try: image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] except torch.cuda.OutOfMemoryError as e: - warnings.warn(f"{e}. \n" - f"Try to use VAE tiling for large images. For example: \n" - f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)") + warnings.warn( + f"{e}. \n" + f"Try to use VAE tiling for large images. For example: \n" + f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)" + ) if use_resolution_binning: image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) diff --git a/src/diffusers/pipelines/sana/pipeline_sana.py b/src/diffusers/pipelines/sana/pipeline_sana.py index 8ffb2c429f77..cca4dfe5e8ba 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana.py +++ b/src/diffusers/pipelines/sana/pipeline_sana.py @@ -957,9 +957,11 @@ def __call__( try: image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] except torch.cuda.OutOfMemoryError as e: - warnings.warn(f"{e}. \n" - f"Try to use VAE tiling for large images. For example: \n" - f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)") + warnings.warn( + f"{e}. \n" + f"Try to use VAE tiling for large images. For example: \n" + f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)" + ) if use_resolution_binning: image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)