Skip to content

Commit 0932a4e

Browse files
committed
add warnings for VAE OOM.
1 parent 45fadaf commit 0932a4e

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

src/diffusers/pipelines/pag/pipeline_pag_sana.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import inspect
1717
import re
1818
import urllib.parse as ul
19+
import warnings
1920
from typing import Callable, Dict, List, Optional, Tuple, Union
2021

2122
import torch
@@ -918,10 +919,9 @@ def __call__(
918919
try:
919920
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
920921
except torch.cuda.OutOfMemoryError as e:
921-
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
922-
self.vae.enable_tiling(tile_sample_min_width=1024, tile_sample_min_height=1024)
923-
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
924-
self.vae.disable_tiling()
922+
warnings.warn(f"{e}. \n"
923+
f"Try to use VAE tiling for large images. For example: \n"
924+
f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)")
925925
if use_resolution_binning:
926926
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
927927

src/diffusers/pipelines/sana/pipeline_sana.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import inspect
1717
import re
1818
import urllib.parse as ul
19+
import warnings
1920
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2021

2122
import torch
@@ -956,10 +957,9 @@ def __call__(
956957
try:
957958
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
958959
except torch.cuda.OutOfMemoryError as e:
959-
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
960-
self.vae.enable_tiling(tile_sample_min_width=1024, tile_sample_min_height=1024)
961-
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
962-
self.vae.disable_tiling()
960+
warnings.warn(f"{e}. \n"
961+
f"Try to use VAE tiling for large images. For example: \n"
962+
f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)")
963963
if use_resolution_binning:
964964
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
965965

0 commit comments

Comments
 (0)