Skip to content

Commit f9d80d4

Browse files
Dan-FloresDaniel Flores
andauthored
Add custom_frame_mappings to VideoDecoder init (#799)
Co-authored-by: Daniel Flores <[email protected]>
1 parent c8b6acf commit f9d80d4

File tree

5 files changed

+225
-40
lines changed

5 files changed

+225
-40
lines changed

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ void SingleStreamDecoder::addVideoStream(
506506
if (seekMode_ == SeekMode::custom_frame_mappings) {
507507
TORCH_CHECK(
508508
customFrameMappings.has_value(),
509-
"Please provide frame mappings when using custom_frame_mappings seek mode.");
509+
"Missing frame mappings when custom_frame_mappings seek mode is set.");
510510
readCustomFrameMappingsUpdateMetadataAndIndex(
511511
streamIndex, customFrameMappings.value());
512512
}

src/torchcodec/decoders/_video_decoder.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import io
8+
import json
89
import numbers
910
from pathlib import Path
1011
from typing import Literal, Optional, Tuple, Union
@@ -62,7 +63,25 @@ class VideoDecoder:
6263
probably is. Default: "exact".
6364
Read more about this parameter in:
6465
:ref:`sphx_glr_generated_examples_decoding_approximate_mode.py`
65-
66+
custom_frame_mappings (str, bytes, or file-like object, optional):
67+
Mapping of frames to their metadata, typically generated via ffprobe.
68+
This enables accurate frame seeking without requiring a full video scan.
69+
Do not set seek_mode when custom_frame_mappings is provided.
70+
Expected JSON format:
71+
72+
.. code-block:: json
73+
74+
{
75+
"frames": [
76+
{
77+
"pts": 0,
78+
"duration": 1001,
79+
"key_frame": 1
80+
}
81+
]
82+
}
83+
84+
Alternative field names "pkt_pts" and "pkt_duration" are also supported.
6685
6786
Attributes:
6887
metadata (VideoStreamMetadata): Metadata of the video stream.
@@ -80,6 +99,9 @@ def __init__(
8099
num_ffmpeg_threads: int = 1,
81100
device: Optional[Union[str, torch_device]] = "cpu",
82101
seek_mode: Literal["exact", "approximate"] = "exact",
102+
custom_frame_mappings: Optional[
103+
Union[str, bytes, io.RawIOBase, io.BufferedReader]
104+
] = None,
83105
):
84106
torch._C._log_api_usage_once("torchcodec.decoders.VideoDecoder")
85107
allowed_seek_modes = ("exact", "approximate")
@@ -89,6 +111,21 @@ def __init__(
89111
f"Supported values are {', '.join(allowed_seek_modes)}."
90112
)
91113

114+
# Validate seek_mode and custom_frame_mappings are not mismatched
115+
if custom_frame_mappings is not None and seek_mode == "approximate":
116+
raise ValueError(
117+
"custom_frame_mappings is incompatible with seek_mode='approximate'. "
118+
"Use seek_mode='custom_frame_mappings' or leave it unspecified to automatically use custom frame mappings."
119+
)
120+
121+
# Auto-select custom_frame_mappings seek_mode and process data when mappings are provided
122+
custom_frame_mappings_data = None
123+
if custom_frame_mappings is not None:
124+
seek_mode = "custom_frame_mappings" # type: ignore[assignment]
125+
custom_frame_mappings_data = _read_custom_frame_mappings(
126+
custom_frame_mappings
127+
)
128+
92129
self._decoder = create_decoder(source=source, seek_mode=seek_mode)
93130

94131
allowed_dimension_orders = ("NCHW", "NHWC")
@@ -110,6 +147,7 @@ def __init__(
110147
dimension_order=dimension_order,
111148
num_threads=num_ffmpeg_threads,
112149
device=device,
150+
custom_frame_mappings=custom_frame_mappings_data,
113151
)
114152

115153
(
@@ -379,3 +417,57 @@ def _get_and_validate_stream_metadata(
379417
end_stream_seconds,
380418
num_frames,
381419
)
420+
421+
422+
def _read_custom_frame_mappings(
423+
custom_frame_mappings: Union[str, bytes, io.RawIOBase, io.BufferedReader]
424+
) -> tuple[Tensor, Tensor, Tensor]:
425+
"""Parse custom frame mappings from JSON data and extract frame metadata.
426+
427+
Args:
428+
custom_frame_mappings: JSON data containing frame metadata, provided as:
429+
- A JSON string (str, bytes)
430+
- A file-like object with a read() method
431+
432+
Returns:
433+
A tuple of three tensors:
434+
- all_frames (Tensor): Presentation timestamps (PTS) for each frame
435+
- is_key_frame (Tensor): Boolean tensor indicating which frames are key frames
436+
- duration (Tensor): Duration of each frame
437+
"""
438+
try:
439+
input_data = (
440+
json.load(custom_frame_mappings)
441+
if hasattr(custom_frame_mappings, "read")
442+
else json.loads(custom_frame_mappings)
443+
)
444+
except json.JSONDecodeError as e:
445+
raise ValueError(
446+
f"Invalid custom frame mappings: {e}. It should be a valid JSON string or a file-like object."
447+
) from e
448+
449+
if not input_data or "frames" not in input_data:
450+
raise ValueError(
451+
"Invalid custom frame mappings. The input is empty or missing the required 'frames' key."
452+
)
453+
454+
first_frame = input_data["frames"][0]
455+
pts_key = next((key for key in ("pts", "pkt_pts") if key in first_frame), None)
456+
duration_key = next(
457+
(key for key in ("duration", "pkt_duration") if key in first_frame), None
458+
)
459+
key_frame_present = "key_frame" in first_frame
460+
461+
if not pts_key or not duration_key or not key_frame_present:
462+
raise ValueError(
463+
"Invalid custom frame mappings. The 'pts'/'pkt_pts', 'duration'/'pkt_duration', and 'key_frame' keys are required in the frame metadata."
464+
)
465+
466+
frame_data = [
467+
(float(frame[pts_key]), frame["key_frame"], float(frame[duration_key]))
468+
for frame in input_data["frames"]
469+
]
470+
all_frames, is_key_frame, duration = map(torch.tensor, zip(*frame_data))
471+
if not (len(all_frames) == len(is_key_frame) == len(duration)):
472+
raise ValueError("Mismatched lengths in frame index data")
473+
return all_frames, is_key_frame, duration

test/test_decoders.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import contextlib
88
import gc
99
import json
10+
from functools import partial
1011
from unittest.mock import patch
1112

1213
import numpy
@@ -1279,6 +1280,112 @@ def test_10bit_videos_cpu(self, asset):
12791280
decoder = VideoDecoder(asset.path)
12801281
decoder.get_frame_at(10)
12811282

1283+
def setup_frame_mappings(tmp_path, file, stream_index):
1284+
json_path = tmp_path / "custom_frame_mappings.json"
1285+
custom_frame_mappings = NASA_VIDEO.generate_custom_frame_mappings(stream_index)
1286+
if file:
1287+
# Write the custom frame mappings to a JSON file
1288+
with open(json_path, "w") as f:
1289+
f.write(custom_frame_mappings)
1290+
return json_path
1291+
else:
1292+
# Return the custom frame mappings as a JSON string
1293+
return custom_frame_mappings
1294+
1295+
@pytest.mark.parametrize("device", all_supported_devices())
1296+
@pytest.mark.parametrize("stream_index", [0, 3])
1297+
@pytest.mark.parametrize(
1298+
"method",
1299+
(
1300+
partial(setup_frame_mappings, file=True),
1301+
partial(setup_frame_mappings, file=False),
1302+
),
1303+
)
1304+
def test_custom_frame_mappings_json_and_bytes(
1305+
self, tmp_path, device, stream_index, method
1306+
):
1307+
custom_frame_mappings = method(tmp_path=tmp_path, stream_index=stream_index)
1308+
# Optionally open the custom frame mappings file if it is a file path
1309+
# or use a null context if it is a string.
1310+
with (
1311+
open(custom_frame_mappings, "r")
1312+
if hasattr(custom_frame_mappings, "read")
1313+
else contextlib.nullcontext()
1314+
) as custom_frame_mappings:
1315+
decoder = VideoDecoder(
1316+
NASA_VIDEO.path,
1317+
stream_index=stream_index,
1318+
device=device,
1319+
custom_frame_mappings=custom_frame_mappings,
1320+
)
1321+
frame_0 = decoder.get_frame_at(0)
1322+
frame_5 = decoder.get_frame_at(5)
1323+
assert_frames_equal(
1324+
frame_0.data,
1325+
NASA_VIDEO.get_frame_data_by_index(0, stream_index=stream_index).to(device),
1326+
)
1327+
assert_frames_equal(
1328+
frame_5.data,
1329+
NASA_VIDEO.get_frame_data_by_index(5, stream_index=stream_index).to(device),
1330+
)
1331+
frames0_5 = decoder.get_frames_played_in_range(
1332+
frame_0.pts_seconds, frame_5.pts_seconds
1333+
)
1334+
assert_frames_equal(
1335+
frames0_5.data,
1336+
NASA_VIDEO.get_frame_data_by_range(0, 5, stream_index=stream_index).to(
1337+
device
1338+
),
1339+
)
1340+
1341+
@pytest.mark.parametrize("device", all_supported_devices())
1342+
@pytest.mark.parametrize(
1343+
"custom_frame_mappings,expected_match",
1344+
[
1345+
(NASA_VIDEO.generate_custom_frame_mappings(0), "seek_mode"),
1346+
("{}", "The input is empty or missing the required 'frames' key."),
1347+
(
1348+
'{"valid": "json"}',
1349+
"The input is empty or missing the required 'frames' key.",
1350+
),
1351+
(
1352+
'{"frames": [{"missing": "keys"}]}',
1353+
"keys are required in the frame metadata.",
1354+
),
1355+
],
1356+
)
1357+
def test_custom_frame_mappings_init_fails(
1358+
self, device, custom_frame_mappings, expected_match
1359+
):
1360+
with pytest.raises(ValueError, match=expected_match):
1361+
VideoDecoder(
1362+
NASA_VIDEO.path,
1363+
stream_index=0,
1364+
device=device,
1365+
custom_frame_mappings=custom_frame_mappings,
1366+
seek_mode=("approximate" if expected_match == "seek_mode" else "exact"),
1367+
)
1368+
1369+
@pytest.mark.parametrize("device", all_supported_devices())
1370+
def test_custom_frame_mappings_init_fails_invalid_json(self, tmp_path, device):
1371+
invalid_json_path = tmp_path / "invalid_json"
1372+
with open(invalid_json_path, "w+") as f:
1373+
f.write("invalid input")
1374+
1375+
# Test both file object and string
1376+
with open(invalid_json_path, "r") as file_obj:
1377+
for custom_frame_mappings in [
1378+
file_obj,
1379+
file_obj.read(),
1380+
]:
1381+
with pytest.raises(ValueError, match="Invalid custom frame mappings"):
1382+
VideoDecoder(
1383+
NASA_VIDEO.path,
1384+
stream_index=0,
1385+
device=device,
1386+
custom_frame_mappings=custom_frame_mappings,
1387+
)
1388+
12821389

12831390
class TestAudioDecoder:
12841391
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3, SINE_MONO_S32))

