Skip to content

Commit 911331c

Browse files
authored
sd: fix VAE tiled fallback VRAM leak (#10139)
When the VAE catches this VRAM OOM, it launches the fallback logic straight from the exception context. Python however refs the entire call stack that caused the exception including any local variables for the sake of exception report and debugging. In the case of tensors, this can hold on the references to GBs of VRAM and inhibit the VRAM allocated from freeing them. So dump the except context completely before going back to the VAE via the tiler by getting out of the except block with nothing but a flag. The greately increases the reliability of the tiler fallback, especially on low VRAM cards, as with the bug, if the leak randomly leaked more than the headroom needed for a single tile, the tiler would fallback would OOM and fail the flow.
1 parent bb32d4e commit 911331c

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

comfy/sd.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,7 @@ def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=
652652
def decode(self, samples_in, vae_options={}):
653653
self.throw_exception_if_invalid()
654654
pixel_samples = None
655+
do_tile = False
655656
try:
656657
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
657658
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
@@ -667,6 +668,13 @@ def decode(self, samples_in, vae_options={}):
667668
pixel_samples[x:x+batch_number] = out
668669
except model_management.OOM_EXCEPTION:
669670
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
671+
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
672+
#exception and the exception itself refs them all until we get out of this except block.
673+
#So we just set a flag for tiler fallback so that tensor gc can happen once the
674+
#exception is fully off the books.
675+
do_tile = True
676+
677+
if do_tile:
670678
dims = samples_in.ndim - 2
671679
if dims == 1 or self.extra_1d_channel is not None:
672680
pixel_samples = self.decode_tiled_1d(samples_in)
@@ -713,6 +721,7 @@ def encode(self, pixel_samples):
713721
self.throw_exception_if_invalid()
714722
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
715723
pixel_samples = pixel_samples.movedim(-1, 1)
724+
do_tile = False
716725
if self.latent_dim == 3 and pixel_samples.ndim < 5:
717726
if not self.not_video:
718727
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
@@ -734,6 +743,13 @@ def encode(self, pixel_samples):
734743

735744
except model_management.OOM_EXCEPTION:
736745
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
746+
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
747+
#exception and the exception itself refs them all until we get out of this except block.
748+
#So we just set a flag for tiler fallback so that tensor gc can happen once the
749+
#exception is fully off the books.
750+
do_tile = True
751+
752+
if do_tile:
737753
if self.latent_dim == 3:
738754
tile = 256
739755
overlap = tile // 4

0 commit comments

Comments
 (0)