@@ -282,6 +282,7 @@ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None)
282282
283283 self .downscale_index_formula = None
284284 self .upscale_index_formula = None
285+ self .extra_1d_channel = None
285286
286287 if config is None :
287288 if "decoder.mid.block_1.mix_factor" in sd :
@@ -445,13 +446,14 @@ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None)
445446 self .memory_used_decode = lambda shape , dtype : (shape [2 ] * shape [3 ] * 87000 ) * model_management .dtype_size (dtype )
446447 self .latent_channels = 8
447448 self .output_channels = 2
448- # self.upscale_ratio = 2048
449- # self.downscale_ratio = 2048
449+ self .upscale_ratio = 4096
450+ self .downscale_ratio = 4096
450451 self .latent_dim = 2
451452 self .process_output = lambda audio : audio
452453 self .process_input = lambda audio : audio
453454 self .working_dtypes = [torch .bfloat16 , torch .float32 ]
454455 self .disable_offload = True
456+ self .extra_1d_channel = 16
455457 else :
456458 logging .warning ("WARNING: No VAE weights detected, VAE not initalized." )
457459 self .first_stage_model = None
@@ -510,7 +512,13 @@ def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
510512 return output
511513
512514 def decode_tiled_1d (self , samples , tile_x = 128 , overlap = 32 ):
513- decode_fn = lambda a : self .first_stage_model .decode (a .to (self .vae_dtype ).to (self .device )).float ()
515+ if samples .ndim == 3 :
516+ decode_fn = lambda a : self .first_stage_model .decode (a .to (self .vae_dtype ).to (self .device )).float ()
517+ else :
518+ og_shape = samples .shape
519+ samples = samples .reshape ((og_shape [0 ], og_shape [1 ] * og_shape [2 ], - 1 ))
520+ decode_fn = lambda a : self .first_stage_model .decode (a .reshape ((- 1 , og_shape [1 ], og_shape [2 ], a .shape [- 1 ])).to (self .vae_dtype ).to (self .device )).float ()
521+
514522 return self .process_output (comfy .utils .tiled_scale_multidim (samples , decode_fn , tile = (tile_x ,), overlap = overlap , upscale_amount = self .upscale_ratio , out_channels = self .output_channels , output_device = self .output_device ))
515523
516524 def decode_tiled_3d (self , samples , tile_t = 999 , tile_x = 32 , tile_y = 32 , overlap = (1 , 8 , 8 )):
@@ -530,9 +538,24 @@ def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
530538 samples /= 3.0
531539 return samples
532540
533- def encode_tiled_1d (self , samples , tile_x = 128 * 2048 , overlap = 32 * 2048 ):
534- encode_fn = lambda a : self .first_stage_model .encode ((self .process_input (a )).to (self .vae_dtype ).to (self .device )).float ()
535- return comfy .utils .tiled_scale_multidim (samples , encode_fn , tile = (tile_x ,), overlap = overlap , upscale_amount = (1 / self .downscale_ratio ), out_channels = self .latent_channels , output_device = self .output_device )
541+ def encode_tiled_1d (self , samples , tile_x = 256 * 2048 , overlap = 64 * 2048 ):
542+ if self .latent_dim == 1 :
543+ encode_fn = lambda a : self .first_stage_model .encode ((self .process_input (a )).to (self .vae_dtype ).to (self .device )).float ()
544+ out_channels = self .latent_channels
545+ upscale_amount = 1 / self .downscale_ratio
546+ else :
547+ extra_channel_size = self .extra_1d_channel
548+ out_channels = self .latent_channels * extra_channel_size
549+ tile_x = tile_x // extra_channel_size
550+ overlap = overlap // extra_channel_size
551+ upscale_amount = 1 / self .downscale_ratio
552+ encode_fn = lambda a : self .first_stage_model .encode ((self .process_input (a )).to (self .vae_dtype ).to (self .device )).reshape (1 , out_channels , - 1 ).float ()
553+
554+ out = comfy .utils .tiled_scale_multidim (samples , encode_fn , tile = (tile_x ,), overlap = overlap , upscale_amount = upscale_amount , out_channels = out_channels , output_device = self .output_device )
555+ if self .latent_dim == 1 :
556+ return out
557+ else :
558+ return out .reshape (samples .shape [0 ], self .latent_channels , extra_channel_size , - 1 )
536559
537560 def encode_tiled_3d (self , samples , tile_t = 9999 , tile_x = 512 , tile_y = 512 , overlap = (1 , 64 , 64 )):
538561 encode_fn = lambda a : self .first_stage_model .encode ((self .process_input (a )).to (self .vae_dtype ).to (self .device )).float ()
@@ -557,7 +580,7 @@ def decode(self, samples_in, vae_options={}):
557580 except model_management .OOM_EXCEPTION :
558581 logging .warning ("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding." )
559582 dims = samples_in .ndim - 2
560- if dims == 1 :
583+ if dims == 1 or self . extra_1d_channel is not None :
561584 pixel_samples = self .decode_tiled_1d (samples_in )
562585 elif dims == 2 :
563586 pixel_samples = self .decode_tiled_ (samples_in )
@@ -624,7 +647,7 @@ def encode(self, pixel_samples):
624647 tile = 256
625648 overlap = tile // 4
626649 samples = self .encode_tiled_3d (pixel_samples , tile_x = tile , tile_y = tile , overlap = (1 , overlap , overlap ))
627- elif self .latent_dim == 1 :
650+ elif self .latent_dim == 1 or self . extra_1d_channel is not None :
628651 samples = self .encode_tiled_1d (pixel_samples )
629652 else :
630653 samples = self .encode_tiled_ (pixel_samples )
0 commit comments