|
4 | 4 | # This source code is licensed under the BSD-style license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
7 | | -import contextlib |
8 | | - |
9 | 7 | import numpy |
10 | 8 | import pytest |
11 | 9 | import torch |
@@ -875,39 +873,3 @@ def test_get_key_frame_indices(self, device): |
875 | 873 | torch.testing.assert_close( |
876 | 874 | key_frame_indices, h265_reference_key_frame_indices, atol=0, rtol=0 |
877 | 875 | ) |
878 | | - |
879 | | - @pytest.mark.parametrize("device", cpu_and_cuda()) |
880 | | - def test_compile(self, device): |
881 | | - decoder = VideoDecoder(NASA_VIDEO.path, device=device) |
882 | | - |
883 | | - @contextlib.contextmanager |
884 | | - def restore_capture_scalar_outputs(): |
885 | | - try: |
886 | | - original = torch._dynamo.config.capture_scalar_outputs |
887 | | - yield |
888 | | - finally: |
889 | | - torch._dynamo.config.capture_scalar_outputs = original |
890 | | - |
891 | | - # TODO: We get a graph break because we call Tensor.item() to turn the |
892 | | - # tensors in FrameBatch into scalars. When we work on compilation and exportability, |
893 | | - # we should investigate. |
894 | | - with restore_capture_scalar_outputs(): |
895 | | - torch._dynamo.config.capture_scalar_outputs = True |
896 | | - |
897 | | - @torch.compile(fullgraph=True, backend="eager") |
898 | | - def get_some_frames(decoder): |
899 | | - frames = [] |
900 | | - frames.append(decoder.get_frame_at(1)) |
901 | | - frames.append(decoder.get_frame_at(3)) |
902 | | - frames.append(decoder.get_frame_at(5)) |
903 | | - return frames |
904 | | - |
905 | | - frames = get_some_frames(decoder) |
906 | | - |
907 | | - ref_frame1 = NASA_VIDEO.get_frame_data_by_index(1).to(device) |
908 | | - ref_frame3 = NASA_VIDEO.get_frame_data_by_index(3).to(device) |
909 | | - ref_frame5 = NASA_VIDEO.get_frame_data_by_index(5).to(device) |
910 | | - |
911 | | - assert_frames_equal(ref_frame1, frames[0].data) |
912 | | - assert_frames_equal(ref_frame3, frames[1].data) |
913 | | - assert_frames_equal(ref_frame5, frames[2].data) |
0 commit comments