Skip to content

Commit f4bdf5f

Browse files
authored
sd: revise hy VAE VRAM (#11105)
This was recently collapsed down to rolling VAE through temporal. Clamp The time dimension.
1 parent 6be85c7 commit f4bdf5f

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

comfy/sd.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -483,8 +483,10 @@ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None)
483483
self.latent_dim = 3
484484
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
485485
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
486-
self.memory_used_decode = lambda shape, dtype: (1500 * shape[2] * shape[3] * shape[4] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
487-
self.memory_used_encode = lambda shape, dtype: (900 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
486+
#This is likely to significantly over-estimate with single image or low frame counts as the
487+
#implementation is able to completely skip caching. Rework if used as an image only VAE
488+
self.memory_used_decode = lambda shape, dtype: (2800 * min(8, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype)
489+
self.memory_used_encode = lambda shape, dtype: (1400 * min(9, shape[2]) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
488490
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
489491
elif "decoder.unpatcher3d.wavelets" in sd:
490492
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 8, 8)

0 commit comments

Comments
 (0)