|
27 | 27 | get_frame_at_index, |
28 | 28 | get_frame_at_pts, |
29 | 29 | get_frames_at_indices, |
| 30 | + get_frames_by_pts_in_range, |
30 | 31 | get_frames_in_range, |
31 | 32 | get_json_metadata, |
32 | 33 | get_next_frame, |
@@ -383,6 +384,48 @@ def test_color_conversion_library_with_scaling( |
383 | 384 | swscale_frame0, _, _ = get_next_frame(swscale_decoder) |
384 | 385 | assert_tensor_equal(filtergraph_frame0, swscale_frame0) |
385 | 386 |
|
| 387 | + @pytest.mark.parametrize("dimension_order", ("NHWC", "NCHW")) |
| 388 | + @pytest.mark.parametrize("color_conversion_library", ("filtergraph", "swscale")) |
| 389 | + def test_color_conversion_library_with_dimension_order( |
| 390 | + self, dimension_order, color_conversion_library |
| 391 | + ): |
| 392 | + decoder = create_from_file(str(NASA_VIDEO.path)) |
| 393 | + _add_video_stream( |
| 394 | + decoder, |
| 395 | + color_conversion_library=color_conversion_library, |
| 396 | + dimension_order=dimension_order, |
| 397 | + ) |
| 398 | + scan_all_streams_to_update_metadata(decoder) |
| 399 | + |
| 400 | + frame0_ref = NASA_VIDEO.get_frame_data_by_index(0) |
| 401 | + C, H, W = frame0_ref.shape |
| 402 | + if dimension_order == "NHWC": |
| 403 | + frame0_ref = frame0_ref.permute(1, 2, 0) |
| 404 | + expected_shape = frame0_ref.shape |
| 405 | + |
| 406 | + stream_index = 3 |
| 407 | + frame0, *_ = get_frame_at_index( |
| 408 | + decoder, stream_index=stream_index, frame_index=0 |
| 409 | + ) |
| 410 | + assert frame0.shape == expected_shape |
| 411 | + assert_tensor_equal(frame0, frame0_ref) |
| 412 | + |
| 413 | + frame0, *_ = get_frame_at_pts(decoder, seconds=0.0) |
| 414 | + assert frame0.shape == expected_shape |
| 415 | + assert_tensor_equal(frame0, frame0_ref) |
| 416 | + |
| 417 | + frames, *_ = get_frames_in_range( |
| 418 | + decoder, stream_index=stream_index, start=0, stop=3 |
| 419 | + ) |
| 420 | + assert frames.shape[1:] == expected_shape |
| 421 | + assert_tensor_equal(frames[0], frame0_ref) |
| 422 | + |
| 423 | + frames, *_ = get_frames_by_pts_in_range( |
| 424 | + decoder, stream_index=stream_index, start_seconds=0, stop_seconds=1 |
| 425 | + ) |
| 426 | + assert frames.shape[1:] == expected_shape |
| 427 | + assert_tensor_equal(frames[0], frame0_ref) |
| 428 | + |
386 | 429 | @pytest.mark.parametrize( |
387 | 430 | "width_scaling_factor,height_scaling_factor", |
388 | 431 | ((1.31, 1.5), (0.71, 0.5), (1.31, 0.7), (0.71, 1.5), (1.0, 1.0)), |
|
0 commit comments