|
6 | 6 |
|
7 | 7 | import io |
8 | 8 | import os |
9 | | -import re |
10 | 9 | from functools import partial |
11 | 10 |
|
12 | 11 | os.environ["TORCH_LOGS"] = "output_code" |
@@ -1120,239 +1119,6 @@ def test_bad_input(self, tmp_path): |
1120 | 1119 | wf=torch.rand(2, 10), sample_rate=10, filename="./file.bad_extension" |
1121 | 1120 | ) |
1122 | 1121 |
|
1123 | | - with pytest.raises(RuntimeError, match="invalid sample rate=10"): |
1124 | | - encode_audio_to_file( |
1125 | | - wf=self.decode(NASA_AUDIO_MP3), |
1126 | | - sample_rate=10, |
1127 | | - filename=valid_output_file, |
1128 | | - ) |
1129 | | - with pytest.raises(RuntimeError, match="bit_rate=-1 must be >= 0"): |
1130 | | - encode_audio_to_file( |
1131 | | - wf=self.decode(NASA_AUDIO_MP3), |
1132 | | - sample_rate=NASA_AUDIO_MP3.sample_rate, |
1133 | | - filename=valid_output_file, |
1134 | | - bit_rate=-1, # bad |
1135 | | - ) |
1136 | | - |
1137 | | - with pytest.raises(RuntimeError, match="Trying to encode 10 channels"): |
1138 | | - encode_audio_to_file( |
1139 | | - wf=torch.rand(10, 20), sample_rate=10, filename="doesnt_matter" |
1140 | | - ) |
1141 | | - |
1142 | | - for num_channels in (0, 3): |
1143 | | - with pytest.raises( |
1144 | | - RuntimeError, |
1145 | | - match=re.escape( |
1146 | | - f"Desired number of channels ({num_channels}) is not supported" |
1147 | | - ), |
1148 | | - ): |
1149 | | - encode_audio_to_file( |
1150 | | - wf=torch.rand(2, 10), |
1151 | | - sample_rate=16_000, |
1152 | | - filename="ok.mp3", |
1153 | | - num_channels=num_channels, |
1154 | | - ) |
1155 | | - |
1156 | | - @pytest.mark.parametrize( |
1157 | | - "encode_method", (encode_audio_to_file, encode_audio_to_tensor) |
1158 | | - ) |
1159 | | - @pytest.mark.parametrize("output_format", ("wav", "flac")) |
1160 | | - def test_round_trip(self, encode_method, output_format, tmp_path): |
1161 | | - # Check that decode(encode(samples)) == samples on lossless formats |
1162 | | - |
1163 | | - if get_ffmpeg_major_version() == 4 and output_format == "wav": |
1164 | | - pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files") |
1165 | | - |
1166 | | - asset = NASA_AUDIO_MP3 |
1167 | | - source_samples = self.decode(asset) |
1168 | | - |
1169 | | - if encode_method is encode_audio_to_file: |
1170 | | - encoded_path = tmp_path / f"output.{output_format}" |
1171 | | - encode_audio_to_file( |
1172 | | - wf=source_samples, |
1173 | | - sample_rate=asset.sample_rate, |
1174 | | - filename=str(encoded_path), |
1175 | | - ) |
1176 | | - encoded_source = encoded_path |
1177 | | - else: |
1178 | | - encoded_source = encode_audio_to_tensor( |
1179 | | - wf=source_samples, sample_rate=asset.sample_rate, format=output_format |
1180 | | - ) |
1181 | | - assert encoded_source.dtype == torch.uint8 |
1182 | | - assert encoded_source.ndim == 1 |
1183 | | - |
1184 | | - rtol, atol = (0, 1e-4) if output_format == "wav" else (None, None) |
1185 | | - torch.testing.assert_close( |
1186 | | - self.decode(encoded_source), source_samples, rtol=rtol, atol=atol |
1187 | | - ) |
1188 | | - |
1189 | | - @pytest.mark.skipif(in_fbcode(), reason="TODO: enable ffmpeg CLI") |
1190 | | - @pytest.mark.parametrize("asset", (NASA_AUDIO_MP3, SINE_MONO_S32)) |
1191 | | - @pytest.mark.parametrize("bit_rate", (None, 0, 44_100, 999_999_999)) |
1192 | | - @pytest.mark.parametrize("num_channels", (None, 1, 2)) |
1193 | | - @pytest.mark.parametrize("output_format", ("mp3", "wav", "flac")) |
1194 | | - def test_against_cli(self, asset, bit_rate, num_channels, output_format, tmp_path): |
1195 | | - # Encodes samples with our encoder and with the FFmpeg CLI, and checks |
1196 | | - # that both decoded outputs are equal |
1197 | | - |
1198 | | - if get_ffmpeg_major_version() == 4 and output_format == "wav": |
1199 | | - pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files") |
1200 | | - |
1201 | | - encoded_by_ffmpeg = tmp_path / f"ffmpeg_output.{output_format}" |
1202 | | - subprocess.run( |
1203 | | - ["ffmpeg", "-i", str(asset.path)] |
1204 | | - + (["-b:a", f"{bit_rate}"] if bit_rate is not None else []) |
1205 | | - + (["-ac", f"{num_channels}"] if num_channels is not None else []) |
1206 | | - + [ |
1207 | | - str(encoded_by_ffmpeg), |
1208 | | - ], |
1209 | | - capture_output=True, |
1210 | | - check=True, |
1211 | | - ) |
1212 | | - |
1213 | | - encoded_by_us = tmp_path / f"our_output.{output_format}" |
1214 | | - encode_audio_to_file( |
1215 | | - wf=self.decode(asset), |
1216 | | - sample_rate=asset.sample_rate, |
1217 | | - filename=str(encoded_by_us), |
1218 | | - bit_rate=bit_rate, |
1219 | | - num_channels=num_channels, |
1220 | | - ) |
1221 | | - |
1222 | | - if output_format == "wav": |
1223 | | - rtol, atol = 0, 1e-4 |
1224 | | - elif output_format == "mp3" and asset is SINE_MONO_S32 and num_channels == 2: |
1225 | | - # Not sure why, this one needs slightly higher tol. With default |
1226 | | - # tolerances, the check fails on ~1% of the samples, so that's |
1227 | | - # probably fine. It might be that the FFmpeg CLI doesn't rely on |
1228 | | - # libswresample for converting channels? |
1229 | | - rtol, atol = 0, 1e-3 |
1230 | | - else: |
1231 | | - rtol, atol = None, None |
1232 | | - torch.testing.assert_close( |
1233 | | - self.decode(encoded_by_ffmpeg), |
1234 | | - self.decode(encoded_by_us), |
1235 | | - rtol=rtol, |
1236 | | - atol=atol, |
1237 | | - ) |
1238 | | - |
1239 | | - @pytest.mark.parametrize("asset", (NASA_AUDIO_MP3, SINE_MONO_S32)) |
1240 | | - @pytest.mark.parametrize("bit_rate", (None, 0, 44_100, 999_999_999)) |
1241 | | - @pytest.mark.parametrize("num_channels", (None, 1, 2)) |
1242 | | - @pytest.mark.parametrize("output_format", ("mp3", "wav", "flac")) |
1243 | | - def test_tensor_against_file( |
1244 | | - self, asset, bit_rate, num_channels, output_format, tmp_path |
1245 | | - ): |
1246 | | - if get_ffmpeg_major_version() == 4 and output_format == "wav": |
1247 | | - pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files") |
1248 | | - |
1249 | | - encoded_file = tmp_path / f"our_output.{output_format}" |
1250 | | - encode_audio_to_file( |
1251 | | - wf=self.decode(asset), |
1252 | | - sample_rate=asset.sample_rate, |
1253 | | - filename=str(encoded_file), |
1254 | | - bit_rate=bit_rate, |
1255 | | - num_channels=num_channels, |
1256 | | - ) |
1257 | | - |
1258 | | - encoded_tensor = encode_audio_to_tensor( |
1259 | | - wf=self.decode(asset), |
1260 | | - sample_rate=asset.sample_rate, |
1261 | | - format=output_format, |
1262 | | - bit_rate=bit_rate, |
1263 | | - num_channels=num_channels, |
1264 | | - ) |
1265 | | - |
1266 | | - torch.testing.assert_close( |
1267 | | - self.decode(encoded_file), self.decode(encoded_tensor) |
1268 | | - ) |
1269 | | - |
1270 | | - def test_encode_to_tensor_long_output(self): |
1271 | | - # Check that we support re-allocating the output tensor when the encoded |
1272 | | - # data is large. |
1273 | | - samples = torch.rand(1, int(1e7)) |
1274 | | - encoded_tensor = encode_audio_to_tensor( |
1275 | | - wf=samples, |
1276 | | - sample_rate=16_000, |
1277 | | - format="flac", |
1278 | | - bit_rate=44_000, |
1279 | | - ) |
1280 | | - # Note: this should be in sync with its C++ counterpart for the test to |
1281 | | - # be meaningful. |
1282 | | - INITIAL_TENSOR_SIZE = 10_000_000 |
1283 | | - assert encoded_tensor.numel() > INITIAL_TENSOR_SIZE |
1284 | | - |
1285 | | - torch.testing.assert_close(self.decode(encoded_tensor), samples) |
1286 | | - |
1287 | | - def test_contiguity(self): |
1288 | | - # Ensure that 2 waveforms with the same values are encoded in the same |
1289 | | - # way, regardless of their memory layout. Here we encode 2 equal |
1290 | | - # waveforms, one is row-aligned while the other is column-aligned. |
1291 | | - |
1292 | | - num_samples = 10_000 # per channel |
1293 | | - contiguous_samples = torch.rand(2, num_samples).contiguous() |
1294 | | - assert contiguous_samples.stride() == (num_samples, 1) |
1295 | | - |
1296 | | - encoded_from_contiguous = encode_audio_to_tensor( |
1297 | | - wf=contiguous_samples, |
1298 | | - sample_rate=16_000, |
1299 | | - format="flac", |
1300 | | - bit_rate=44_000, |
1301 | | - ) |
1302 | | - non_contiguous_samples = contiguous_samples.T.contiguous().T |
1303 | | - assert non_contiguous_samples.stride() == (1, 2) |
1304 | | - |
1305 | | - torch.testing.assert_close( |
1306 | | - contiguous_samples, non_contiguous_samples, rtol=0, atol=0 |
1307 | | - ) |
1308 | | - |
1309 | | - encoded_from_non_contiguous = encode_audio_to_tensor( |
1310 | | - wf=non_contiguous_samples, |
1311 | | - sample_rate=16_000, |
1312 | | - format="flac", |
1313 | | - bit_rate=44_000, |
1314 | | - ) |
1315 | | - |
1316 | | - torch.testing.assert_close( |
1317 | | - encoded_from_contiguous, encoded_from_non_contiguous, rtol=0, atol=0 |
1318 | | - ) |
1319 | | - |
1320 | | - @pytest.mark.parametrize("num_channels_input", (1, 2)) |
1321 | | - @pytest.mark.parametrize("num_channels_output", (1, 2, None)) |
1322 | | - @pytest.mark.parametrize( |
1323 | | - "encode_method", (encode_audio_to_file, encode_audio_to_tensor) |
1324 | | - ) |
1325 | | - def test_num_channels( |
1326 | | - self, num_channels_input, num_channels_output, encode_method, tmp_path |
1327 | | - ): |
1328 | | - # We just check that the num_channels parameter is respected. |
1329 | | - # Correctness is checked in other tests (like test_against_cli()) |
1330 | | - |
1331 | | - sample_rate = 16_000 |
1332 | | - source_samples = torch.rand(num_channels_input, 1_000) |
1333 | | - format = "mp3" |
1334 | | - |
1335 | | - if encode_method is encode_audio_to_file: |
1336 | | - encoded_path = tmp_path / f"output.{format}" |
1337 | | - encode_audio_to_file( |
1338 | | - wf=source_samples, |
1339 | | - sample_rate=sample_rate, |
1340 | | - filename=str(encoded_path), |
1341 | | - num_channels=num_channels_output, |
1342 | | - ) |
1343 | | - encoded_source = encoded_path |
1344 | | - else: |
1345 | | - encoded_source = encode_audio_to_tensor( |
1346 | | - wf=source_samples, |
1347 | | - sample_rate=sample_rate, |
1348 | | - format=format, |
1349 | | - num_channels=num_channels_output, |
1350 | | - ) |
1351 | | - |
1352 | | - if num_channels_output is None: |
1353 | | - num_channels_output = num_channels_input |
1354 | | - assert self.decode(encoded_source).shape[0] == num_channels_output |
1355 | | - |
1356 | 1122 |
|
1357 | 1123 | if __name__ == "__main__": |
1358 | 1124 | pytest.main() |
0 commit comments