@@ -1162,12 +1162,9 @@ def test_round_trip(self, encode_method, output_format, tmp_path):
11621162
11631163 @pytest .mark .skipif (in_fbcode (), reason = "TODO: enable ffmpeg CLI" )
11641164 @pytest .mark .parametrize ("asset" , (NASA_AUDIO_MP3 , SINE_MONO_S32 ))
1165- @pytest .mark .parametrize (
1166- "encode_method" , (encode_audio_to_file , encode_audio_to_tensor )
1167- )
11681165 @pytest .mark .parametrize ("bit_rate" , (None , 0 , 44_100 , 999_999_999 ))
11691166 @pytest .mark .parametrize ("output_format" , ("mp3" , "wav" , "flac" ))
1170- def test_against_cli (self , asset , encode_method , bit_rate , output_format , tmp_path ):
1167+ def test_against_cli (self , asset , bit_rate , output_format , tmp_path ):
11711168 # Encodes samples with our encoder and with the FFmpeg CLI, and checks
11721169 # that both decoded outputs are equal
11731170
@@ -1186,24 +1183,14 @@ def test_against_cli(self, asset, encode_method, bit_rate, output_format, tmp_pa
11861183 check = True ,
11871184 )
11881185
1189- if encode_method is encode_audio_to_file :
1190- encoded_by_us = tmp_path / f"our_output.{ output_format } "
1191- encode_audio_to_file (
1192- wf = self .decode (asset ),
1193- sample_rate = asset .sample_rate ,
1194- filename = str (encoded_by_us ),
1195- bit_rate = bit_rate ,
1196- )
1197- else :
1198- encoded_by_us = encode_audio_to_tensor (
1199- wf = self .decode (asset ),
1200- sample_rate = asset .sample_rate ,
1201- format = output_format ,
1202- bit_rate = bit_rate ,
1203- )
1186+ encoded_by_us = tmp_path / f"our_output.{ output_format } "
1187+ encode_audio_to_file (
1188+ wf = self .decode (asset ),
1189+ sample_rate = asset .sample_rate ,
1190+ filename = str (encoded_by_us ),
1191+ bit_rate = bit_rate ,
1192+ )
12041193
1205- if output_format == "mp3" and encode_method is encode_audio_to_tensor :
1206- pytest .skip ("TODO-ENCODING investigate, decoded lengths are slightly different" )
12071194 rtol , atol = (0 , 1e-4 ) if output_format == "wav" else (None , None )
12081195 torch .testing .assert_close (
12091196 self .decode (encoded_by_ffmpeg ),
@@ -1212,6 +1199,32 @@ def test_against_cli(self, asset, encode_method, bit_rate, output_format, tmp_pa
12121199 atol = atol ,
12131200 )
12141201
1202+ @pytest .mark .parametrize ("asset" , (NASA_AUDIO_MP3 , SINE_MONO_S32 ))
1203+ @pytest .mark .parametrize ("bit_rate" , (None , 0 , 44_100 , 999_999_999 ))
1204+ @pytest .mark .parametrize ("output_format" , ("mp3" , "wav" , "flac" ))
1205+ def test_tensor_against_file (self , asset , bit_rate , output_format , tmp_path ):
1206+ if get_ffmpeg_major_version () == 4 and output_format == "wav" :
1207+ pytest .skip ("Swresample with FFmpeg 4 doesn't work on wav files" )
1208+
1209+ encoded_file = tmp_path / f"our_output.{ output_format } "
1210+ encode_audio_to_file (
1211+ wf = self .decode (asset ),
1212+ sample_rate = asset .sample_rate ,
1213+ filename = str (encoded_file ),
1214+ bit_rate = bit_rate ,
1215+ )
1216+
1217+ encoded_tensor = encode_audio_to_tensor (
1218+ wf = self .decode (asset ),
1219+ sample_rate = asset .sample_rate ,
1220+ format = output_format ,
1221+ bit_rate = bit_rate ,
1222+ )
1223+
1224+ torch .testing .assert_close (
1225+ self .decode (encoded_file ), self .decode (encoded_tensor )
1226+ )
1227+
12151228
12161229if __name__ == "__main__" :
12171230 pytest .main ()
0 commit comments