@@ -1010,10 +1010,12 @@ def __init__(
10101010        # The minimal tile height and width for spatial tiling to be used 
10111011        self .tile_sample_min_height  =  512 
10121012        self .tile_sample_min_width  =  512 
1013+         self .tile_sample_min_num_frames  =  16 
10131014
10141015        # The minimal distance between two spatial tiles 
10151016        self .tile_sample_stride_height  =  448 
10161017        self .tile_sample_stride_width  =  448 
1018+         self .tile_sample_stride_num_frames  =  8  
10171019
10181020    def  _set_gradient_checkpointing (self , module , value = False ):
10191021        if  isinstance (module , (LTXVideoEncoder3d , LTXVideoDecoder3d )):
@@ -1114,6 +1116,53 @@ def encode(
11141116        if  not  return_dict :
11151117            return  (posterior ,)
11161118        return  AutoencoderKLOutput (latent_dist = posterior )
1119+         
1120+     def  blend_t (self , a : torch .Tensor , b : torch .Tensor , blend_extent : int ) ->  torch .Tensor :
1121+         blend_extent  =  min (a .shape [- 3 ], b .shape [- 3 ], blend_extent )
1122+         for  x  in  range (blend_extent ):
1123+             b [:, :, x , :, :] =  a [:, :, - blend_extent  +  x , :, :] *  (1  -  x  /  blend_extent ) +  b [:, :, x , :, :] *  (
1124+                 x  /  blend_extent 
1125+             )
1126+         return  b 
1127+ 
1128+     def  _temporal_tiled_decode (self , z : torch .Tensor , temb : Optional [torch .Tensor ], return_dict : bool  =  True ) ->  Union [DecoderOutput , torch .Tensor ]:
1129+         batch_size , num_channels , num_frames , height , width  =  z .shape 
1130+         num_sample_frames  =  (num_frames  -  1 ) *  self .temporal_compression_ratio  +  1 
1131+ 
1132+         tile_latent_min_height  =  self .tile_sample_min_height  //  self .spatial_compression_ratio 
1133+         tile_latent_min_width  =  self .tile_sample_min_width  //  self .spatial_compression_ratio 
1134+         tile_latent_min_num_frames  =  self .tile_sample_min_num_frames  //  self .temporal_compression_ratio 
1135+         tile_latent_stride_num_frames  =  self .tile_sample_stride_num_frames  //  self .temporal_compression_ratio 
1136+         blend_num_frames  =  self .tile_sample_min_num_frames  -  self .tile_sample_stride_num_frames 
1137+ 
1138+         row  =  []
1139+         for  i  in  range (0 , num_frames , tile_latent_stride_num_frames ):
1140+             tile  =  z [:, :, i  : i  +  tile_latent_min_num_frames  +  1 , :, :]
1141+             if  self .use_tiling  and  (tile .shape [- 1 ] >  tile_latent_min_width  or  tile .shape [- 2 ] >  tile_latent_min_height ):
1142+                 decoded  =  self .tiled_decode (tile , temb , return_dict = True ).sample 
1143+             else :
1144+                 print ("NOT Use tile decode" )
1145+                 print (f"input tile: { tile .size ()}  )
1146+                 decoded  =  self .decoder (tile , temb )
1147+                 print (f"output tile: { decoded .size ()}  )
1148+             if  i  >  0 :
1149+                 decoded  =  decoded [:, :, :- 1 , :, :]
1150+             row .append (decoded )
1151+ 
1152+         result_row  =  []
1153+         for  i , tile  in  enumerate (row ):
1154+             if  i  >  0 :
1155+                 tile  =  self .blend_t (row [i  -  1 ], tile , blend_num_frames )
1156+                 tile  =  tile [:, :, : self .tile_sample_stride_num_frames , :, :]
1157+                 result_row .append (tile )
1158+             else :
1159+                 result_row .append (tile [:, :, :self .tile_sample_stride_num_frames  +  1 , :, :])
1160+ 
1161+         dec  =  torch .cat (result_row , dim = 2 )[:, :, :num_sample_frames ]
1162+ 
1163+         if  not  return_dict :
1164+             return  (dec ,)
1165+         return  DecoderOutput (sample = dec )
11171166
11181167    def  _decode (
11191168        self , z : torch .Tensor , temb : Optional [torch .Tensor ] =  None , return_dict : bool  =  True 
@@ -1125,13 +1174,8 @@ def _decode(
11251174        if  self .use_tiling  and  (width  >  tile_latent_min_width  or  height  >  tile_latent_min_height ):
11261175            return  self .tiled_decode (z , temb , return_dict = return_dict )
11271176
1128-         if  self .use_framewise_decoding :
1129-             # TODO(aryan): requires investigation 
1130-             raise  NotImplementedError (
1131-                 "Frame-wise decoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to " 
1132-                 "quality issues caused by splitting inference across frame dimension. If you believe this " 
1133-                 "should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls." 
1134-             )
1177+         if  self .use_framewise_decoding  and  num_frames  >  tile_latent_min_num_frames :
1178+             dec  =  self ._temporal_tiled_decode (z , temb , return_dict = False )[0 ]
11351179        else :
11361180            dec  =  self .decoder (z , temb )
11371181
0 commit comments