|  | 
| 16 | 16 | import inspect | 
| 17 | 17 | import re | 
| 18 | 18 | import urllib.parse as ul | 
|  | 19 | +import warnings | 
| 19 | 20 | from typing import Callable, Dict, List, Optional, Tuple, Union | 
| 20 | 21 | 
 | 
| 21 | 22 | import torch | 
|  | 
| 41 | 42 |     ASPECT_RATIO_1024_BIN, | 
| 42 | 43 | ) | 
| 43 | 44 | from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN | 
|  | 45 | +from ..sana.pipeline_sana import ASPECT_RATIO_4096_BIN | 
| 44 | 46 | from .pag_utils import PAGMixin | 
| 45 | 47 | 
 | 
| 46 | 48 | 
 | 
| @@ -639,7 +641,7 @@ def __call__( | 
| 639 | 641 |         negative_prompt_attention_mask: Optional[torch.Tensor] = None, | 
| 640 | 642 |         output_type: Optional[str] = "pil", | 
| 641 | 643 |         return_dict: bool = True, | 
| 642 |  | -        clean_caption: bool = True, | 
|  | 644 | +        clean_caption: bool = False, | 
| 643 | 645 |         use_resolution_binning: bool = True, | 
| 644 | 646 |         callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, | 
| 645 | 647 |         callback_on_step_end_tensor_inputs: List[str] = ["latents"], | 
| @@ -755,7 +757,9 @@ def __call__( | 
| 755 | 757 |             callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs | 
| 756 | 758 | 
 | 
| 757 | 759 |         if use_resolution_binning: | 
| 758 |  | -            if self.transformer.config.sample_size == 64: | 
|  | 760 | +            if self.transformer.config.sample_size == 128: | 
|  | 761 | +                aspect_ratio_bin = ASPECT_RATIO_4096_BIN | 
|  | 762 | +            elif self.transformer.config.sample_size == 64: | 
| 759 | 763 |                 aspect_ratio_bin = ASPECT_RATIO_2048_BIN | 
| 760 | 764 |             elif self.transformer.config.sample_size == 32: | 
| 761 | 765 |                 aspect_ratio_bin = ASPECT_RATIO_1024_BIN | 
| @@ -912,7 +916,14 @@ def __call__( | 
| 912 | 916 |             image = latents | 
| 913 | 917 |         else: | 
| 914 | 918 |             latents = latents.to(self.vae.dtype) | 
| 915 |  | -            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] | 
|  | 919 | +            try: | 
|  | 920 | +                image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] | 
|  | 921 | +            except torch.cuda.OutOfMemoryError as e: | 
|  | 922 | +                warnings.warn( | 
|  | 923 | +                    f"{e}. \n" | 
|  | 924 | +                    f"Try to use VAE tiling for large images. For example: \n" | 
|  | 925 | +                    f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)" | 
|  | 926 | +                ) | 
| 916 | 927 |             if use_resolution_binning: | 
| 917 | 928 |                 image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) | 
| 918 | 929 | 
 | 
|  | 
0 commit comments