Skip to content

Commit d2357fe

Browse files
committed
Add test
1 parent 98fee85 commit d2357fe

File tree

3 files changed

+46
-17
lines changed

3 files changed

+46
-17
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1324,7 +1324,8 @@ void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU(
13241324

13251325
auto numSamples = avFrame->nb_samples; // per channel
13261326
auto numChannels = getNumChannels(avFrame);
1327-
torch::Tensor outputData = torch::empty({numChannels, numSamples}, torch::kFloat32);
1327+
torch::Tensor outputData =
1328+
torch::empty({numChannels, numSamples}, torch::kFloat32);
13281329

13291330
AVSampleFormat format = static_cast<AVSampleFormat>(avFrame->format);
13301331
// TODO-AUDIO Implement all formats.

test/decoders/test_ops.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -695,7 +695,7 @@ def test_get_frames_by_pts_in_range_audio(self, range, asset):
695695
decoder, start_seconds=start_seconds, stop_seconds=stop_seconds
696696
)
697697

698-
assert_frames_equal(frames, reference_frames)
698+
torch.testing.assert_close(frames, reference_frames)
699699

700700
@pytest.mark.parametrize(
701701
"asset, expected_shape", ((NASA_AUDIO, (2, 1024)), (NASA_AUDIO_MP3, (2, 576)))
@@ -723,6 +723,46 @@ def test_decode_just_one_frame_at_boundaries(self, asset, expected_shape):
723723
)
724724
assert frames.shape == expected_shape
725725

726+
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
727+
def test_multiple_calls(self, asset):
728+
729+
def decode_stateless(start_seconds, stop_seconds):
730+
decoder = create_from_file(str(asset.path), seek_mode="approximate")
731+
add_audio_stream(decoder)
732+
733+
return get_frames_by_pts_in_range_audio(
734+
decoder, start_seconds=start_seconds, stop_seconds=stop_seconds
735+
)
736+
737+
decoder = create_from_file(str(asset.path), seek_mode="approximate")
738+
add_audio_stream(decoder)
739+
740+
start_seconds, stop_seconds = 0, 2
741+
frames = get_frames_by_pts_in_range_audio(
742+
decoder, start_seconds=start_seconds, stop_seconds=stop_seconds
743+
)
744+
torch.testing.assert_close(
745+
frames, decode_stateless(start_seconds, stop_seconds)
746+
)
747+
748+
start_seconds, stop_seconds = 3, 4
749+
frames = get_frames_by_pts_in_range_audio(
750+
decoder, start_seconds=start_seconds, stop_seconds=stop_seconds
751+
)
752+
torch.testing.assert_close(
753+
frames, decode_stateless(start_seconds, stop_seconds)
754+
)
755+
756+
# TODO-AUDIO
757+
start_seconds, stop_seconds = 0, 2
758+
frames = get_frames_by_pts_in_range_audio(
759+
decoder, start_seconds=start_seconds, stop_seconds=stop_seconds
760+
)
761+
with pytest.raises(AssertionError):
762+
torch.testing.assert_close(
763+
frames, decode_stateless(start_seconds, stop_seconds)
764+
)
765+
726766

727767
if __name__ == "__main__":
728768
pytest.main()

test/utils.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,25 +26,16 @@ def cpu_and_cuda():
2626
return ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda))
2727

2828

29-
def assert_frames_equal(*args, **kwargs):
30-
frame = args[0]
31-
# This heuristic will work until we start returning uint8 audio frames...
32-
if frame.dtype == torch.uint8:
33-
return assert_video_frames_equal(*args, **kwargs)
34-
else:
35-
return assert_audio_frames_equal(*args, **kwargs)
36-
37-
38-
def assert_audio_frames_equal(*args, **kwargs):
39-
torch.testing.assert_close(*args, **kwargs)
29+
def get_ffmpeg_major_version():
30+
return int(get_ffmpeg_library_versions()["ffmpeg_version"].split(".")[0])
4031

4132

4233
# For use with decoded data frames. On CPU Linux, we expect exact, bit-for-bit
4334
# equality. On CUDA Linux, we expect a small tolerance.
4435
# On other platforms (e.g. MacOS), we also allow a small tolerance. FFmpeg does
4536
# not guarantee bit-for-bit equality across systems and architectures, so we
4637
# also cannot. We currently use Linux on x86_64 as our reference system.
47-
def assert_video_frames_equal(*args, **kwargs):
38+
def assert_frames_equal(*args, **kwargs):
4839
if sys.platform == "linux":
4940
if args[0].device.type == "cuda":
5041
atol = 2
@@ -82,9 +73,6 @@ def assert_tensor_close_on_at_least(actual_tensor, ref_tensor, *, percentage, at
8273
)
8374

8475

85-
def get_ffmpeg_major_version():
86-
return int(get_ffmpeg_library_versions()["ffmpeg_version"].split(".")[0])
87-
8876

8977
def in_fbcode() -> bool:
9078
return os.environ.get("IN_FBCODE_TORCHCODEC") == "1"

0 commit comments

Comments
 (0)