@@ -1182,7 +1182,8 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor:
11821182
11831183 frame_batch_size = self .num_sample_frames_batch_size
11841184 # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
1185- num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
1185+ # As the extra single frame is handled inside the loop, it is not required to round up here.
1186+ num_batches = max (num_frames // frame_batch_size , 1 )
11861187 conv_cache = None
11871188 enc = []
11881189
@@ -1330,7 +1331,8 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
13301331 row = []
13311332 for j in range (0 , width , overlap_width ):
13321333 # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
1333- num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
1334+ # As the extra single frame is handled inside the loop, it is not required to round up here.
1335+ num_batches = max (num_frames // frame_batch_size , 1 )
13341336 conv_cache = None
13351337 time = []
13361338
@@ -1409,7 +1411,7 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod
14091411 for i in range (0 , height , overlap_height ):
14101412 row = []
14111413 for j in range (0 , width , overlap_width ):
1412- num_batches = num_frames // frame_batch_size
1414+ num_batches = max ( num_frames // frame_batch_size , 1 )
14131415 conv_cache = None
14141416 time = []
14151417
0 commit comments