|
21 | 21 | _test_frame_pts_equality, |
22 | 22 | add_audio_stream, |
23 | 23 | add_video_stream, |
| 24 | + create_encoder, |
24 | 25 | create_from_bytes, |
25 | 26 | create_from_file, |
26 | 27 | create_from_file_like, |
27 | 28 | create_from_tensor, |
| 29 | + encode, |
28 | 30 | get_ffmpeg_library_versions, |
29 | 31 | get_frame_at_index, |
30 | 32 | get_frame_at_pts, |
|
48 | 50 | SINE_MONO_S32, |
49 | 51 | SINE_MONO_S32_44100, |
50 | 52 | SINE_MONO_S32_8000, |
| 53 | + TestContainerFile, |
51 | 54 | ) |
52 | 55 |
|
53 | 56 | torch._dynamo.config.capture_dynamic_output_shape_ops = True |
54 | 57 |
|
55 | 58 | INDEX_OF_FRAME_AT_6_SECONDS = 180 |
56 | 59 |
|
57 | 60 |
|
58 | | -class TestVideoOps: |
| 61 | +class TestVideoDecoderOps: |
59 | 62 | @pytest.mark.parametrize("device", cpu_and_cuda()) |
60 | 63 | def test_seek_and_next(self, device): |
61 | 64 | decoder = create_from_file(str(NASA_VIDEO.path)) |
@@ -632,7 +635,7 @@ def test_cuda_decoder(self): |
632 | 635 | ) |
633 | 636 |
|
634 | 637 |
|
635 | | -class TestAudioOps: |
| 638 | +class TestAudioDecoderOps: |
636 | 639 | @pytest.mark.parametrize( |
637 | 640 | "method", |
638 | 641 | ( |
@@ -923,5 +926,33 @@ def get_all_frames(asset, sample_rate=None, stop_seconds=None): |
923 | 926 | torch.testing.assert_close(frames_downsampled_to_8000, frames_8000_native) |
924 | 927 |
|
925 | 928 |
|
| 929 | +class TestAudioEncoderOps: |
| 930 | + |
| 931 | + def decode(self, source) -> torch.Tensor: |
| 932 | + if isinstance(source, TestContainerFile): |
| 933 | + source = str(source.path) |
| 934 | + else: |
| 935 | + source = str(source) |
| 936 | + decoder = create_from_file(source, seek_mode="approximate") |
| 937 | + add_audio_stream(decoder) |
| 938 | + frames, *_ = get_frames_by_pts_in_range_audio( |
| 939 | + decoder, start_seconds=0, stop_seconds=None |
| 940 | + ) |
| 941 | + return frames |
| 942 | + |
| 943 | + def test_round_trip(self, tmp_path): |
| 944 | + asset = SINE_MONO_S32 |
| 945 | + source_samples = self.decode(asset) |
| 946 | + |
| 947 | + output_file = tmp_path / "output.mp3" |
| 948 | + encoder = create_encoder( |
| 949 | + sample_rate=asset.sample_rate, filename=str(output_file) |
| 950 | + ) |
| 951 | + encode(encoder, source_samples) |
| 952 | + |
| 953 | + round_trip_samples = self.decode(output_file) |
| 954 | + torch.testing.assert_close(source_samples, round_trip_samples) |
| 955 | + |
| 956 | + |
926 | 957 | if __name__ == "__main__": |
927 | 958 | pytest.main() |
0 commit comments