5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
import io
8
+ import json
8
9
import numbers
9
10
from pathlib import Path
10
11
from typing import Literal , Optional , Tuple , Union
@@ -62,7 +63,25 @@ class VideoDecoder:
62
63
probably is. Default: "exact".
63
64
Read more about this parameter in:
64
65
:ref:`sphx_glr_generated_examples_decoding_approximate_mode.py`
65
-
66
+ custom_frame_mappings (str, bytes, or file-like object, optional):
67
+ Mapping of frames to their metadata, typically generated via ffprobe.
68
+ This enables accurate frame seeking without requiring a full video scan.
69
+ Do not set seek_mode when custom_frame_mappings is provided.
70
+ Expected JSON format:
71
+
72
+ .. code-block:: json
73
+
74
+ {
75
+ "frames": [
76
+ {
77
+ "pts": 0,
78
+ "duration": 1001,
79
+ "key_frame": 1
80
+ }
81
+ ]
82
+ }
83
+
84
+ Alternative field names "pkt_pts" and "pkt_duration" are also supported.
66
85
67
86
Attributes:
68
87
metadata (VideoStreamMetadata): Metadata of the video stream.
@@ -80,6 +99,9 @@ def __init__(
80
99
num_ffmpeg_threads : int = 1 ,
81
100
device : Optional [Union [str , torch_device ]] = "cpu" ,
82
101
seek_mode : Literal ["exact" , "approximate" ] = "exact" ,
102
+ custom_frame_mappings : Optional [
103
+ Union [str , bytes , io .RawIOBase , io .BufferedReader ]
104
+ ] = None ,
83
105
):
84
106
torch ._C ._log_api_usage_once ("torchcodec.decoders.VideoDecoder" )
85
107
allowed_seek_modes = ("exact" , "approximate" )
@@ -89,6 +111,21 @@ def __init__(
89
111
f"Supported values are { ', ' .join (allowed_seek_modes )} ."
90
112
)
91
113
114
+ # Validate seek_mode and custom_frame_mappings are not mismatched
115
+ if custom_frame_mappings is not None and seek_mode == "approximate" :
116
+ raise ValueError (
117
+ "custom_frame_mappings is incompatible with seek_mode='approximate'. "
118
+ "Use seek_mode='custom_frame_mappings' or leave it unspecified to automatically use custom frame mappings."
119
+ )
120
+
121
+ # Auto-select custom_frame_mappings seek_mode and process data when mappings are provided
122
+ custom_frame_mappings_data = None
123
+ if custom_frame_mappings is not None :
124
+ seek_mode = "custom_frame_mappings" # type: ignore[assignment]
125
+ custom_frame_mappings_data = _read_custom_frame_mappings (
126
+ custom_frame_mappings
127
+ )
128
+
92
129
self ._decoder = create_decoder (source = source , seek_mode = seek_mode )
93
130
94
131
allowed_dimension_orders = ("NCHW" , "NHWC" )
@@ -110,6 +147,7 @@ def __init__(
110
147
dimension_order = dimension_order ,
111
148
num_threads = num_ffmpeg_threads ,
112
149
device = device ,
150
+ custom_frame_mappings = custom_frame_mappings_data ,
113
151
)
114
152
115
153
(
@@ -379,3 +417,57 @@ def _get_and_validate_stream_metadata(
379
417
end_stream_seconds ,
380
418
num_frames ,
381
419
)
420
+
421
+
422
+ def _read_custom_frame_mappings (
423
+ custom_frame_mappings : Union [str , bytes , io .RawIOBase , io .BufferedReader ]
424
+ ) -> tuple [Tensor , Tensor , Tensor ]:
425
+ """Parse custom frame mappings from JSON data and extract frame metadata.
426
+
427
+ Args:
428
+ custom_frame_mappings: JSON data containing frame metadata, provided as:
429
+ - A JSON string (str, bytes)
430
+ - A file-like object with a read() method
431
+
432
+ Returns:
433
+ A tuple of three tensors:
434
+ - all_frames (Tensor): Presentation timestamps (PTS) for each frame
435
+ - is_key_frame (Tensor): Boolean tensor indicating which frames are key frames
436
+ - duration (Tensor): Duration of each frame
437
+ """
438
+ try :
439
+ input_data = (
440
+ json .load (custom_frame_mappings )
441
+ if hasattr (custom_frame_mappings , "read" )
442
+ else json .loads (custom_frame_mappings )
443
+ )
444
+ except json .JSONDecodeError as e :
445
+ raise ValueError (
446
+ f"Invalid custom frame mappings: { e } . It should be a valid JSON string or a file-like object."
447
+ ) from e
448
+
449
+ if not input_data or "frames" not in input_data :
450
+ raise ValueError (
451
+ "Invalid custom frame mappings. The input is empty or missing the required 'frames' key."
452
+ )
453
+
454
+ first_frame = input_data ["frames" ][0 ]
455
+ pts_key = next ((key for key in ("pts" , "pkt_pts" ) if key in first_frame ), None )
456
+ duration_key = next (
457
+ (key for key in ("duration" , "pkt_duration" ) if key in first_frame ), None
458
+ )
459
+ key_frame_present = "key_frame" in first_frame
460
+
461
+ if not pts_key or not duration_key or not key_frame_present :
462
+ raise ValueError (
463
+ "Invalid custom frame mappings. The 'pts'/'pkt_pts', 'duration'/'pkt_duration', and 'key_frame' keys are required in the frame metadata."
464
+ )
465
+
466
+ frame_data = [
467
+ (float (frame [pts_key ]), frame ["key_frame" ], float (frame [duration_key ]))
468
+ for frame in input_data ["frames" ]
469
+ ]
470
+ all_frames , is_key_frame , duration = map (torch .tensor , zip (* frame_data ))
471
+ if not (len (all_frames ) == len (is_key_frame ) == len (duration )):
472
+ raise ValueError ("Mismatched lengths in frame index data" )
473
+ return all_frames , is_key_frame , duration
0 commit comments