diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp index a502de0de..9190ac4d4 100644 --- a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp @@ -424,8 +424,8 @@ int BetaCudaDeviceInterface::frameReadyInDisplayOrder( int BetaCudaDeviceInterface::receiveFrame(UniqueAVFrame& avFrame) { if (readyFrames_.empty()) { // No frame found, instruct caller to try again later after sending more - // packets. - return AVERROR(EAGAIN); + // packets, or to stop if EOF was already sent. + return eofSent_ ? AVERROR_EOF : AVERROR(EAGAIN); } CUVIDPARSERDISPINFO dispInfo = readyFrames_.front(); readyFrames_.pop(); diff --git a/test/test_decoders.py b/test/test_decoders.py index 260a7baa2..c803bc592 100644 --- a/test/test_decoders.py +++ b/test/test_decoders.py @@ -45,6 +45,7 @@ SINE_MONO_S32_8000, TEST_SRC_2_720P, TEST_SRC_2_720P_H265, + unsplit_device_str, ) @@ -178,6 +179,7 @@ def test_getitem_int(self, num_ffmpeg_threads, device, seek_mode): device=device, seek_mode=seek_mode, ) + device, _ = unsplit_device_str(device) ref_frame0 = NASA_VIDEO.get_frame_data_by_index(0).to(device) ref_frame1 = NASA_VIDEO.get_frame_data_by_index(1).to(device) @@ -223,6 +225,7 @@ def test_getitem_numpy_int(self): @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_getitem_slice(self, device, seek_mode): decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) + device, _ = unsplit_device_str(device) # ensure that the degenerate case of a range of size 1 works @@ -400,6 +403,7 @@ def test_getitem_fails(self, device, seek_mode): @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_iteration(self, device, seek_mode): decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) + device, _ = unsplit_device_str(device) ref_frame0 = NASA_VIDEO.get_frame_data_by_index(0).to(device) ref_frame1 = NASA_VIDEO.get_frame_data_by_index(1).to(device) @@ -447,6 +451,7 @@ def test_iteration_slow(self): @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frame_at(self, device, seek_mode): decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) + device, _ = unsplit_device_str(device) ref_frame9 = NASA_VIDEO.get_frame_data_by_index(9).to(device) frame9 = decoder.get_frame_at(9) @@ -510,6 +515,7 @@ def test_get_frame_at_fails(self, device, seek_mode): @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frames_at(self, device, seek_mode): decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) + device, _ = unsplit_device_str(device) # test positive and negative frame index frames = decoder.get_frames_at([35, 25, -1, -2]) @@ -585,6 +591,7 @@ def test_get_frame_at_av1(self, device): pytest.skip("AV1 decoding on CUDA is not supported internally") decoder = VideoDecoder(AV1_VIDEO.path, device=device) + device, _ = unsplit_device_str(device) ref_frame10 = AV1_VIDEO.get_frame_data_by_index(10) ref_frame_info10 = AV1_VIDEO.get_frame_info(10) decoded_frame10 = decoder.get_frame_at(10) @@ -596,6 +603,7 @@ def test_get_frame_at_av1(self, device): @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frame_played_at(self, device, seek_mode): decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) + device, _ = unsplit_device_str(device) ref_frame_played_at_6 = NASA_VIDEO.get_frame_data_by_index(180).to(device) assert_frames_equal( @@ -635,8 +643,8 @@ def test_get_frame_played_at_fails(self, device, seek_mode): @pytest.mark.parametrize("device", all_supported_devices()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frames_played_at(self, device, seek_mode): - decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) + device, _ = unsplit_device_str(device) # Note: We know the frame at ~0.84s has index 25, the one at 1.16s has # index 35. We use those indices as reference to test against. @@ -695,6 +703,7 @@ def test_get_frames_in_range(self, stream_index, device, seek_mode): device=device, seek_mode=seek_mode, ) + device, _ = unsplit_device_str(device) # test degenerate case where we only actually get 1 frame ref_frames9 = NASA_VIDEO.get_frame_data_by_range( @@ -799,6 +808,7 @@ def test_get_frames_in_range_slice_indices_syntax(self, device, seek_mode): device=device, seek_mode=seek_mode, ) + device, _ = unsplit_device_str(device) # high range ends get capped to num_frames frames387_389 = decoder.get_frames_in_range(start=387, stop=1000) @@ -874,6 +884,7 @@ def test_get_frames_with_missing_num_frames_metadata( device=device, seek_mode=seek_mode, ) + device, _ = unsplit_device_str(device) assert decoder.metadata.num_frames_from_header is None assert decoder.metadata.num_frames_from_content is None @@ -942,6 +953,7 @@ def test_get_frames_by_pts_in_range(self, stream_index, device, seek_mode): device=device, seek_mode=seek_mode, ) + device, _ = unsplit_device_str(device) # Note that we are comparing the results of VideoDecoder's method: # get_frames_played_in_range() @@ -1134,6 +1146,7 @@ def test_get_key_frame_indices(self, device): @pytest.mark.parametrize("device", all_supported_devices()) def test_compile(self, device): decoder = VideoDecoder(NASA_VIDEO.path, device=device) + device, _ = unsplit_device_str(device) @contextlib.contextmanager def restore_capture_scalar_outputs(): @@ -1271,6 +1284,19 @@ def test_10bit_videos(self, device, asset): # This just validates that we can decode 10-bit videos. # TODO validate against the ref that the decoded frames are correct + if device == "cuda:0:beta": + # This fails on our BETA interface on asset 0 (only!) with: + # + # RuntimeError: Codec configuration not supported on this GPU. + # Codec: 4, chroma format: 1, bit depth: 10 + # + # I don't remember but I suspect asset 0 is actually the one that + # fallsback to the CPU path on the default CUDA interface (that + # would make sense) + # We should investigate if and how we could make that fallback + # happen for the BETA interface. + pytest.skip("TODONVDEC P2 - investigate and unskip") + decoder = VideoDecoder(asset.path, device=device) decoder.get_frame_at(10) @@ -1316,6 +1342,7 @@ def test_custom_frame_mappings_json_and_bytes( device=device, custom_frame_mappings=custom_frame_mappings, ) + device, _ = unsplit_device_str(device) frame_0 = decoder.get_frame_at(0) frame_5 = decoder.get_frame_at(5) assert_frames_equal( diff --git a/test/test_ops.py b/test/test_ops.py index b50aec88b..f16dd63ad 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -55,6 +55,7 @@ SINE_MONO_S32, SINE_MONO_S32_44100, SINE_MONO_S32_8000, + unsplit_device_str, ) torch._dynamo.config.capture_dynamic_output_shape_ops = True @@ -66,7 +67,8 @@ class TestVideoDecoderOps: @pytest.mark.parametrize("device", all_supported_devices()) def test_seek_and_next(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) - add_video_stream(decoder, device=device) + device, device_variant = unsplit_device_str(device) + add_video_stream(decoder, device=device, device_variant=device_variant) frame0, _, _ = get_next_frame(decoder) reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) assert_frames_equal(frame0, reference_frame0.to(device)) @@ -83,7 +85,8 @@ def test_seek_and_next(self, device): @pytest.mark.parametrize("device", all_supported_devices()) def test_seek_to_negative_pts(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) - add_video_stream(decoder, device=device) + device, device_variant = unsplit_device_str(device) + add_video_stream(decoder, device=device, device_variant=device_variant) frame0, _, _ = get_next_frame(decoder) reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) assert_frames_equal(frame0, reference_frame0.to(device)) @@ -95,7 +98,8 @@ def test_seek_to_negative_pts(self, device): @pytest.mark.parametrize("device", all_supported_devices()) def test_get_frame_at_pts(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) - add_video_stream(decoder, device=device) + device, device_variant = unsplit_device_str(device) + add_video_stream(decoder, device=device, device_variant=device_variant) # This frame has pts=6.006 and duration=0.033367, so it should be visible # at timestamps in the range [6.006, 6.039367) (not including the last timestamp). frame6, _, _ = get_frame_at_pts(decoder, 6.006) @@ -119,7 +123,8 @@ def test_get_frame_at_pts(self, device): @pytest.mark.parametrize("device", all_supported_devices()) def test_get_frame_at_index(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) - add_video_stream(decoder, device=device) + device, device_variant = unsplit_device_str(device) + add_video_stream(decoder, device=device, device_variant=device_variant) frame0, _, _ = get_frame_at_index(decoder, frame_index=0) reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) assert_frames_equal(frame0, reference_frame0.to(device)) @@ -137,7 +142,8 @@ def test_get_frame_at_index(self, device): @pytest.mark.parametrize("device", all_supported_devices()) def test_get_frame_with_info_at_index(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) - add_video_stream(decoder, device=device) + device, device_variant = unsplit_device_str(device) + add_video_stream(decoder, device=device, device_variant=device_variant) frame6, pts, duration = get_frame_at_index(decoder, frame_index=180) reference_frame6 = NASA_VIDEO.get_frame_data_by_index( INDEX_OF_FRAME_AT_6_SECONDS @@ -149,7 +155,8 @@ def test_get_frame_with_info_at_index(self, device): @pytest.mark.parametrize("device", all_supported_devices()) def test_get_frames_at_indices(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) - add_video_stream(decoder, device=device) + device, device_variant = unsplit_device_str(device) + add_video_stream(decoder, device=device, device_variant=device_variant) frames0and180, *_ = get_frames_at_indices(decoder, frame_indices=[0, 180]) reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) reference_frame180 = NASA_VIDEO.get_frame_data_by_index( @@ -161,7 +168,8 @@ def test_get_frames_at_indices(self, device): @pytest.mark.parametrize("device", all_supported_devices()) def test_get_frames_at_indices_unsorted_indices(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) - _add_video_stream(decoder, device=device) + device, device_variant = unsplit_device_str(device) + add_video_stream(decoder, device=device, device_variant=device_variant) frame_indices = [2, 0, 1, 0, 2] @@ -188,7 +196,8 @@ def test_get_frames_at_indices_unsorted_indices(self, device): @pytest.mark.parametrize("device", all_supported_devices()) def test_get_frames_at_indices_negative_indices(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) - add_video_stream(decoder, device=device) + device, device_variant = unsplit_device_str(device) + add_video_stream(decoder, device=device, device_variant=device_variant) frames389and387and1, *_ = get_frames_at_indices( decoder, frame_indices=[-1, -3, -389] ) @@ -202,7 +211,8 @@ def test_get_frames_at_indices_negative_indices(self, device): @pytest.mark.parametrize("device", all_supported_devices()) def test_get_frames_at_indices_fail_on_invalid_negative_indices(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) - add_video_stream(decoder, device=device) + device, device_variant = unsplit_device_str(device) + add_video_stream(decoder, device=device, device_variant=device_variant) with pytest.raises( IndexError, match="negative indices must have an absolute value less than the number of frames", @@ -214,7 +224,8 @@ def test_get_frames_at_indices_fail_on_invalid_negative_indices(self, device): @pytest.mark.parametrize("device", all_supported_devices()) def test_get_frames_by_pts(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) - _add_video_stream(decoder, device=device) + device, device_variant = unsplit_device_str(device) + add_video_stream(decoder, device=device, device_variant=device_variant) # Note: 13.01 should give the last video frame for the NASA video timestamps = [2, 0, 1, 0 + 1e-3, 13.01, 2 + 1e-3] @@ -246,7 +257,8 @@ def test_pts_apis_against_index_ref(self, device): # APIs exactly where those frames are supposed to start. We assert that # we get the expected frame. decoder = create_from_file(str(NASA_VIDEO.path)) - add_video_stream(decoder, device=device) + device, device_variant = unsplit_device_str(device) + add_video_stream(decoder, device=device, device_variant=device_variant) metadata = get_json_metadata(decoder) metadata_dict = json.loads(metadata) @@ -297,7 +309,8 @@ def test_pts_apis_against_index_ref(self, device): @pytest.mark.parametrize("device", all_supported_devices()) def test_get_frames_in_range(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) - add_video_stream(decoder, device=device) + device, device_variant = unsplit_device_str(device) + add_video_stream(decoder, device=device, device_variant=device_variant) # ensure that the degenerate case of a range of size 1 works ref_frame0 = NASA_VIDEO.get_frame_data_by_range(0, 1) @@ -337,7 +350,8 @@ def test_get_frames_in_range(self, device): @pytest.mark.parametrize("device", all_supported_devices()) def test_throws_exception_at_eof(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) - add_video_stream(decoder, device=device) + device, device_variant = unsplit_device_str(device) + add_video_stream(decoder, device=device, device_variant=device_variant) seek_to_pts(decoder, 12.979633) last_frame, _, _ = get_next_frame(decoder) @@ -352,7 +366,8 @@ def test_throws_exception_at_eof(self, device): @pytest.mark.parametrize("device", all_supported_devices()) def test_throws_exception_if_seek_too_far(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) - add_video_stream(decoder, device=device) + device, device_variant = unsplit_device_str(device) + add_video_stream(decoder, device=device, device_variant=device_variant) # pts=12.979633 is the last frame in the video. seek_to_pts(decoder, 12.979633 + 1.0e-4) with pytest.raises(IndexError, match="no more frames"): @@ -363,9 +378,11 @@ def test_compile_seek_and_next(self, device): # TODO_OPEN_ISSUE Scott (T180277797): Get this to work with the inductor stack. Right now # compilation fails because it can't handle tensors of size unknown at # compile-time. + device, device_variant = unsplit_device_str(device) + @torch.compile(fullgraph=True, backend="eager") def get_frame1_and_frame_time6(decoder): - add_video_stream(decoder, device=device) + add_video_stream(decoder, device=device, device_variant=device_variant) frame0, _, _ = get_next_frame(decoder) seek_to_pts(decoder, 6.0) frame_time6, _, _ = get_next_frame(decoder) @@ -408,7 +425,8 @@ def test_create_decoder(self, create_from, device): else: raise ValueError("Oops, double check the parametrization of this test!") - add_video_stream(decoder, device=device) + device, device_variant = unsplit_device_str(device) + add_video_stream(decoder, device=device, device_variant=device_variant) frame0, _, _ = get_next_frame(decoder) reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) assert_frames_equal(frame0, reference_frame0.to(device)) @@ -536,9 +554,11 @@ def test_seek_mode_custom_frame_mappings(self, device): decoder = create_from_file( str(NASA_VIDEO.path), seek_mode="custom_frame_mappings" ) + device, device_variant = unsplit_device_str(device) add_video_stream( decoder, device=device, + device_variant=device_variant, stream_index=stream_index, custom_frame_mappings=NASA_VIDEO.get_custom_frame_mappings( stream_index=stream_index @@ -1077,7 +1097,8 @@ def seek(self, offset: int, whence: int) -> int: open(NASA_VIDEO.path, mode="rb", buffering=buffering) ) decoder = create_from_file_like(file_counter, "approximate") - add_video_stream(decoder, device=device) + device, device_variant = unsplit_device_str(device) + add_video_stream(decoder, device=device, device_variant=device_variant) frame0, *_ = get_next_frame(decoder) reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) diff --git a/test/utils.py b/test/utils.py index 644dc0bce..f26c013a7 100644 --- a/test/utils.py +++ b/test/utils.py @@ -27,7 +27,27 @@ def needs_cuda(test_item): def all_supported_devices(): - return ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda)) + return ( + "cpu", + pytest.param("cuda", marks=pytest.mark.needs_cuda), + pytest.param("cuda:0:beta", marks=pytest.mark.needs_cuda), + ) + + +def unsplit_device_str(device_str: str) -> str: + # helper meant to be used as + # device, device_variant = unsplit_device_str(device) + # when `device` comes from all_supported_devices() and may be "cuda:0:beta". + # It is used: + # - before calling `.to(device)` where device can't be "cuda:0:beta" + # - before calling add_video_stream(device=device, device_variant=device_variant) + # + # TODONVDEC P2: Find a less clunky way to test the BETA CUDA interface. It + # will ultimately depend on how we want to publicly expose it. + if device_str == "cuda:0:beta": + return "cuda", "beta" + else: + return device_str, "default" def get_ffmpeg_major_version():