Skip to content

Commit d87f09e

Browse files
committed
Avoid using hard-coded values in test
1 parent 814d75f commit d87f09e

File tree

2 files changed

+36
-18
lines changed

2 files changed

+36
-18
lines changed

test/decoders/test_video_decoder.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -302,9 +302,12 @@ def test_get_frame_at(self):
302302

303303
assert_tensor_equal(ref_frame9, frame9.data)
304304
assert isinstance(frame9.pts_seconds, float)
305-
assert frame9.pts_seconds == pytest.approx(0.3003)
305+
expected_frame_info = NASA_VIDEO.get_frame_info(9)
306+
assert frame9.pts_seconds == pytest.approx(expected_frame_info.pts_seconds)
306307
assert isinstance(frame9.duration_seconds, float)
307-
assert frame9.duration_seconds == pytest.approx(0.03337, rel=1e-3)
308+
assert frame9.duration_seconds == pytest.approx(
309+
expected_frame_info.duration_seconds, rel=1e-3
310+
)
308311

309312
# test numpy.int64
310313
frame9 = decoder.get_frame_at(numpy.int64(9))
@@ -344,22 +347,31 @@ def test_get_frame_at_fails(self):
344347
def test_get_frames_at(self):
345348
decoder = VideoDecoder(NASA_VIDEO.path)
346349

347-
indices = [35, 25]
348-
frames = decoder.get_frames_at(indices)
350+
frames = decoder.get_frames_at([35, 25])
349351

350352
assert isinstance(frames, FrameBatch)
351353

352-
for i in range(len(indices)):
353-
assert_tensor_equal(
354-
frames[i].data, NASA_VIDEO.get_frame_data_by_index(indices[i])
355-
)
354+
assert_tensor_equal(frames[0].data, NASA_VIDEO.get_frame_data_by_index(35))
355+
assert_tensor_equal(frames[1].data, NASA_VIDEO.get_frame_data_by_index(25))
356356

357-
expected_pts_seconds = torch.tensor([1.1678, 0.8342], dtype=torch.float64)
357+
expected_pts_seconds = torch.tensor(
358+
[
359+
NASA_VIDEO.get_frame_info(35).pts_seconds,
360+
NASA_VIDEO.get_frame_info(25).pts_seconds,
361+
],
362+
dtype=torch.float64,
363+
)
358364
torch.testing.assert_close(
359365
frames.pts_seconds, expected_pts_seconds, atol=1e-4, rtol=0
360366
)
361367

362-
expected_duration_seconds = torch.tensor([0.0334, 0.0334], dtype=torch.float64)
368+
expected_duration_seconds = torch.tensor(
369+
[
370+
NASA_VIDEO.get_frame_info(35).duration_seconds,
371+
NASA_VIDEO.get_frame_info(25).duration_seconds,
372+
],
373+
dtype=torch.float64,
374+
)
363375
torch.testing.assert_close(
364376
frames.duration_seconds, expected_duration_seconds, atol=1e-4, rtol=0
365377
)
@@ -404,27 +416,31 @@ def test_get_frame_displayed_at_fails(self):
404416
def test_get_frames_displayed_at(self):
405417

406418
decoder = VideoDecoder(NASA_VIDEO.path)
407-
ref_frame6 = NASA_VIDEO.get_frame_by_name("time6.000000")
408-
ref_frame10 = NASA_VIDEO.get_frame_by_name("time10.000000")
409419

410-
seconds = [6.02, 10.01, 6.01]
420+
# Note: We know the frame at ~0.84s has index 25, the one at 1.16s has
421+
# index 35. We use those indices as reference to test against.
422+
seconds = [0.84, 1.17, 0.85]
423+
reference_indices = [25, 35, 25]
411424
frames = decoder.get_frames_displayed_at(seconds)
412425

413426
assert isinstance(frames, FrameBatch)
414427

415-
assert_tensor_equal(frames.data[0], ref_frame6)
416-
assert_tensor_equal(frames.data[1], ref_frame10)
417-
assert_tensor_equal(frames.data[2], ref_frame6)
428+
for i in range(len(reference_indices)):
429+
assert_tensor_equal(
430+
frames.data[i], NASA_VIDEO.get_frame_data_by_index(reference_indices[i])
431+
)
418432

419433
expected_pts_seconds = torch.tensor(
420-
[6.0060, 10.0100, 6.0060], dtype=torch.float64
434+
[NASA_VIDEO.get_frame_info(i).pts_seconds for i in reference_indices],
435+
dtype=torch.float64,
421436
)
422437
torch.testing.assert_close(
423438
frames.pts_seconds, expected_pts_seconds, atol=1e-4, rtol=0
424439
)
425440

426441
expected_duration_seconds = torch.tensor(
427-
[0.0334, 0.0334, 0.0334], dtype=torch.float64
442+
[NASA_VIDEO.get_frame_info(i).duration_seconds for i in reference_indices],
443+
dtype=torch.float64,
428444
)
429445
torch.testing.assert_close(
430446
frames.duration_seconds, expected_duration_seconds, atol=1e-4, rtol=0

test/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,8 @@ def get_empty_chw_tensor(self, *, stream_index: int) -> torch.Tensor:
265265
8: TestFrameInfo(pts_seconds=0.266933, duration_seconds=0.033367),
266266
9: TestFrameInfo(pts_seconds=0.300300, duration_seconds=0.033367),
267267
10: TestFrameInfo(pts_seconds=0.333667, duration_seconds=0.033367),
268+
25: TestFrameInfo(pts_seconds=0.8342, duration_seconds=0.033367),
269+
35: TestFrameInfo(pts_seconds=1.1678, duration_seconds=0.033367),
268270
},
269271
},
270272
)

0 commit comments

Comments
 (0)