|
40 | 40 | assert_tensor_close_on_at_least, |
41 | 41 | assert_tensor_equal, |
42 | 42 | cpu_and_cuda, |
43 | | - get_tensor_compare_function, |
| 43 | + get_frame_compare_function, |
44 | 44 | NASA_AUDIO, |
45 | 45 | NASA_VIDEO, |
46 | 46 | needs_cuda, |
@@ -133,7 +133,7 @@ def test_get_frame_with_info_at_index(self): |
133 | 133 |
|
134 | 134 | @pytest.mark.parametrize("device", cpu_and_cuda()) |
135 | 135 | def test_get_frames_at_indices(self, device): |
136 | | - tensor_compare_function = get_tensor_compare_function(device) |
| 136 | + tensor_compare_function = get_frame_compare_function(device) |
137 | 137 | decoder = create_from_file(str(NASA_VIDEO.path)) |
138 | 138 | scan_all_streams_to_update_metadata(decoder) |
139 | 139 | add_video_stream(decoder, device=device) |
@@ -210,37 +210,6 @@ def test_get_frames_by_pts(self, device): |
210 | 210 | with pytest.raises(AssertionError): |
211 | 211 | assert_tensor_equal(frames[0], frames[-1]) |
212 | 212 |
|
213 | | - @pytest.mark.parametrize("device", cpu_and_cuda()) |
214 | | - def test_get_frames_by_pts_with_cuda(self, device): |
215 | | - decoder = create_from_file(str(NASA_VIDEO.path)) |
216 | | - _add_video_stream(decoder, device=device) |
217 | | - scan_all_streams_to_update_metadata(decoder) |
218 | | - stream_index = 3 |
219 | | - |
220 | | - # Note: 13.01 should give the last video frame for the NASA video |
221 | | - timestamps = [2, 0, 1, 0 + 1e-3, 13.01, 2 + 1e-3] |
222 | | - |
223 | | - expected_frames = [ |
224 | | - get_frame_at_pts(decoder, seconds=pts)[0] for pts in timestamps |
225 | | - ] |
226 | | - |
227 | | - frames, *_ = get_frames_by_pts( |
228 | | - decoder, |
229 | | - stream_index=stream_index, |
230 | | - timestamps=timestamps, |
231 | | - ) |
232 | | - for frame, expected_frame in zip(frames, expected_frames): |
233 | | - assert_tensor_equal(frame, expected_frame) |
234 | | - |
235 | | - # first and last frame should be equal, at pts=2 [+ eps]. We then modify |
236 | | - # the first frame and assert that it's now different from the last |
237 | | - # frame. This ensures a copy was properly made during the de-duplication |
238 | | - # logic. |
239 | | - assert_tensor_equal(frames[0], frames[-1]) |
240 | | - frames[0] += 20 |
241 | | - with pytest.raises(AssertionError): |
242 | | - assert_tensor_equal(frames[0], frames[-1]) |
243 | | - |
244 | 213 | @pytest.mark.parametrize("device", cpu_and_cuda()) |
245 | 214 | def test_pts_apis_against_index_ref(self, device): |
246 | 215 | # Non-regression test for https://github.com/pytorch/torchcodec/pull/287 |
@@ -304,7 +273,7 @@ def test_pts_apis_against_index_ref(self, device): |
304 | 273 |
|
305 | 274 | @pytest.mark.parametrize("device", cpu_and_cuda()) |
306 | 275 | def test_get_frames_in_range(self, device): |
307 | | - tensor_compare_function = get_tensor_compare_function(device) |
| 276 | + tensor_compare_function = get_frame_compare_function(device) |
308 | 277 | decoder = create_from_file(str(NASA_VIDEO.path)) |
309 | 278 | scan_all_streams_to_update_metadata(decoder) |
310 | 279 | add_video_stream(decoder, device=device) |
|
0 commit comments