|
4 | 4 | # This source code is licensed under the BSD-style license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
7 | | -import os |
8 | 7 | import json |
9 | | -import pathlib |
| 8 | + |
10 | 9 | import pytest |
11 | | -import numpy as np |
12 | 10 | import torch |
| 11 | +from torchcodec import Frame, FrameBatch |
13 | 12 |
|
14 | 13 | from torchcodec._core import ( |
15 | | - create_from_file, |
16 | 14 | add_video_stream, |
17 | | - get_next_frame, |
18 | | - get_frame_at_index, |
| 15 | + create_from_file, |
19 | 16 | get_frame_at_pts, |
20 | | - get_frames_in_range, |
21 | 17 | get_json_metadata, |
| 18 | + get_next_frame, |
22 | 19 | seek_to_pts, |
23 | 20 | ) |
24 | 21 |
|
25 | | -from torchcodec.decoders import VideoDecoder,VideoStreamMetadata |
26 | | -from torchcodec import Frame, FrameBatch |
| 22 | +from torchcodec.decoders import VideoDecoder, VideoStreamMetadata |
27 | 23 |
|
28 | | -from .utils import ( |
29 | | - assert_frames_equal, |
30 | | - cpu_and_cuda, |
31 | | - TestVideo, |
32 | | - TestVideoStreamInfo, |
33 | | - TestFrameInfo, |
34 | | - _get_file_path, |
35 | | - VAR_FPS_VIDEO, |
36 | | -) |
| 24 | +from .utils import cpu_and_cuda, VAR_FPS_VIDEO |
37 | 25 |
|
38 | | -class TestVariableFPSVideoDecoder: |
39 | | - def _check_video_exists(self): |
40 | | - try: |
41 | | - VAR_FPS_VIDEO.path |
42 | | - except FileNotFoundError: |
43 | | - pytest.skip("Variable FPS test video not found") |
44 | 26 |
|
| 27 | +class TestVariableFPSVideoDecoder: |
45 | 28 | @pytest.mark.parametrize("device", cpu_and_cuda()) |
46 | 29 | def test_basic_decoding(self, device): |
47 | | - self._check_video_exists() |
48 | | - |
| 30 | + |
49 | 31 | decoder = VideoDecoder(str(VAR_FPS_VIDEO.path)) |
50 | | - |
| 32 | + |
51 | 33 | frame = decoder.get_frame_at(0) |
52 | 34 | assert isinstance(frame, Frame) |
53 | | - |
| 35 | + |
54 | 36 | metadata = decoder.metadata |
55 | 37 | assert isinstance(metadata, VideoStreamMetadata) |
56 | 38 | assert metadata.num_frames > 30 |
57 | 39 |
|
58 | 40 | @pytest.mark.parametrize("device", cpu_and_cuda()) |
59 | 41 | def test_exact_seeking_mode(self, device): |
60 | | - self._check_video_exists() |
61 | | - |
| 42 | + |
62 | 43 | decoder = VideoDecoder(str(VAR_FPS_VIDEO.path), seek_mode="exact") |
63 | | - |
| 44 | + |
64 | 45 | test_timestamps = [0.0, 0.5, 1.0, 1.5, 2.0] |
65 | 46 | for timestamp in test_timestamps: |
66 | 47 | if timestamp < decoder.metadata.duration_seconds: |
67 | 48 | frame_batch = decoder.get_frames_played_at(seconds=[timestamp]) |
68 | 49 | assert isinstance(frame_batch, FrameBatch) |
69 | | - assert abs(frame_batch.pts_seconds[0] - timestamp) <= frame_batch.duration_seconds[0] |
| 50 | + assert ( |
| 51 | + abs(frame_batch.pts_seconds[0] - timestamp) |
| 52 | + <= frame_batch.duration_seconds[0] |
| 53 | + ) |
70 | 54 |
|
71 | 55 | @pytest.mark.parametrize("device", cpu_and_cuda()) |
72 | 56 | def test_approximate_seeking_mode_behavior(self, device): |
73 | 57 | """Test behavior in approximate seeking mode (may fail as expected)""" |
74 | | - self._check_video_exists() |
75 | | - |
| 58 | + |
76 | 59 | # Create two decoders: one with exact mode, one with approximate mode |
77 | 60 | decoder_exact = create_from_file(str(VAR_FPS_VIDEO.path), seek_mode="exact") |
78 | 61 | add_video_stream(decoder_exact, device=device) |
79 | | - |
80 | | - decoder_approx = create_from_file(str(VAR_FPS_VIDEO.path), seek_mode="approximate") |
| 62 | + |
| 63 | + decoder_approx = create_from_file( |
| 64 | + str(VAR_FPS_VIDEO.path), seek_mode="approximate" |
| 65 | + ) |
81 | 66 | add_video_stream(decoder_approx, device=device) |
82 | 67 |
|
83 | 68 | metadata = get_json_metadata(decoder_exact) |
84 | 69 | metadata_dict = json.loads(metadata) |
85 | | - |
| 70 | + |
86 | 71 | # Compare seeking in both modes |
87 | 72 | test_pts = [0.0, 0.5, 1.0, 1.5, 2.0] |
88 | 73 | differences = [] |
89 | | - |
| 74 | + |
90 | 75 | for pts in test_pts: |
91 | 76 | if pts < metadata_dict["durationSeconds"]: |
92 | 77 | frame_exact, pts_exact, _ = get_frame_at_pts(decoder_exact, pts) |
93 | | - |
| 78 | + |
94 | 79 | try: |
95 | 80 | frame_approx, pts_approx, _ = get_frame_at_pts(decoder_approx, pts) |
96 | | - |
97 | | - differences.append({ |
98 | | - "seek_pts": pts, |
99 | | - "exact_pts": pts_exact.item(), |
100 | | - "approx_pts": pts_approx.item(), |
101 | | - "frames_match": torch.allclose(frame_exact, frame_approx), |
102 | | - "pts_difference": abs(pts_exact.item() - pts_approx.item()), |
103 | | - "approximate_failed": False |
104 | | - }) |
| 81 | + |
| 82 | + differences.append( |
| 83 | + { |
| 84 | + "seek_pts": pts, |
| 85 | + "exact_pts": pts_exact.item(), |
| 86 | + "approx_pts": pts_approx.item(), |
| 87 | + "frames_match": torch.allclose(frame_exact, frame_approx), |
| 88 | + "pts_difference": abs(pts_exact.item() - pts_approx.item()), |
| 89 | + "approximate_failed": False, |
| 90 | + } |
| 91 | + ) |
105 | 92 | except Exception as e: |
106 | | - differences.append({ |
107 | | - "seek_pts": pts, |
108 | | - "exact_pts": pts_exact.item(), |
109 | | - "error": str(e), |
110 | | - "approximate_failed": True |
111 | | - }) |
112 | | - |
| 93 | + differences.append( |
| 94 | + { |
| 95 | + "seek_pts": pts, |
| 96 | + "exact_pts": pts_exact.item(), |
| 97 | + "error": str(e), |
| 98 | + "approximate_failed": True, |
| 99 | + } |
| 100 | + ) |
| 101 | + |
113 | 102 | # Print differences (useful for debugging) |
114 | 103 | for diff in differences: |
115 | 104 | if diff["approximate_failed"]: |
116 | | - print(f"Seeking to {diff['seek_pts']}s failed in approximate mode: {diff['error']}") |
| 105 | + print( |
| 106 | + f"Seeking to {diff['seek_pts']}s failed in approximate mode: {diff['error']}" |
| 107 | + ) |
117 | 108 | else: |
118 | | - print(f"Seeking to {diff['seek_pts']}s: exact={diff['exact_pts']}, " |
119 | | - f"approx={diff['approx_pts']}, diff={diff['pts_difference']}, " |
120 | | - f"frames {'match' if diff['frames_match'] else 'differ'}") |
121 | | - |
| 109 | + print( |
| 110 | + f"Seeking to {diff['seek_pts']}s: exact={diff['exact_pts']}, " |
| 111 | + f"approx={diff['approx_pts']}, diff={diff['pts_difference']}, " |
| 112 | + f"frames {'match' if diff['frames_match'] else 'differ'}" |
| 113 | + ) |
| 114 | + |
122 | 115 | # No assertion as approximate mode is expected to potentially fail |
123 | 116 |
|
124 | 117 | @pytest.mark.parametrize("device", cpu_and_cuda()) |
125 | 118 | def test_frame_timing_pattern(self, device): |
126 | | - self._check_video_exists() |
127 | | - |
| 119 | + |
128 | 120 | decoder = create_from_file(str(VAR_FPS_VIDEO.path)) |
129 | 121 | add_video_stream(decoder, device=device) |
130 | | - |
| 122 | + |
131 | 123 | frames_info = [] |
132 | 124 | frame_index = 0 |
133 | | - |
| 125 | + |
134 | 126 | seek_to_pts(decoder, 0.0) |
135 | | - |
| 127 | + |
136 | 128 | while True: |
137 | 129 | try: |
138 | 130 | frame, pts, duration = get_next_frame(decoder) |
139 | | - frames_info.append({ |
140 | | - "index": frame_index, |
141 | | - "pts": pts.item(), |
142 | | - "duration": duration.item() |
143 | | - }) |
| 131 | + frames_info.append( |
| 132 | + { |
| 133 | + "index": frame_index, |
| 134 | + "pts": pts.item(), |
| 135 | + "duration": duration.item(), |
| 136 | + } |
| 137 | + ) |
144 | 138 | frame_index += 1 |
145 | 139 | except IndexError: |
146 | 140 | break |
147 | | - |
148 | | - assert len(frames_info) > 30, "Not enough frames to verify variable frame rate" |
149 | | - |
150 | | - intervals_before = [frames_info[i+1]["pts"] - frames_info[i]["pts"] |
151 | | - for i in range(min(30, len(frames_info)-1))] |
152 | | - |
153 | | - intervals_after = [frames_info[i+1]["pts"] - frames_info[i]["pts"] |
154 | | - for i in range(30, min(60, len(frames_info)-1))] |
155 | | - |
| 141 | + |
| 142 | + assert len(frames_info) > 30, "Not enough frames to verify variable frame rate" |
| 143 | + |
| 144 | + intervals_before = [ |
| 145 | + frames_info[i + 1]["pts"] - frames_info[i]["pts"] |
| 146 | + for i in range(min(30, len(frames_info) - 1)) |
| 147 | + ] |
| 148 | + |
| 149 | + intervals_after = [ |
| 150 | + frames_info[i + 1]["pts"] - frames_info[i]["pts"] |
| 151 | + for i in range(30, min(60, len(frames_info) - 1)) |
| 152 | + ] |
| 153 | + |
156 | 154 | if len(intervals_after) > 5: |
157 | 155 | avg_interval_before = sum(intervals_before) / len(intervals_before) |
158 | 156 | avg_interval_after = sum(intervals_after) / len(intervals_after) |
159 | | - |
| 157 | + |
160 | 158 | print(f"Average interval for first 30 frames: {avg_interval_before:.6f}s") |
161 | 159 | print(f"Average interval for subsequent frames: {avg_interval_after:.6f}s") |
162 | | - |
| 160 | + |
163 | 161 | expected_ratio = 0.5 |
164 | 162 | actual_ratio = avg_interval_before / avg_interval_after |
165 | | - |
| 163 | + |
166 | 164 | # Allow for some error |
167 | | - assert abs(actual_ratio - expected_ratio) < 0.2, \ |
168 | | - f"Interval ratio ({actual_ratio:.2f}) differs too much from expected ({expected_ratio})" |
| 165 | + assert ( |
| 166 | + abs(actual_ratio - expected_ratio) < 0.2 |
| 167 | + ), f"Interval ratio ({actual_ratio:.2f}) differs too much from expected ({expected_ratio})" |
169 | 168 |
|
170 | 169 | @pytest.mark.parametrize("device", cpu_and_cuda()) |
171 | 170 | def test_sequential_decoding(self, device): |
172 | | - self._check_video_exists() |
173 | | - |
| 171 | + |
174 | 172 | decoder = create_from_file(str(VAR_FPS_VIDEO.path)) |
175 | 173 | add_video_stream(decoder, device=device) |
176 | | - |
| 174 | + |
177 | 175 | seek_to_pts(decoder, 0.0) |
178 | | - |
| 176 | + |
179 | 177 | # Decode multiple frames and verify monotonically increasing timestamps |
180 | 178 | last_pts = -1.0 |
181 | 179 | for _ in range(50): |
182 | 180 | try: |
183 | 181 | frame, pts, duration = get_next_frame(decoder) |
184 | 182 | current_pts = pts.item() |
185 | | - |
186 | | - assert current_pts > last_pts, \ |
187 | | - f"Frame timestamps not monotonically increasing: current={current_pts}, previous={last_pts}" |
188 | | - |
| 183 | + |
| 184 | + assert ( |
| 185 | + current_pts > last_pts |
| 186 | + ), f"Frame timestamps not monotonically increasing: current={current_pts}, previous={last_pts}" |
| 187 | + |
189 | 188 | last_pts = current_pts |
190 | 189 | except IndexError: |
191 | 190 | break |
192 | 191 |
|
193 | 192 | @pytest.mark.parametrize("device", cpu_and_cuda()) |
194 | 193 | def test_frames_in_range(self, device): |
195 | | - self._check_video_exists() |
196 | | - |
| 194 | + |
197 | 195 | decoder = VideoDecoder(str(VAR_FPS_VIDEO.path)) |
198 | | - |
| 196 | + |
199 | 197 | frame_batch = decoder.get_frames_in_range(0, 10) |
200 | 198 | assert isinstance(frame_batch, FrameBatch) |
201 | 199 | assert len(frame_batch) == 10 |
202 | | - |
| 200 | + |
203 | 201 | timestamps = frame_batch.pts_seconds.tolist() |
204 | | - assert all(timestamps[i] < timestamps[i+1] for i in range(len(timestamps)-1)) |
| 202 | + assert all( |
| 203 | + timestamps[i] < timestamps[i + 1] for i in range(len(timestamps) - 1) |
| 204 | + ) |
0 commit comments