Skip to content

Commit e534d5d

Browse files
committed
.
1 parent e448666 commit e534d5d

File tree

2 files changed

+8
-39
lines changed

2 files changed

+8
-39
lines changed

test/decoders/test_video_decoder_ops.py

Lines changed: 3 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
assert_tensor_close_on_at_least,
4141
assert_tensor_equal,
4242
cpu_and_cuda,
43-
get_tensor_compare_function,
43+
get_frame_compare_function,
4444
NASA_AUDIO,
4545
NASA_VIDEO,
4646
needs_cuda,
@@ -133,7 +133,7 @@ def test_get_frame_with_info_at_index(self):
133133

134134
@pytest.mark.parametrize("device", cpu_and_cuda())
135135
def test_get_frames_at_indices(self, device):
136-
tensor_compare_function = get_tensor_compare_function(device)
136+
tensor_compare_function = get_frame_compare_function(device)
137137
decoder = create_from_file(str(NASA_VIDEO.path))
138138
scan_all_streams_to_update_metadata(decoder)
139139
add_video_stream(decoder, device=device)
@@ -210,37 +210,6 @@ def test_get_frames_by_pts(self, device):
210210
with pytest.raises(AssertionError):
211211
assert_tensor_equal(frames[0], frames[-1])
212212

213-
@pytest.mark.parametrize("device", cpu_and_cuda())
214-
def test_get_frames_by_pts_with_cuda(self, device):
215-
decoder = create_from_file(str(NASA_VIDEO.path))
216-
_add_video_stream(decoder, device=device)
217-
scan_all_streams_to_update_metadata(decoder)
218-
stream_index = 3
219-
220-
# Note: 13.01 should give the last video frame for the NASA video
221-
timestamps = [2, 0, 1, 0 + 1e-3, 13.01, 2 + 1e-3]
222-
223-
expected_frames = [
224-
get_frame_at_pts(decoder, seconds=pts)[0] for pts in timestamps
225-
]
226-
227-
frames, *_ = get_frames_by_pts(
228-
decoder,
229-
stream_index=stream_index,
230-
timestamps=timestamps,
231-
)
232-
for frame, expected_frame in zip(frames, expected_frames):
233-
assert_tensor_equal(frame, expected_frame)
234-
235-
# first and last frame should be equal, at pts=2 [+ eps]. We then modify
236-
# the first frame and assert that it's now different from the last
237-
# frame. This ensures a copy was properly made during the de-duplication
238-
# logic.
239-
assert_tensor_equal(frames[0], frames[-1])
240-
frames[0] += 20
241-
with pytest.raises(AssertionError):
242-
assert_tensor_equal(frames[0], frames[-1])
243-
244213
@pytest.mark.parametrize("device", cpu_and_cuda())
245214
def test_pts_apis_against_index_ref(self, device):
246215
# Non-regression test for https://github.com/pytorch/torchcodec/pull/287
@@ -304,7 +273,7 @@ def test_pts_apis_against_index_ref(self, device):
304273

305274
@pytest.mark.parametrize("device", cpu_and_cuda())
306275
def test_get_frames_in_range(self, device):
307-
tensor_compare_function = get_tensor_compare_function(device)
276+
tensor_compare_function = get_frame_compare_function(device)
308277
decoder = create_from_file(str(NASA_VIDEO.path))
309278
scan_all_streams_to_update_metadata(decoder)
310279
add_video_stream(decoder, device=device)

test/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def cpu_and_cuda():
2323
return ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda))
2424

2525

26-
def get_tensor_compare_function(device):
26+
def get_frame_compare_function(device):
2727
if device == "cpu":
2828
return assert_tensor_equal
2929
else:
@@ -43,10 +43,10 @@ def assert_tensor_equal(*args, **kwargs):
4343

4444

4545
# Asserts that at least `percentage`% of the values are within the absolute tolerance.
46-
def assert_tensor_close_on_at_least(frame1, frame2, percentage=90, abs_tolerance=20):
47-
frame1 = frame1.to("cpu")
48-
frame2 = frame2.to("cpu")
49-
diff = (frame2.float() - frame1.float()).abs()
46+
def assert_tensor_close_on_at_least(tensor1, tensor2, percentage=90, abs_tolerance=20):
47+
tensor1 = tensor1.to("cpu")
48+
tensor2 = tensor2.to("cpu")
49+
diff = (tensor2.float() - tensor1.float()).abs()
5050
max_diff_percentage = 100.0 - percentage
5151
if diff.sum() == 0:
5252
return

0 commit comments

Comments
 (0)