@@ -1182,7 +1182,8 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor:
11821182
11831183        frame_batch_size  =  self .num_sample_frames_batch_size 
11841184        # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k. 
1185-         num_batches  =  num_frames  //  frame_batch_size  if  num_frames  >  1  else  1 
1185+         # So, it is okay to not round up as the extra remaining frame is handled in the loop 
1186+         num_batches  =  max (num_frames  //  frame_batch_size , 1 )
11861187        conv_cache  =  None 
11871188        enc  =  []
11881189
@@ -1330,7 +1331,8 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
13301331            row  =  []
13311332            for  j  in  range (0 , width , overlap_width ):
13321333                # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k. 
1333-                 num_batches  =  num_frames  //  frame_batch_size  if  num_frames  >  1  else  1 
1334+                 # So, it is okay to not round up as the extra remaining frame is handled in the loop 
1335+                 num_batches  =  max (num_frames  //  frame_batch_size , 1 )
13341336                conv_cache  =  None 
13351337                time  =  []
13361338
@@ -1409,7 +1411,7 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod
14091411        for  i  in  range (0 , height , overlap_height ):
14101412            row  =  []
14111413            for  j  in  range (0 , width , overlap_width ):
1412-                 num_batches  =  num_frames  //  frame_batch_size 
1414+                 num_batches  =  max ( num_frames  //  frame_batch_size ,  1 ) 
14131415                conv_cache  =  None 
14141416                time  =  []
14151417
0 commit comments