Skip to content

Commit 6a2678a

Browse files
Trim/pad channels in VAE code. (#11406)
1 parent e4fb3a3 commit 6a2678a

File tree

2 files changed

+27
-12
lines changed

2 files changed

+27
-12
lines changed

comfy/sd.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,7 @@ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None)
321321
self.latent_channels = 4
322322
self.latent_dim = 2
323323
self.output_channels = 3
324+
self.pad_channel_value = None
324325
self.process_input = lambda image: image * 2.0 - 1.0
325326
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
326327
self.working_dtypes = [torch.bfloat16, torch.float32]
@@ -435,6 +436,7 @@ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None)
435436
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype)
436437
self.latent_channels = 64
437438
self.output_channels = 2
439+
self.pad_channel_value = "replicate"
438440
self.upscale_ratio = 2048
439441
self.downscale_ratio = 2048
440442
self.latent_dim = 1
@@ -547,6 +549,7 @@ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None)
547549
self.latent_dim = 3
548550
self.latent_channels = 16
549551
self.output_channels = sd["encoder.conv1.weight"].shape[1]
552+
self.pad_channel_value = 1.0
550553
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0}
551554
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
552555
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
@@ -583,6 +586,7 @@ def estimate_memory(shape, dtype, num_layers = 16, kv_cache_multiplier = 2):
583586
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
584587
self.latent_channels = 8
585588
self.output_channels = 2
589+
self.pad_channel_value = "replicate"
586590
self.upscale_ratio = 4096
587591
self.downscale_ratio = 4096
588592
self.latent_dim = 2
@@ -691,17 +695,28 @@ def throw_exception_if_invalid(self):
691695
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
692696

693697
def vae_encode_crop_pixels(self, pixels):
694-
if not self.crop_input:
695-
return pixels
696-
697-
downscale_ratio = self.spacial_compression_encode()
698+
if self.crop_input:
699+
downscale_ratio = self.spacial_compression_encode()
700+
701+
dims = pixels.shape[1:-1]
702+
for d in range(len(dims)):
703+
x = (dims[d] // downscale_ratio) * downscale_ratio
704+
x_offset = (dims[d] % downscale_ratio) // 2
705+
if x != dims[d]:
706+
pixels = pixels.narrow(d + 1, x_offset, x)
707+
708+
if pixels.shape[-1] > self.output_channels:
709+
pixels = pixels[..., :self.output_channels]
710+
elif pixels.shape[-1] < self.output_channels:
711+
if self.pad_channel_value is not None:
712+
if isinstance(self.pad_channel_value, str):
713+
mode = self.pad_channel_value
714+
value = None
715+
else:
716+
mode = "constant"
717+
value = self.pad_channel_value
698718

699-
dims = pixels.shape[1:-1]
700-
for d in range(len(dims)):
701-
x = (dims[d] // downscale_ratio) * downscale_ratio
702-
x_offset = (dims[d] % downscale_ratio) // 2
703-
if x != dims[d]:
704-
pixels = pixels.narrow(d + 1, x_offset, x)
719+
pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value)
705720
return pixels
706721

707722
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):

nodes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def INPUT_TYPES(s):
343343
CATEGORY = "latent"
344344

345345
def encode(self, vae, pixels):
346-
t = vae.encode(pixels[:,:,:,:3])
346+
t = vae.encode(pixels)
347347
return ({"samples":t}, )
348348

349349
class VAEEncodeTiled:
@@ -361,7 +361,7 @@ def INPUT_TYPES(s):
361361
CATEGORY = "_for_testing"
362362

363363
def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8):
364-
t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
364+
t = vae.encode_tiled(pixels, tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
365365
return ({"samples": t}, )
366366

367367
class VAEEncodeForInpaint:

0 commit comments

Comments
 (0)