Skip to content

Commit bef3356

Browse files
committed
Merge branch 'master' into v3-improvements
2 parents e8d4074 + 6a2678a commit bef3356

File tree

4 files changed

+37
-18
lines changed

4 files changed

+37
-18
lines changed

comfy/ldm/wan/vae.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ class Encoder3d(nn.Module):
227227
def __init__(self,
228228
dim=128,
229229
z_dim=4,
230+
input_channels=3,
230231
dim_mult=[1, 2, 4, 4],
231232
num_res_blocks=2,
232233
attn_scales=[],
@@ -245,7 +246,7 @@ def __init__(self,
245246
scale = 1.0
246247

247248
# init block
248-
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
249+
self.conv1 = CausalConv3d(input_channels, dims[0], 3, padding=1)
249250

250251
# downsample blocks
251252
downsamples = []
@@ -331,6 +332,7 @@ class Decoder3d(nn.Module):
331332
def __init__(self,
332333
dim=128,
333334
z_dim=4,
335+
output_channels=3,
334336
dim_mult=[1, 2, 4, 4],
335337
num_res_blocks=2,
336338
attn_scales=[],
@@ -378,7 +380,7 @@ def __init__(self,
378380
# output blocks
379381
self.head = nn.Sequential(
380382
RMS_norm(out_dim, images=False), nn.SiLU(),
381-
CausalConv3d(out_dim, 3, 3, padding=1))
383+
CausalConv3d(out_dim, output_channels, 3, padding=1))
382384

383385
def forward(self, x, feat_cache=None, feat_idx=[0]):
384386
## conv1
@@ -449,6 +451,7 @@ def __init__(self,
449451
num_res_blocks=2,
450452
attn_scales=[],
451453
temperal_downsample=[True, True, False],
454+
image_channels=3,
452455
dropout=0.0):
453456
super().__init__()
454457
self.dim = dim
@@ -460,11 +463,11 @@ def __init__(self,
460463
self.temperal_upsample = temperal_downsample[::-1]
461464

462465
# modules
463-
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
466+
self.encoder = Encoder3d(dim, z_dim * 2, image_channels, dim_mult, num_res_blocks,
464467
attn_scales, self.temperal_downsample, dropout)
465468
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
466469
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
467-
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
470+
self.decoder = Decoder3d(dim, z_dim, image_channels, dim_mult, num_res_blocks,
468471
attn_scales, self.temperal_upsample, dropout)
469472

470473
def encode(self, x):

comfy/sd.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,7 @@ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None)
321321
self.latent_channels = 4
322322
self.latent_dim = 2
323323
self.output_channels = 3
324+
self.pad_channel_value = None
324325
self.process_input = lambda image: image * 2.0 - 1.0
325326
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
326327
self.working_dtypes = [torch.bfloat16, torch.float32]
@@ -435,6 +436,7 @@ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None)
435436
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype)
436437
self.latent_channels = 64
437438
self.output_channels = 2
439+
self.pad_channel_value = "replicate"
438440
self.upscale_ratio = 2048
439441
self.downscale_ratio = 2048
440442
self.latent_dim = 1
@@ -546,7 +548,9 @@ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None)
546548
self.downscale_index_formula = (4, 8, 8)
547549
self.latent_dim = 3
548550
self.latent_channels = 16
549-
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
551+
self.output_channels = sd["encoder.conv1.weight"].shape[1]
552+
self.pad_channel_value = 1.0
553+
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0}
550554
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
551555
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
552556
self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype)
@@ -582,6 +586,7 @@ def estimate_memory(shape, dtype, num_layers = 16, kv_cache_multiplier = 2):
582586
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
583587
self.latent_channels = 8
584588
self.output_channels = 2
589+
self.pad_channel_value = "replicate"
585590
self.upscale_ratio = 4096
586591
self.downscale_ratio = 4096
587592
self.latent_dim = 2
@@ -690,17 +695,28 @@ def throw_exception_if_invalid(self):
690695
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
691696

692697
def vae_encode_crop_pixels(self, pixels):
693-
if not self.crop_input:
694-
return pixels
695-
696-
downscale_ratio = self.spacial_compression_encode()
698+
if self.crop_input:
699+
downscale_ratio = self.spacial_compression_encode()
700+
701+
dims = pixels.shape[1:-1]
702+
for d in range(len(dims)):
703+
x = (dims[d] // downscale_ratio) * downscale_ratio
704+
x_offset = (dims[d] % downscale_ratio) // 2
705+
if x != dims[d]:
706+
pixels = pixels.narrow(d + 1, x_offset, x)
707+
708+
if pixels.shape[-1] > self.output_channels:
709+
pixels = pixels[..., :self.output_channels]
710+
elif pixels.shape[-1] < self.output_channels:
711+
if self.pad_channel_value is not None:
712+
if isinstance(self.pad_channel_value, str):
713+
mode = self.pad_channel_value
714+
value = None
715+
else:
716+
mode = "constant"
717+
value = self.pad_channel_value
697718

698-
dims = pixels.shape[1:-1]
699-
for d in range(len(dims)):
700-
x = (dims[d] // downscale_ratio) * downscale_ratio
701-
x_offset = (dims[d] % downscale_ratio) // 2
702-
if x != dims[d]:
703-
pixels = pixels.narrow(d + 1, x_offset, x)
719+
pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value)
704720
return pixels
705721

706722
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):

nodes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def INPUT_TYPES(s):
343343
CATEGORY = "latent"
344344

345345
def encode(self, vae, pixels):
346-
t = vae.encode(pixels[:,:,:,:3])
346+
t = vae.encode(pixels)
347347
return ({"samples":t}, )
348348

349349
class VAEEncodeTiled:
@@ -361,7 +361,7 @@ def INPUT_TYPES(s):
361361
CATEGORY = "_for_testing"
362362

363363
def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8):
364-
t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
364+
t = vae.encode_tiled(pixels, tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
365365
return ({"samples": t}, )
366366

367367
class VAEEncodeForInpaint:

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
comfyui-frontend-package==1.34.9
2-
comfyui-workflow-templates==0.7.59
2+
comfyui-workflow-templates==0.7.60
33
comfyui-embedded-docs==0.3.1
44
torch
55
torchsde

0 commit comments

Comments
 (0)