@@ -1015,7 +1015,7 @@ def __init__(
10151015        # The minimal distance between two spatial tiles 
10161016        self .tile_sample_stride_height  =  448 
10171017        self .tile_sample_stride_width  =  448 
1018-         self .tile_sample_stride_num_frames  =  8   
1018+         self .tile_sample_stride_num_frames  =  8 
10191019
10201020    def  _set_gradient_checkpointing (self , module , value = False ):
10211021        if  isinstance (module , (LTXVideoEncoder3d , LTXVideoDecoder3d )):
@@ -1185,7 +1185,7 @@ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.
11851185                x  /  blend_extent 
11861186            )
11871187        return  b 
1188-          
1188+ 
11891189    def  blend_t (self , a : torch .Tensor , b : torch .Tensor , blend_extent : int ) ->  torch .Tensor :
11901190        blend_extent  =  min (a .shape [- 3 ], b .shape [- 3 ], blend_extent )
11911191        for  x  in  range (blend_extent ):
@@ -1280,9 +1280,7 @@ def tiled_decode(
12801280        for  i  in  range (0 , height , tile_latent_stride_height ):
12811281            row  =  []
12821282            for  j  in  range (0 , width , tile_latent_stride_width ):
1283-                 time  =  self .decoder (
1284-                     z [:, :, :, i  : i  +  tile_latent_min_height , j  : j  +  tile_latent_min_width ], temb 
1285-                 )
1283+                 time  =  self .decoder (z [:, :, :, i  : i  +  tile_latent_min_height , j  : j  +  tile_latent_min_width ], temb )
12861284
12871285                row .append (time )
12881286            rows .append (row )
@@ -1337,7 +1335,9 @@ def _temporal_tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
13371335        enc  =  torch .cat (result_row , dim = 2 )[:, :, :latent_num_frames ]
13381336        return  enc 
13391337
1340-     def  _temporal_tiled_decode (self , z : torch .Tensor , temb : Optional [torch .Tensor ], return_dict : bool  =  True ) ->  Union [DecoderOutput , torch .Tensor ]:
1338+     def  _temporal_tiled_decode (
1339+         self , z : torch .Tensor , temb : Optional [torch .Tensor ], return_dict : bool  =  True 
1340+     ) ->  Union [DecoderOutput , torch .Tensor ]:
13411341        batch_size , num_channels , num_frames , height , width  =  z .shape 
13421342        num_sample_frames  =  (num_frames  -  1 ) *  self .temporal_compression_ratio  +  1 
13431343
@@ -1365,7 +1365,7 @@ def _temporal_tiled_decode(self, z: torch.Tensor, temb: Optional[torch.Tensor],
13651365                tile  =  tile [:, :, : self .tile_sample_stride_num_frames , :, :]
13661366                result_row .append (tile )
13671367            else :
1368-                 result_row .append (tile [:, :, :self .tile_sample_stride_num_frames  +  1 , :, :])
1368+                 result_row .append (tile [:, :, :  self .tile_sample_stride_num_frames  +  1 , :, :])
13691369
13701370        dec  =  torch .cat (result_row , dim = 2 )[:, :, :num_sample_frames ]
13711371
0 commit comments