@@ -486,6 +486,9 @@ def __init__(
486486 self .tile_sample_stride_height = 448
487487 self .tile_sample_stride_width = 448
488488
489+ self .tile_latent_min_height = self .tile_sample_min_height // self .spatial_compression_ratio
490+ self .tile_latent_min_width = self .tile_sample_min_width // self .spatial_compression_ratio
491+
489492 def enable_tiling (
490493 self ,
491494 tile_sample_min_height : Optional [int ] = None ,
@@ -515,6 +518,8 @@ def enable_tiling(
515518 self .tile_sample_min_width = tile_sample_min_width or self .tile_sample_min_width
516519 self .tile_sample_stride_height = tile_sample_stride_height or self .tile_sample_stride_height
517520 self .tile_sample_stride_width = tile_sample_stride_width or self .tile_sample_stride_width
521+ self .tile_latent_min_height = self .tile_sample_min_height // self .spatial_compression_ratio
522+ self .tile_latent_min_width = self .tile_sample_min_width // self .spatial_compression_ratio
518523
519524 def disable_tiling (self ) -> None :
520525 r"""
@@ -606,11 +611,106 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp
606611 return (decoded ,)
607612 return DecoderOutput (sample = decoded )
608613
614+ def blend_v (self , a : torch .Tensor , b : torch .Tensor , blend_extent : int ) -> torch .Tensor :
615+ blend_extent = min (a .shape [2 ], b .shape [2 ], blend_extent )
616+ for y in range (blend_extent ):
617+ b [:, :, y , :] = a [:, :, - blend_extent + y , :] * (1 - y / blend_extent ) + b [:, :, y , :] * (y / blend_extent )
618+ return b
619+
620+ def blend_h (self , a : torch .Tensor , b : torch .Tensor , blend_extent : int ) -> torch .Tensor :
621+ blend_extent = min (a .shape [3 ], b .shape [3 ], blend_extent )
622+ for x in range (blend_extent ):
623+ b [:, :, :, x ] = a [:, :, :, - blend_extent + x ] * (1 - x / blend_extent ) + b [:, :, :, x ] * (x / blend_extent )
624+ return b
625+
609626 def tiled_encode (self , x : torch .Tensor , return_dict : bool = True ) -> torch .Tensor :
610- raise NotImplementedError ("`tiled_encode` has not been implemented for AutoencoderDC." )
627+ batch_size , num_channels , height , width = x .shape
628+ latent_height = height // self .spatial_compression_ratio
629+ latent_width = width // self .spatial_compression_ratio
630+
631+ tile_latent_min_height = self .tile_sample_min_height // self .spatial_compression_ratio
632+ tile_latent_min_width = self .tile_sample_min_width // self .spatial_compression_ratio
633+ tile_latent_stride_height = self .tile_sample_stride_height // self .spatial_compression_ratio
634+ tile_latent_stride_width = self .tile_sample_stride_width // self .spatial_compression_ratio
635+ blend_height = tile_latent_min_height - tile_latent_stride_height
636+ blend_width = tile_latent_min_width - tile_latent_stride_width
637+
638+ # Split x into overlapping tiles and encode them separately.
639+ # The tiles have an overlap to avoid seams between tiles.
640+ rows = []
641+ for i in range (0 , x .shape [2 ], self .tile_sample_stride_height ):
642+ row = []
643+ for j in range (0 , x .shape [3 ], self .tile_sample_stride_width ):
644+ tile = x [:, :, i : i + self .tile_sample_min_height , j : j + self .tile_sample_min_width ]
645+ if (
646+ tile .shape [2 ] % self .spatial_compression_ratio != 0
647+ or tile .shape [3 ] % self .spatial_compression_ratio != 0
648+ ):
649+ pad_h = (self .spatial_compression_ratio - tile .shape [2 ]) % self .spatial_compression_ratio
650+ pad_w = (self .spatial_compression_ratio - tile .shape [3 ]) % self .spatial_compression_ratio
651+ tile = F .pad (tile , (0 , pad_w , 0 , pad_h ))
652+ tile = self .encoder (tile )
653+ row .append (tile )
654+ rows .append (row )
655+ result_rows = []
656+ for i , row in enumerate (rows ):
657+ result_row = []
658+ for j , tile in enumerate (row ):
659+ # blend the above tile and the left tile
660+ # to the current tile and add the current tile to the result row
661+ if i > 0 :
662+ tile = self .blend_v (rows [i - 1 ][j ], tile , blend_height )
663+ if j > 0 :
664+ tile = self .blend_h (row [j - 1 ], tile , blend_width )
665+ result_row .append (tile [:, :, :tile_latent_stride_height , :tile_latent_stride_width ])
666+ result_rows .append (torch .cat (result_row , dim = 3 ))
667+
668+ encoded = torch .cat (result_rows , dim = 2 )[:, :, :latent_height , :latent_width ]
669+
670+ if not return_dict :
671+ return (encoded ,)
672+ return EncoderOutput (latent = encoded )
611673
612674 def tiled_decode (self , z : torch .Tensor , return_dict : bool = True ) -> Union [DecoderOutput , torch .Tensor ]:
613- raise NotImplementedError ("`tiled_decode` has not been implemented for AutoencoderDC." )
675+ batch_size , num_channels , height , width = z .shape
676+
677+ tile_latent_min_height = self .tile_sample_min_height // self .spatial_compression_ratio
678+ tile_latent_min_width = self .tile_sample_min_width // self .spatial_compression_ratio
679+ tile_latent_stride_height = self .tile_sample_stride_height // self .spatial_compression_ratio
680+ tile_latent_stride_width = self .tile_sample_stride_width // self .spatial_compression_ratio
681+
682+ blend_height = self .tile_sample_min_height - self .tile_sample_stride_height
683+ blend_width = self .tile_sample_min_width - self .tile_sample_stride_width
684+
685+ # Split z into overlapping tiles and decode them separately.
686+ # The tiles have an overlap to avoid seams between tiles.
687+ rows = []
688+ for i in range (0 , height , tile_latent_stride_height ):
689+ row = []
690+ for j in range (0 , width , tile_latent_stride_width ):
691+ tile = z [:, :, i : i + tile_latent_min_height , j : j + tile_latent_min_width ]
692+ decoded = self .decoder (tile )
693+ row .append (decoded )
694+ rows .append (row )
695+
696+ result_rows = []
697+ for i , row in enumerate (rows ):
698+ result_row = []
699+ for j , tile in enumerate (row ):
700+ # blend the above tile and the left tile
701+ # to the current tile and add the current tile to the result row
702+ if i > 0 :
703+ tile = self .blend_v (rows [i - 1 ][j ], tile , blend_height )
704+ if j > 0 :
705+ tile = self .blend_h (row [j - 1 ], tile , blend_width )
706+ result_row .append (tile [:, :, : self .tile_sample_stride_height , : self .tile_sample_stride_width ])
707+ result_rows .append (torch .cat (result_row , dim = 3 ))
708+
709+ decoded = torch .cat (result_rows , dim = 2 )
710+
711+ if not return_dict :
712+ return (decoded ,)
713+ return DecoderOutput (sample = decoded )
614714
615715 def forward (self , sample : torch .Tensor , return_dict : bool = True ) -> torch .Tensor :
616716 encoded = self .encode (sample , return_dict = False )[0 ]
0 commit comments