Skip to content

Commit a692c3c

Browse files
Make ACE VAE tiling work. (#8004)
1 parent 5d3cc85 commit a692c3c

File tree

1 file changed

+31
-8
lines changed

1 file changed

+31
-8
lines changed

comfy/sd.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None)
282282

283283
self.downscale_index_formula = None
284284
self.upscale_index_formula = None
285+
self.extra_1d_channel = None
285286

286287
if config is None:
287288
if "decoder.mid.block_1.mix_factor" in sd:
@@ -445,13 +446,14 @@ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None)
445446
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
446447
self.latent_channels = 8
447448
self.output_channels = 2
448-
# self.upscale_ratio = 2048
449-
# self.downscale_ratio = 2048
449+
self.upscale_ratio = 4096
450+
self.downscale_ratio = 4096
450451
self.latent_dim = 2
451452
self.process_output = lambda audio: audio
452453
self.process_input = lambda audio: audio
453454
self.working_dtypes = [torch.bfloat16, torch.float32]
454455
self.disable_offload = True
456+
self.extra_1d_channel = 16
455457
else:
456458
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
457459
self.first_stage_model = None
@@ -510,7 +512,13 @@ def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
510512
return output
511513

512514
def decode_tiled_1d(self, samples, tile_x=128, overlap=32):
513-
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
515+
if samples.ndim == 3:
516+
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
517+
else:
518+
og_shape = samples.shape
519+
samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1))
520+
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).float()
521+
514522
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
515523

516524
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
@@ -530,9 +538,24 @@ def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
530538
samples /= 3.0
531539
return samples
532540

533-
def encode_tiled_1d(self, samples, tile_x=128 * 2048, overlap=32 * 2048):
534-
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
535-
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
541+
def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
542+
if self.latent_dim == 1:
543+
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
544+
out_channels = self.latent_channels
545+
upscale_amount = 1 / self.downscale_ratio
546+
else:
547+
extra_channel_size = self.extra_1d_channel
548+
out_channels = self.latent_channels * extra_channel_size
549+
tile_x = tile_x // extra_channel_size
550+
overlap = overlap // extra_channel_size
551+
upscale_amount = 1 / self.downscale_ratio
552+
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).float()
553+
554+
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
555+
if self.latent_dim == 1:
556+
return out
557+
else:
558+
return out.reshape(samples.shape[0], self.latent_channels, extra_channel_size, -1)
536559

537560
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
538561
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
@@ -557,7 +580,7 @@ def decode(self, samples_in, vae_options={}):
557580
except model_management.OOM_EXCEPTION:
558581
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
559582
dims = samples_in.ndim - 2
560-
if dims == 1:
583+
if dims == 1 or self.extra_1d_channel is not None:
561584
pixel_samples = self.decode_tiled_1d(samples_in)
562585
elif dims == 2:
563586
pixel_samples = self.decode_tiled_(samples_in)
@@ -624,7 +647,7 @@ def encode(self, pixel_samples):
624647
tile = 256
625648
overlap = tile // 4
626649
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
627-
elif self.latent_dim == 1:
650+
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
628651
samples = self.encode_tiled_1d(pixel_samples)
629652
else:
630653
samples = self.encode_tiled_(pixel_samples)

0 commit comments

Comments
 (0)