Skip to content

Commit 45fd0ec

Browse files
committed
Add (failing) round-trip test
1 parent 52d1753 commit 45fd0ec

File tree

2 files changed

+35
-2
lines changed

2 files changed

+35
-2
lines changed

test/decoders/test_ops.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@
2121
_test_frame_pts_equality,
2222
add_audio_stream,
2323
add_video_stream,
24+
create_encoder,
2425
create_from_bytes,
2526
create_from_file,
2627
create_from_file_like,
2728
create_from_tensor,
29+
encode,
2830
get_ffmpeg_library_versions,
2931
get_frame_at_index,
3032
get_frame_at_pts,
@@ -48,14 +50,15 @@
4850
SINE_MONO_S32,
4951
SINE_MONO_S32_44100,
5052
SINE_MONO_S32_8000,
53+
TestContainerFile,
5154
)
5255

5356
torch._dynamo.config.capture_dynamic_output_shape_ops = True
5457

5558
INDEX_OF_FRAME_AT_6_SECONDS = 180
5659

5760

58-
class TestVideoOps:
61+
class TestVideoDecoderOps:
5962
@pytest.mark.parametrize("device", cpu_and_cuda())
6063
def test_seek_and_next(self, device):
6164
decoder = create_from_file(str(NASA_VIDEO.path))
@@ -632,7 +635,7 @@ def test_cuda_decoder(self):
632635
)
633636

634637

635-
class TestAudioOps:
638+
class TestAudioDecoderOps:
636639
@pytest.mark.parametrize(
637640
"method",
638641
(
@@ -923,5 +926,33 @@ def get_all_frames(asset, sample_rate=None, stop_seconds=None):
923926
torch.testing.assert_close(frames_downsampled_to_8000, frames_8000_native)
924927

925928

929+
class TestAudioEncoderOps:
930+
931+
def decode(self, source) -> torch.Tensor:
932+
if isinstance(source, TestContainerFile):
933+
source = str(source.path)
934+
else:
935+
source = str(source)
936+
decoder = create_from_file(source, seek_mode="approximate")
937+
add_audio_stream(decoder)
938+
frames, *_ = get_frames_by_pts_in_range_audio(
939+
decoder, start_seconds=0, stop_seconds=None
940+
)
941+
return frames
942+
943+
def test_round_trip(self, tmp_path):
944+
asset = SINE_MONO_S32
945+
source_samples = self.decode(asset)
946+
947+
output_file = tmp_path / "output.mp3"
948+
encoder = create_encoder(
949+
sample_rate=asset.sample_rate, filename=str(output_file)
950+
)
951+
encode(encoder, source_samples)
952+
953+
round_trip_samples = self.decode(output_file)
954+
torch.testing.assert_close(source_samples, round_trip_samples)
955+
956+
926957
if __name__ == "__main__":
927958
pytest.main()

test/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ class TestAudioStreamInfo:
114114

115115
@dataclass
116116
class TestContainerFile:
117+
__test__ = False # prevents pytest from thinking this is a test class
118+
117119
filename: str
118120

119121
default_stream_index: int

0 commit comments

Comments
 (0)