@@ -321,6 +321,7 @@ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None)
321321 self .latent_channels = 4
322322 self .latent_dim = 2
323323 self .output_channels = 3
324+ self .pad_channel_value = None
324325 self .process_input = lambda image : image * 2.0 - 1.0
325326 self .process_output = lambda image : torch .clamp ((image + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
326327 self .working_dtypes = [torch .bfloat16 , torch .float32 ]
@@ -435,6 +436,7 @@ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None)
435436 self .memory_used_decode = lambda shape , dtype : (1000 * shape [2 ] * 2048 ) * model_management .dtype_size (dtype )
436437 self .latent_channels = 64
437438 self .output_channels = 2
439+ self .pad_channel_value = "replicate"
438440 self .upscale_ratio = 2048
439441 self .downscale_ratio = 2048
440442 self .latent_dim = 1
@@ -547,6 +549,7 @@ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None)
547549 self .latent_dim = 3
548550 self .latent_channels = 16
549551 self .output_channels = sd ["encoder.conv1.weight" ].shape [1 ]
552+ self .pad_channel_value = 1.0
550553 ddconfig = {"dim" : dim , "z_dim" : self .latent_channels , "dim_mult" : [1 , 2 , 4 , 4 ], "num_res_blocks" : 2 , "attn_scales" : [], "temperal_downsample" : [False , True , True ], "image_channels" : self .output_channels , "dropout" : 0.0 }
551554 self .first_stage_model = comfy .ldm .wan .vae .WanVAE (** ddconfig )
552555 self .working_dtypes = [torch .bfloat16 , torch .float16 , torch .float32 ]
@@ -583,6 +586,7 @@ def estimate_memory(shape, dtype, num_layers = 16, kv_cache_multiplier = 2):
583586 self .memory_used_decode = lambda shape , dtype : (shape [2 ] * shape [3 ] * 87000 ) * model_management .dtype_size (dtype )
584587 self .latent_channels = 8
585588 self .output_channels = 2
589+ self .pad_channel_value = "replicate"
586590 self .upscale_ratio = 4096
587591 self .downscale_ratio = 4096
588592 self .latent_dim = 2
@@ -691,17 +695,28 @@ def throw_exception_if_invalid(self):
691695 raise RuntimeError ("ERROR: VAE is invalid: None\n \n If the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE." )
692696
693697 def vae_encode_crop_pixels (self , pixels ):
694- if not self .crop_input :
695- return pixels
696-
697- downscale_ratio = self .spacial_compression_encode ()
698+ if self .crop_input :
699+ downscale_ratio = self .spacial_compression_encode ()
700+
701+ dims = pixels .shape [1 :- 1 ]
702+ for d in range (len (dims )):
703+ x = (dims [d ] // downscale_ratio ) * downscale_ratio
704+ x_offset = (dims [d ] % downscale_ratio ) // 2
705+ if x != dims [d ]:
706+ pixels = pixels .narrow (d + 1 , x_offset , x )
707+
708+ if pixels .shape [- 1 ] > self .output_channels :
709+ pixels = pixels [..., :self .output_channels ]
710+ elif pixels .shape [- 1 ] < self .output_channels :
711+ if self .pad_channel_value is not None :
712+ if isinstance (self .pad_channel_value , str ):
713+ mode = self .pad_channel_value
714+ value = None
715+ else :
716+ mode = "constant"
717+ value = self .pad_channel_value
698718
699- dims = pixels .shape [1 :- 1 ]
700- for d in range (len (dims )):
701- x = (dims [d ] // downscale_ratio ) * downscale_ratio
702- x_offset = (dims [d ] % downscale_ratio ) // 2
703- if x != dims [d ]:
704- pixels = pixels .narrow (d + 1 , x_offset , x )
719+ pixels = torch .nn .functional .pad (pixels , (0 , self .output_channels - pixels .shape [- 1 ]), mode = mode , value = value )
705720 return pixels
706721
707722 def decode_tiled_ (self , samples , tile_x = 64 , tile_y = 64 , overlap = 16 ):
0 commit comments