test/test_ops.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
from .utils import (
4545
all_supported_devices,
4646
assert_frames_equal,
47-
get_ffmpeg_major_version,
4847
NASA_AUDIO,
4948
NASA_AUDIO_MP3,
5049
NASA_VIDEO,
@@ -485,7 +484,7 @@ def test_seek_mode_custom_frame_mappings_fails(self):
485484
)
486485
with pytest.raises(
487486
RuntimeError,
488-
match="Please provide frame mappings when using custom_frame_mappings seek mode.",
487+
match="Missing frame mappings when custom_frame_mappings seek mode is set.",
489488
):
490489
add_video_stream(decoder, stream_index=0, custom_frame_mappings=None)
491490

@@ -505,10 +504,6 @@ def test_seek_mode_custom_frame_mappings_fails(self):
505504
decoder, stream_index=0, custom_frame_mappings=different_lengths
506505
)
507506

508-
@pytest.mark.skipif(
509-
get_ffmpeg_major_version() in (4, 5),
510-
reason="ffprobe isn't accurate on ffmpeg 4 and 5",
511-
)
512507
@pytest.mark.parametrize("device", all_supported_devices())
513508
def test_seek_mode_custom_frame_mappings(self, device):
514509
stream_index = 3 # custom_frame_index seek mode requires a stream index

