Skip to content

Commit 3881586

Browse files
committed
Cleanups
1 parent da40954 commit 3881586

File tree

2 files changed

+67
-25
lines changed

2 files changed

+67
-25
lines changed

test/decoders/test_ops.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
INDEX_OF_FRAME_AT_6_SECONDS = 180
5252

5353

54-
class TestOps:
54+
class TestVideoOps:
5555
@pytest.mark.parametrize("device", cpu_and_cuda())
5656
def test_seek_and_next(self, device):
5757
decoder = create_from_file(str(NASA_VIDEO.path))
@@ -616,6 +616,8 @@ def test_cuda_decoder(self):
616616
duration, torch.tensor(0.0334).double(), atol=0, rtol=1e-3
617617
)
618618

619+
620+
class TestAudioOps:
619621
@pytest.mark.parametrize(
620622
"method",
621623
(
@@ -664,19 +666,15 @@ def test_audio_decode_all_samples_with_next(self, asset):
664666
@pytest.mark.parametrize(
665667
"range", ("begin_to_end", "at_frame_boundaries", "not_at_frame_boundaries")
666668
)
667-
# @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
668-
@pytest.mark.parametrize("asset", (NASA_AUDIO,))
669-
def test_audio_get_frames_by_pts_in_range_audio(self, range, asset):
669+
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
670+
def test_get_frames_by_pts_in_range_audio(self, range, asset):
670671
if range == "begin_to_end":
671672
start_seconds, stop_seconds = 0, asset.duration_seconds
672673
elif range == "at_frame_boundaries":
673674
start_seconds = asset.frames[asset.default_stream_index][10].pts_seconds
674-
# need -1e-5 because the upper bound in open. If we don't do this
675-
# then our test util returns one frame too much.
676-
stop_seconds = (
677-
asset.frames[asset.default_stream_index][40].pts_seconds - 1e-5
678-
)
675+
stop_seconds = asset.frames[asset.default_stream_index][40].pts_seconds
679676
else:
677+
assert range == "not_at_frame_boundaries"
680678
start_frame_info = asset.frames[asset.default_stream_index][10]
681679
stop_frame_info = asset.frames[asset.default_stream_index][40]
682680
start_seconds = start_frame_info.pts_seconds + (
@@ -689,20 +687,32 @@ def test_audio_get_frames_by_pts_in_range_audio(self, range, asset):
689687
decoder = create_from_file(str(asset.path), seek_mode="approximate")
690688
add_audio_stream(decoder)
691689

690+
stop_offset = 0 if range == "at_frame_boundaries" else 1
692691
reference_frames = asset.get_frame_data_by_range(
693692
start=asset.get_frame_index(pts_seconds=start_seconds),
694-
stop=asset.get_frame_index(pts_seconds=stop_seconds) + 1,
693+
stop=asset.get_frame_index(pts_seconds=stop_seconds) + stop_offset,
695694
)
696-
reference_frames = torch.cat(reference_frames.unbind(), dim=-1)
697695

698696
frames = get_frames_by_pts_in_range_audio(
699697
decoder, start_seconds=start_seconds, stop_seconds=stop_seconds
700698
)
701699

702700
assert_frames_equal(frames, reference_frames)
703701

702+
@pytest.mark.parametrize(
703+
"asset, expected_shape", ((NASA_AUDIO, (2, 1024)), (NASA_AUDIO_MP3, (2, 576)))
704+
)
705+
def test_decode_epsilon_range(self, asset, expected_shape):
706+
decoder = create_from_file(str(asset.path), seek_mode="approximate")
707+
add_audio_stream(decoder)
708+
709+
frames = get_frames_by_pts_in_range_audio(
710+
decoder, start_seconds=5, stop_seconds=5 + 1e-5
711+
)
712+
assert frames.shape == expected_shape
713+
704714
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
705-
def test_audio_seek_and_next(self, asset):
715+
def test_seek_and_next_audio(self, asset):
706716
decoder = create_from_file(str(asset.path), seek_mode="approximate")
707717
add_audio_stream(decoder)
708718

test/utils.py

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -190,11 +190,7 @@ def get_frame_data_by_range(
190190
*,
191191
stream_index: Optional[int] = None,
192192
) -> torch.Tensor:
193-
tensors = [
194-
self.get_frame_data_by_index(i, stream_index=stream_index)
195-
for i in range(start, stop, step)
196-
]
197-
return torch.stack(tensors)
193+
raise NotImplementedError("Override in child classes")
198194

199195
def get_pts_seconds_by_range(
200196
self,
@@ -261,6 +257,20 @@ def get_frame_data_by_index(
261257
)
262258
return torch.load(file_path, weights_only=True).permute(2, 0, 1)
263259

260+
def get_frame_data_by_range(
261+
self,
262+
start: int,
263+
stop: int,
264+
step: int = 1,
265+
*,
266+
stream_index: Optional[int] = None,
267+
) -> torch.Tensor:
268+
tensors = [
269+
self.get_frame_data_by_index(i, stream_index=stream_index)
270+
for i in range(start, stop, step)
271+
]
272+
return torch.stack(tensors)
273+
264274
@property
265275
def width(self) -> int:
266276
return self.stream_infos[self.default_stream_index].width
@@ -327,6 +337,7 @@ class TestAudio(TestContainerFile):
327337

328338
stream_infos: Dict[int, TestAudioStreamInfo]
329339
# stream_index -> list of 2D frame tensors of shape (num_channels, num_samples_in_that_frame)
340+
# num_samples_in_that_frame isn't necessarily constant for a given stream.
330341
_reference_frames: Dict[int, List[torch.Tensor]] = field(default_factory=dict)
331342

332343
# Storing each individual frame is too expensive for audio, because there's
@@ -354,19 +365,40 @@ def get_frame_data_by_index(
354365

355366
return self._reference_frames[stream_index][idx]
356367

368+
def get_frame_data_by_range(
369+
self,
370+
start: int,
371+
stop: int,
372+
step: int = 1,
373+
*,
374+
stream_index: Optional[int] = None,
375+
) -> torch.Tensor:
376+
tensors = [
377+
self.get_frame_data_by_index(i, stream_index=stream_index)
378+
for i in range(start, stop, step)
379+
]
380+
return torch.cat(tensors, dim=-1)
381+
357382
def get_frame_index(
358383
self, *, pts_seconds: float, stream_index: Optional[int] = None
359384
) -> int:
360385
if stream_index is None:
361386
stream_index = self.default_stream_index
362-
out = next(
363-
frame_index
364-
for (frame_index, frame_info) in self.frames[stream_index].items()
365-
if frame_info.pts_seconds
366-
<= pts_seconds
367-
< frame_info.pts_seconds + frame_info.duration_seconds
368-
)
369-
return out
387+
388+
if pts_seconds <= self.frames[stream_index][0].pts_seconds:
389+
# Special case for e.g. NASA_AUDIO_MP3 whose first frame's pts is
390+
# 0.13~, not 0.
391+
return 0
392+
try:
393+
return next(
394+
frame_index
395+
for (frame_index, frame_info) in self.frames[stream_index].items()
396+
if frame_info.pts_seconds
397+
<= pts_seconds
398+
< frame_info.pts_seconds + frame_info.duration_seconds
399+
)
400+
except StopIteration:
401+
return len(self.frames[stream_index]) - 1
370402

371403
@property
372404
def sample_rate(self) -> int:

0 commit comments

Comments
 (0)