Skip to content

Commit 04bf185

Browse files
committed
Fix some tests
1 parent 6c7e31f commit 04bf185

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1017,7 +1017,10 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
10171017
// exactly at a frame start either.
10181018
// - When calling `getFramePlayedAt(pts)`, regardless of the mode, if `pts`
10191019
// doesn't land exactly at a frame's start. We have tests that currently
1020-
// exhibit this behavior: test_get_frame_at_pts_audio_bad().
1020+
// exhibit this behavior: test_get_frame_at_pts_audio_bad(). The "obvious"
1021+
// fix for this is to let `getFramePlayedAt` convert the pts to an index,
1022+
// just like the rest of the APIs.
1023+
//
10211024
// TODO HOW DO WE FIX THIS??
10221025

10231026
// A few notes:

test/utils.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,25 @@ def cpu_and_cuda():
2323
return ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda))
2424

2525

26+
def assert_frames_equal(*args, **kwargs):
27+
frame = args[0]
28+
# This heuristic will work until we start returningu int8 audio frames...
29+
if frame.dtype == torch.uint8:
30+
return assert_video_frames_equal(*args, **kwargs)
31+
else:
32+
return assert_audio_frames_equal(*args, **kwargs)
33+
34+
35+
def assert_audio_frames_equal(*args, **kwargs):
36+
torch.testing.assert_close(*args, **kwargs)
37+
38+
2639
# For use with decoded data frames. On CPU Linux, we expect exact, bit-for-bit
2740
# equality. On CUDA Linux, we expect a small tolerance.
2841
# On other platforms (e.g. MacOS), we also allow a small tolerance. FFmpeg does
2942
# not guarantee bit-for-bit equality across systems and architectures, so we
3043
# also cannot. We currently use Linux on x86_64 as our reference system.
31-
def assert_frames_equal(*args, **kwargs):
44+
def assert_video_frames_equal(*args, **kwargs):
3245
if sys.platform == "linux":
3346
if args[0].device.type == "cuda":
3447
atol = 2

0 commit comments

Comments
 (0)