test/utils.py

Lines changed: 23 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import torch
1515

1616
from torchcodec._core import get_ffmpeg_library_versions
17+
from torchcodec.decoders._video_decoder import _read_custom_frame_mappings
1718

1819

1920
# Decorator for skipping CUDA tests when CUDA isn't available. The tests are
@@ -267,40 +268,30 @@ def get_custom_frame_mappings(
267268
if stream_index is None:
268269
stream_index = self.default_stream_index
269270
if self._custom_frame_mappings_data.get(stream_index) is None:
270-
self.generate_custom_frame_mappings(stream_index)
271+
self._custom_frame_mappings_data[stream_index] = (
272+
_read_custom_frame_mappings(
273+
self.generate_custom_frame_mappings(stream_index)
274+
)
275+
)
271276
return self._custom_frame_mappings_data[stream_index]
272277

273-
def generate_custom_frame_mappings(self, stream_index: int) -> None:
274-
result = json.loads(
275-
subprocess.run(
276-
[
277-
"ffprobe",
278-
"-i",
279-
f"{self.path}",
280-
"-select_streams",
281-
f"{stream_index}",
282-
"-show_frames",
283-
"-of",
284-
"json",
285-
],
286-
check=True,
287-
capture_output=True,
288-
text=True,
289-
).stdout
290-
)
291-
all_frames = torch.tensor([float(frame["pts"]) for frame in result["frames"]])
292-
is_key_frame = torch.tensor([frame["key_frame"] for frame in result["frames"]])
293-
duration = torch.tensor(
294-
[float(frame["duration"]) for frame in result["frames"]]
295-
)
296-
assert (
297-
len(all_frames) == len(is_key_frame) == len(duration)
298-
), "Mismatched lengths in frame index data"
299-
self._custom_frame_mappings_data[stream_index] = (
300-
all_frames,
301-
is_key_frame,
302-
duration,
303-
)
278+
def generate_custom_frame_mappings(self, stream_index: int) -> str:
279+
result = subprocess.run(
280+
[
281+
"ffprobe",
282+
"-i",
283+
f"{self.path}",
284+
"-select_streams",
285+
f"{stream_index}",
286+
"-show_frames",
287+
"-of",
288+
"json",
289+
],
290+
check=True,
291+
capture_output=True,
292+
text=True,
293+
).stdout
294+
return result
304295

305296
@property
306297
def empty_pts_seconds(self) -> torch.Tensor:

0 commit comments

Comments
 (0)