Skip to content

Commit 4d18b53

Browse files
committed
add VAE automatically tiling function;
1 parent b00ce80 commit 4d18b53

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

src/diffusers/pipelines/pag/pipeline_pag_sana.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -915,7 +915,13 @@ def __call__(
915915
image = latents
916916
else:
917917
latents = latents.to(self.vae.dtype)
918-
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
918+
try:
919+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
920+
except torch.cuda.OutOfMemoryError as e:
921+
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
922+
self.vae.enable_tiling(tile_sample_min_width=1024, tile_sample_min_height=1024)
923+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
924+
self.vae.disable_tiling()
919925
if use_resolution_binning:
920926
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
921927

src/diffusers/pipelines/sana/pipeline_sana.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -953,7 +953,13 @@ def __call__(
953953
image = latents
954954
else:
955955
latents = latents.to(self.vae.dtype)
956-
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
956+
try:
957+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
958+
except torch.cuda.OutOfMemoryError as e:
959+
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
960+
self.vae.enable_tiling(tile_sample_min_width=1024, tile_sample_min_height=1024)
961+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
962+
self.vae.disable_tiling()
957963
if use_resolution_binning:
958964
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
959965

0 commit comments

Comments
 (0)