Skip to content

Commit 2643c74

Browse files
committed
autoencoder_dc tiling
1 parent 95c5ce4 commit 2643c74

File tree

1 file changed

+101
-6
lines changed

1 file changed

+101
-6
lines changed

src/diffusers/models/autoencoders/autoencoder_dc.py

Lines changed: 101 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -479,12 +479,15 @@ def __init__(
479479
self.use_tiling = False
480480

481481
# The minimal tile height and width for spatial tiling to be used
482-
self.tile_sample_min_height = 512
483-
self.tile_sample_min_width = 512
482+
self.tile_sample_min_height = 1024
483+
self.tile_sample_min_width = 1024
484484

485485
# The minimal distance between two spatial tiles
486-
self.tile_sample_stride_height = 448
487-
self.tile_sample_stride_width = 448
486+
self.tile_sample_stride_height = 896
487+
self.tile_sample_stride_width = 896
488+
489+
self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
490+
self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
488491

489492
def enable_tiling(
490493
self,
@@ -515,6 +518,8 @@ def enable_tiling(
515518
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
516519
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
517520
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
521+
self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
522+
self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
518523

519524
def disable_tiling(self) -> None:
520525
r"""
@@ -606,11 +611,101 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp
606611
return (decoded,)
607612
return DecoderOutput(sample=decoded)
608613

614+
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
615+
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
616+
for y in range(blend_extent):
617+
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
618+
return b
619+
620+
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
621+
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
622+
for x in range(blend_extent):
623+
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
624+
return b
625+
609626
def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
610-
raise NotImplementedError("`tiled_encode` has not been implemented for AutoencoderDC.")
627+
batch_size, num_channels, height, width = x.shape
628+
latent_height = height // self.spatial_compression_ratio
629+
latent_width = width // self.spatial_compression_ratio
630+
631+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
632+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
633+
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
634+
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
635+
blend_height = tile_latent_min_height - tile_latent_stride_height
636+
blend_width = tile_latent_min_width - tile_latent_stride_width
637+
638+
# Split x into overlapping tiles and encode them separately.
639+
# The tiles have an overlap to avoid seams between tiles.
640+
rows = []
641+
for i in range(0, x.shape[2], self.tile_sample_stride_height):
642+
row = []
643+
for j in range(0, x.shape[3], self.tile_sample_stride_width):
644+
tile = x[:, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
645+
if tile.shape[2] % self.spatial_compression_ratio != 0 or tile.shape[3] % self.spatial_compression_ratio != 0:
646+
tile = F.pad(tile, (0, (self.spatial_compression_ratio - tile.shape[3]) % self.spatial_compression_ratio, 0, (self.spatial_compression_ratio - tile.shape[2]) % self.spatial_compression_ratio))
647+
tile = self.encoder(tile)
648+
row.append(tile)
649+
rows.append(row)
650+
result_rows = []
651+
for i, row in enumerate(rows):
652+
result_row = []
653+
for j, tile in enumerate(row):
654+
# blend the above tile and the left tile
655+
# to the current tile and add the current tile to the result row
656+
if i > 0:
657+
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
658+
if j > 0:
659+
tile = self.blend_h(row[j - 1], tile, blend_width)
660+
result_row.append(tile[:, :, :tile_latent_stride_height, :tile_latent_stride_width])
661+
result_rows.append(torch.cat(result_row, dim=3))
662+
663+
encoded = torch.cat(result_rows, dim=2)[:, :, :latent_height, :latent_width]
664+
665+
if not return_dict:
666+
return (encoded,)
667+
return EncoderOutput(latent=encoded)
611668

612669
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
613-
raise NotImplementedError("`tiled_decode` has not been implemented for AutoencoderDC.")
670+
batch_size, num_channels, height, width = z.shape
671+
672+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
673+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
674+
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
675+
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
676+
677+
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
678+
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
679+
680+
# Split z into overlapping tiles and decode them separately.
681+
# The tiles have an overlap to avoid seams between tiles.
682+
rows = []
683+
for i in range(0, height, tile_latent_stride_height):
684+
row = []
685+
for j in range(0, width, tile_latent_stride_width):
686+
tile = z[:, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
687+
decoded = self.decoder(tile)
688+
row.append(decoded)
689+
rows.append(row)
690+
691+
result_rows = []
692+
for i, row in enumerate(rows):
693+
result_row = []
694+
for j, tile in enumerate(row):
695+
# blend the above tile and the left tile
696+
# to the current tile and add the current tile to the result row
697+
if i > 0:
698+
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
699+
if j > 0:
700+
tile = self.blend_h(row[j - 1], tile, blend_width)
701+
result_row.append(tile[:, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
702+
result_rows.append(torch.cat(result_row, dim=3))
703+
704+
decoded = torch.cat(result_rows, dim=2)
705+
706+
if not return_dict:
707+
return (decoded,)
708+
return DecoderOutput(sample=decoded)
614709

615710
def forward(self, sample: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
616711
encoded = self.encode(sample, return_dict=False)[0]

0 commit comments

Comments
 (0)