@@ -479,12 +479,15 @@ def __init__(
479479        self .use_tiling  =  False 
480480
481481        # The minimal tile height and width for spatial tiling to be used 
482-         self .tile_sample_min_height  =  512 
483-         self .tile_sample_min_width  =  512 
482+         self .tile_sample_min_height  =  1024 
483+         self .tile_sample_min_width  =  1024 
484484
485485        # The minimal distance between two spatial tiles 
486-         self .tile_sample_stride_height  =  448 
487-         self .tile_sample_stride_width  =  448 
486+         self .tile_sample_stride_height  =  896 
487+         self .tile_sample_stride_width  =  896 
488+ 
489+         self .tile_latent_min_height  =  self .tile_sample_min_height  //  self .spatial_compression_ratio 
490+         self .tile_latent_min_width  =  self .tile_sample_min_width  //  self .spatial_compression_ratio 
488491
489492    def  enable_tiling (
490493        self ,
@@ -515,6 +518,8 @@ def enable_tiling(
515518        self .tile_sample_min_width  =  tile_sample_min_width  or  self .tile_sample_min_width 
516519        self .tile_sample_stride_height  =  tile_sample_stride_height  or  self .tile_sample_stride_height 
517520        self .tile_sample_stride_width  =  tile_sample_stride_width  or  self .tile_sample_stride_width 
521+         self .tile_latent_min_height  =  self .tile_sample_min_height  //  self .spatial_compression_ratio 
522+         self .tile_latent_min_width  =  self .tile_sample_min_width  //  self .spatial_compression_ratio 
518523
519524    def  disable_tiling (self ) ->  None :
520525        r""" 
@@ -606,11 +611,101 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp
606611            return  (decoded ,)
607612        return  DecoderOutput (sample = decoded )
608613
614+     def  blend_v (self , a : torch .Tensor , b : torch .Tensor , blend_extent : int ) ->  torch .Tensor :
615+         blend_extent  =  min (a .shape [2 ], b .shape [2 ], blend_extent )
616+         for  y  in  range (blend_extent ):
617+             b [:, :, y , :] =  a [:, :, - blend_extent  +  y , :] *  (1  -  y  /  blend_extent ) +  b [:, :, y , :] *  (y  /  blend_extent )
618+         return  b 
619+ 
620+     def  blend_h (self , a : torch .Tensor , b : torch .Tensor , blend_extent : int ) ->  torch .Tensor :
621+         blend_extent  =  min (a .shape [3 ], b .shape [3 ], blend_extent )
622+         for  x  in  range (blend_extent ):
623+             b [:, :, :, x ] =  a [:, :, :, - blend_extent  +  x ] *  (1  -  x  /  blend_extent ) +  b [:, :, :, x ] *  (x  /  blend_extent )
624+         return  b 
625+ 
609626    def  tiled_encode (self , x : torch .Tensor , return_dict : bool  =  True ) ->  torch .Tensor :
610-         raise  NotImplementedError ("`tiled_encode` has not been implemented for AutoencoderDC." )
627+         batch_size , num_channels , height , width  =  x .shape 
628+         latent_height  =  height  //  self .spatial_compression_ratio 
629+         latent_width  =  width  //  self .spatial_compression_ratio 
630+         
631+         tile_latent_min_height  =  self .tile_sample_min_height  //  self .spatial_compression_ratio 
632+         tile_latent_min_width  =  self .tile_sample_min_width  //  self .spatial_compression_ratio 
633+         tile_latent_stride_height  =  self .tile_sample_stride_height  //  self .spatial_compression_ratio 
634+         tile_latent_stride_width  =  self .tile_sample_stride_width  //  self .spatial_compression_ratio 
635+         blend_height  =  tile_latent_min_height  -  tile_latent_stride_height 
636+         blend_width  =  tile_latent_min_width  -  tile_latent_stride_width 
637+ 
638+         # Split x into overlapping tiles and encode them separately. 
639+         # The tiles have an overlap to avoid seams between tiles. 
640+         rows  =  []
641+         for  i  in  range (0 , x .shape [2 ], self .tile_sample_stride_height ):
642+             row  =  []
643+             for  j  in  range (0 , x .shape [3 ], self .tile_sample_stride_width ):
644+                 tile  =  x [:, :, i  : i  +  self .tile_sample_min_height , j  : j  +  self .tile_sample_min_width ]
645+                 if  tile .shape [2 ] %  self .spatial_compression_ratio  !=  0  or  tile .shape [3 ] %  self .spatial_compression_ratio  !=  0 :
646+                     tile  =  F .pad (tile , (0 , (self .spatial_compression_ratio  -  tile .shape [3 ]) %  self .spatial_compression_ratio , 0 , (self .spatial_compression_ratio  -  tile .shape [2 ]) %  self .spatial_compression_ratio ))
647+                 tile  =  self .encoder (tile )
648+                 row .append (tile )
649+             rows .append (row )
650+         result_rows  =  []
651+         for  i , row  in  enumerate (rows ):
652+             result_row  =  []
653+             for  j , tile  in  enumerate (row ):
654+                 # blend the above tile and the left tile 
655+                 # to the current tile and add the current tile to the result row 
656+                 if  i  >  0 :
657+                     tile  =  self .blend_v (rows [i  -  1 ][j ], tile , blend_height )
658+                 if  j  >  0 :
659+                     tile  =  self .blend_h (row [j  -  1 ], tile , blend_width )
660+                 result_row .append (tile [:, :, :tile_latent_stride_height , :tile_latent_stride_width ])
661+             result_rows .append (torch .cat (result_row , dim = 3 ))
662+ 
663+         encoded  =  torch .cat (result_rows , dim = 2 )[:, :, :latent_height , :latent_width ]
664+ 
665+         if  not  return_dict :
666+             return  (encoded ,)
667+         return  EncoderOutput (latent = encoded )
611668
612669    def  tiled_decode (self , z : torch .Tensor , return_dict : bool  =  True ) ->  Union [DecoderOutput , torch .Tensor ]:
613-         raise  NotImplementedError ("`tiled_decode` has not been implemented for AutoencoderDC." )
670+         batch_size , num_channels , height , width  =  z .shape 
671+ 
672+         tile_latent_min_height  =  self .tile_sample_min_height  //  self .spatial_compression_ratio 
673+         tile_latent_min_width  =  self .tile_sample_min_width  //  self .spatial_compression_ratio 
674+         tile_latent_stride_height  =  self .tile_sample_stride_height  //  self .spatial_compression_ratio 
675+         tile_latent_stride_width  =  self .tile_sample_stride_width  //  self .spatial_compression_ratio 
676+ 
677+         blend_height  =  self .tile_sample_min_height  -  self .tile_sample_stride_height 
678+         blend_width  =  self .tile_sample_min_width  -  self .tile_sample_stride_width 
679+ 
680+         # Split z into overlapping tiles and decode them separately. 
681+         # The tiles have an overlap to avoid seams between tiles. 
682+         rows  =  []
683+         for  i  in  range (0 , height , tile_latent_stride_height ):
684+             row  =  []
685+             for  j  in  range (0 , width , tile_latent_stride_width ):
686+                 tile  =  z [:, :, i  : i  +  tile_latent_min_height , j  : j  +  tile_latent_min_width ]
687+                 decoded  =  self .decoder (tile )
688+                 row .append (decoded )
689+             rows .append (row )
690+ 
691+         result_rows  =  []
692+         for  i , row  in  enumerate (rows ):
693+             result_row  =  []
694+             for  j , tile  in  enumerate (row ):
695+                 # blend the above tile and the left tile 
696+                 # to the current tile and add the current tile to the result row 
697+                 if  i  >  0 :
698+                     tile  =  self .blend_v (rows [i  -  1 ][j ], tile , blend_height )
699+                 if  j  >  0 :
700+                     tile  =  self .blend_h (row [j  -  1 ], tile , blend_width )
701+                 result_row .append (tile [:, :, : self .tile_sample_stride_height , : self .tile_sample_stride_width ])
702+             result_rows .append (torch .cat (result_row , dim = 3 ))
703+ 
704+         decoded  =  torch .cat (result_rows , dim = 2 )
705+ 
706+         if  not  return_dict :
707+             return  (decoded ,)
708+         return  DecoderOutput (sample = decoded )
614709
615710    def  forward (self , sample : torch .Tensor , return_dict : bool  =  True ) ->  torch .Tensor :
616711        encoded  =  self .encode (sample , return_dict = False )[0 ]
0 commit comments