@@ -486,6 +486,9 @@ def __init__(
486486        self .tile_sample_stride_height  =  448 
487487        self .tile_sample_stride_width  =  448 
488488
489+         self .tile_latent_min_height  =  self .tile_sample_min_height  //  self .spatial_compression_ratio 
490+         self .tile_latent_min_width  =  self .tile_sample_min_width  //  self .spatial_compression_ratio 
491+ 
489492    def  enable_tiling (
490493        self ,
491494        tile_sample_min_height : Optional [int ] =  None ,
@@ -515,6 +518,8 @@ def enable_tiling(
515518        self .tile_sample_min_width  =  tile_sample_min_width  or  self .tile_sample_min_width 
516519        self .tile_sample_stride_height  =  tile_sample_stride_height  or  self .tile_sample_stride_height 
517520        self .tile_sample_stride_width  =  tile_sample_stride_width  or  self .tile_sample_stride_width 
521+         self .tile_latent_min_height  =  self .tile_sample_min_height  //  self .spatial_compression_ratio 
522+         self .tile_latent_min_width  =  self .tile_sample_min_width  //  self .spatial_compression_ratio 
518523
519524    def  disable_tiling (self ) ->  None :
520525        r""" 
@@ -606,11 +611,106 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp
606611            return  (decoded ,)
607612        return  DecoderOutput (sample = decoded )
608613
614+     def  blend_v (self , a : torch .Tensor , b : torch .Tensor , blend_extent : int ) ->  torch .Tensor :
615+         blend_extent  =  min (a .shape [2 ], b .shape [2 ], blend_extent )
616+         for  y  in  range (blend_extent ):
617+             b [:, :, y , :] =  a [:, :, - blend_extent  +  y , :] *  (1  -  y  /  blend_extent ) +  b [:, :, y , :] *  (y  /  blend_extent )
618+         return  b 
619+ 
620+     def  blend_h (self , a : torch .Tensor , b : torch .Tensor , blend_extent : int ) ->  torch .Tensor :
621+         blend_extent  =  min (a .shape [3 ], b .shape [3 ], blend_extent )
622+         for  x  in  range (blend_extent ):
623+             b [:, :, :, x ] =  a [:, :, :, - blend_extent  +  x ] *  (1  -  x  /  blend_extent ) +  b [:, :, :, x ] *  (x  /  blend_extent )
624+         return  b 
625+ 
609626    def  tiled_encode (self , x : torch .Tensor , return_dict : bool  =  True ) ->  torch .Tensor :
610-         raise  NotImplementedError ("`tiled_encode` has not been implemented for AutoencoderDC." )
627+         batch_size , num_channels , height , width  =  x .shape 
628+         latent_height  =  height  //  self .spatial_compression_ratio 
629+         latent_width  =  width  //  self .spatial_compression_ratio 
630+ 
631+         tile_latent_min_height  =  self .tile_sample_min_height  //  self .spatial_compression_ratio 
632+         tile_latent_min_width  =  self .tile_sample_min_width  //  self .spatial_compression_ratio 
633+         tile_latent_stride_height  =  self .tile_sample_stride_height  //  self .spatial_compression_ratio 
634+         tile_latent_stride_width  =  self .tile_sample_stride_width  //  self .spatial_compression_ratio 
635+         blend_height  =  tile_latent_min_height  -  tile_latent_stride_height 
636+         blend_width  =  tile_latent_min_width  -  tile_latent_stride_width 
637+ 
638+         # Split x into overlapping tiles and encode them separately. 
639+         # The tiles have an overlap to avoid seams between tiles. 
640+         rows  =  []
641+         for  i  in  range (0 , x .shape [2 ], self .tile_sample_stride_height ):
642+             row  =  []
643+             for  j  in  range (0 , x .shape [3 ], self .tile_sample_stride_width ):
644+                 tile  =  x [:, :, i  : i  +  self .tile_sample_min_height , j  : j  +  self .tile_sample_min_width ]
645+                 if  (
646+                     tile .shape [2 ] %  self .spatial_compression_ratio  !=  0 
647+                     or  tile .shape [3 ] %  self .spatial_compression_ratio  !=  0 
648+                 ):
649+                     pad_h  =  (self .spatial_compression_ratio  -  tile .shape [2 ]) %  self .spatial_compression_ratio 
650+                     pad_w  =  (self .spatial_compression_ratio  -  tile .shape [3 ]) %  self .spatial_compression_ratio 
651+                     tile  =  F .pad (tile , (0 , pad_w , 0 , pad_h ))
652+                 tile  =  self .encoder (tile )
653+                 row .append (tile )
654+             rows .append (row )
655+         result_rows  =  []
656+         for  i , row  in  enumerate (rows ):
657+             result_row  =  []
658+             for  j , tile  in  enumerate (row ):
659+                 # blend the above tile and the left tile 
660+                 # to the current tile and add the current tile to the result row 
661+                 if  i  >  0 :
662+                     tile  =  self .blend_v (rows [i  -  1 ][j ], tile , blend_height )
663+                 if  j  >  0 :
664+                     tile  =  self .blend_h (row [j  -  1 ], tile , blend_width )
665+                 result_row .append (tile [:, :, :tile_latent_stride_height , :tile_latent_stride_width ])
666+             result_rows .append (torch .cat (result_row , dim = 3 ))
667+ 
668+         encoded  =  torch .cat (result_rows , dim = 2 )[:, :, :latent_height , :latent_width ]
669+ 
670+         if  not  return_dict :
671+             return  (encoded ,)
672+         return  EncoderOutput (latent = encoded )
611673
612674    def  tiled_decode (self , z : torch .Tensor , return_dict : bool  =  True ) ->  Union [DecoderOutput , torch .Tensor ]:
613-         raise  NotImplementedError ("`tiled_decode` has not been implemented for AutoencoderDC." )
675+         batch_size , num_channels , height , width  =  z .shape 
676+ 
677+         tile_latent_min_height  =  self .tile_sample_min_height  //  self .spatial_compression_ratio 
678+         tile_latent_min_width  =  self .tile_sample_min_width  //  self .spatial_compression_ratio 
679+         tile_latent_stride_height  =  self .tile_sample_stride_height  //  self .spatial_compression_ratio 
680+         tile_latent_stride_width  =  self .tile_sample_stride_width  //  self .spatial_compression_ratio 
681+ 
682+         blend_height  =  self .tile_sample_min_height  -  self .tile_sample_stride_height 
683+         blend_width  =  self .tile_sample_min_width  -  self .tile_sample_stride_width 
684+ 
685+         # Split z into overlapping tiles and decode them separately. 
686+         # The tiles have an overlap to avoid seams between tiles. 
687+         rows  =  []
688+         for  i  in  range (0 , height , tile_latent_stride_height ):
689+             row  =  []
690+             for  j  in  range (0 , width , tile_latent_stride_width ):
691+                 tile  =  z [:, :, i  : i  +  tile_latent_min_height , j  : j  +  tile_latent_min_width ]
692+                 decoded  =  self .decoder (tile )
693+                 row .append (decoded )
694+             rows .append (row )
695+ 
696+         result_rows  =  []
697+         for  i , row  in  enumerate (rows ):
698+             result_row  =  []
699+             for  j , tile  in  enumerate (row ):
700+                 # blend the above tile and the left tile 
701+                 # to the current tile and add the current tile to the result row 
702+                 if  i  >  0 :
703+                     tile  =  self .blend_v (rows [i  -  1 ][j ], tile , blend_height )
704+                 if  j  >  0 :
705+                     tile  =  self .blend_h (row [j  -  1 ], tile , blend_width )
706+                 result_row .append (tile [:, :, : self .tile_sample_stride_height , : self .tile_sample_stride_width ])
707+             result_rows .append (torch .cat (result_row , dim = 3 ))
708+ 
709+         decoded  =  torch .cat (result_rows , dim = 2 )
710+ 
711+         if  not  return_dict :
712+             return  (decoded ,)
713+         return  DecoderOutput (sample = decoded )
614714
615715    def  forward (self , sample : torch .Tensor , return_dict : bool  =  True ) ->  torch .Tensor :
616716        encoded  =  self .encode (sample , return_dict = False )[0 ]
0 commit comments