|
6 | 6 |
|
7 | 7 | import io |
8 | 8 | import os |
| 9 | +import re |
9 | 10 | from functools import partial |
10 | 11 |
|
11 | 12 | os.environ["TORCH_LOGS"] = "output_code" |
@@ -1119,6 +1120,239 @@ def test_bad_input(self, tmp_path): |
1119 | 1120 | wf=torch.rand(2, 10), sample_rate=10, filename="./file.bad_extension" |
1120 | 1121 | ) |
1121 | 1122 |
|
| 1123 | + with pytest.raises(RuntimeError, match="invalid sample rate=10"): |
| 1124 | + encode_audio_to_file( |
| 1125 | + wf=self.decode(NASA_AUDIO_MP3), |
| 1126 | + sample_rate=10, |
| 1127 | + filename=valid_output_file, |
| 1128 | + ) |
| 1129 | + with pytest.raises(RuntimeError, match="bit_rate=-1 must be >= 0"): |
| 1130 | + encode_audio_to_file( |
| 1131 | + wf=self.decode(NASA_AUDIO_MP3), |
| 1132 | + sample_rate=NASA_AUDIO_MP3.sample_rate, |
| 1133 | + filename=valid_output_file, |
| 1134 | + bit_rate=-1, # bad |
| 1135 | + ) |
| 1136 | + |
| 1137 | + with pytest.raises(RuntimeError, match="Trying to encode 10 channels"): |
| 1138 | + encode_audio_to_file( |
| 1139 | + wf=torch.rand(10, 20), sample_rate=10, filename="doesnt_matter" |
| 1140 | + ) |
| 1141 | + |
| 1142 | + for num_channels in (0, 3): |
| 1143 | + with pytest.raises( |
| 1144 | + RuntimeError, |
| 1145 | + match=re.escape( |
| 1146 | + f"Desired number of channels ({num_channels}) is not supported" |
| 1147 | + ), |
| 1148 | + ): |
| 1149 | + encode_audio_to_file( |
| 1150 | + wf=torch.rand(2, 10), |
| 1151 | + sample_rate=16_000, |
| 1152 | + filename="ok.mp3", |
| 1153 | + num_channels=num_channels, |
| 1154 | + ) |
| 1155 | + |
| 1156 | + @pytest.mark.parametrize( |
| 1157 | + "encode_method", (encode_audio_to_file, encode_audio_to_tensor) |
| 1158 | + ) |
| 1159 | + @pytest.mark.parametrize("output_format", ("wav", "flac")) |
| 1160 | + def test_round_trip(self, encode_method, output_format, tmp_path): |
| 1161 | + # Check that decode(encode(samples)) == samples on lossless formats |
| 1162 | + |
| 1163 | + if get_ffmpeg_major_version() == 4 and output_format == "wav": |
| 1164 | + pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files") |
| 1165 | + |
| 1166 | + asset = NASA_AUDIO_MP3 |
| 1167 | + source_samples = self.decode(asset) |
| 1168 | + |
| 1169 | + if encode_method is encode_audio_to_file: |
| 1170 | + encoded_path = tmp_path / f"output.{output_format}" |
| 1171 | + encode_audio_to_file( |
| 1172 | + wf=source_samples, |
| 1173 | + sample_rate=asset.sample_rate, |
| 1174 | + filename=str(encoded_path), |
| 1175 | + ) |
| 1176 | + encoded_source = encoded_path |
| 1177 | + else: |
| 1178 | + encoded_source = encode_audio_to_tensor( |
| 1179 | + wf=source_samples, sample_rate=asset.sample_rate, format=output_format |
| 1180 | + ) |
| 1181 | + assert encoded_source.dtype == torch.uint8 |
| 1182 | + assert encoded_source.ndim == 1 |
| 1183 | + |
| 1184 | + rtol, atol = (0, 1e-4) if output_format == "wav" else (None, None) |
| 1185 | + torch.testing.assert_close( |
| 1186 | + self.decode(encoded_source), source_samples, rtol=rtol, atol=atol |
| 1187 | + ) |
| 1188 | + |
| 1189 | + @pytest.mark.skipif(in_fbcode(), reason="TODO: enable ffmpeg CLI") |
| 1190 | + @pytest.mark.parametrize("asset", (NASA_AUDIO_MP3, SINE_MONO_S32)) |
| 1191 | + @pytest.mark.parametrize("bit_rate", (None, 0, 44_100, 999_999_999)) |
| 1192 | + @pytest.mark.parametrize("num_channels", (None, 1, 2)) |
| 1193 | + @pytest.mark.parametrize("output_format", ("mp3", "wav", "flac")) |
| 1194 | + def test_against_cli(self, asset, bit_rate, num_channels, output_format, tmp_path): |
| 1195 | + # Encodes samples with our encoder and with the FFmpeg CLI, and checks |
| 1196 | + # that both decoded outputs are equal |
| 1197 | + |
| 1198 | + if get_ffmpeg_major_version() == 4 and output_format == "wav": |
| 1199 | + pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files") |
| 1200 | + |
| 1201 | + encoded_by_ffmpeg = tmp_path / f"ffmpeg_output.{output_format}" |
| 1202 | + subprocess.run( |
| 1203 | + ["ffmpeg", "-i", str(asset.path)] |
| 1204 | + + (["-b:a", f"{bit_rate}"] if bit_rate is not None else []) |
| 1205 | + + (["-ac", f"{num_channels}"] if num_channels is not None else []) |
| 1206 | + + [ |
| 1207 | + str(encoded_by_ffmpeg), |
| 1208 | + ], |
| 1209 | + capture_output=True, |
| 1210 | + check=True, |
| 1211 | + ) |
| 1212 | + |
| 1213 | + encoded_by_us = tmp_path / f"our_output.{output_format}" |
| 1214 | + encode_audio_to_file( |
| 1215 | + wf=self.decode(asset), |
| 1216 | + sample_rate=asset.sample_rate, |
| 1217 | + filename=str(encoded_by_us), |
| 1218 | + bit_rate=bit_rate, |
| 1219 | + num_channels=num_channels, |
| 1220 | + ) |
| 1221 | + |
| 1222 | + if output_format == "wav": |
| 1223 | + rtol, atol = 0, 1e-4 |
| 1224 | + elif output_format == "mp3" and asset is SINE_MONO_S32 and num_channels == 2: |
| 1225 | + # Not sure why, this one needs slightly higher tol. With default |
| 1226 | + # tolerances, the check fails on ~1% of the samples, so that's |
| 1227 | + # probably fine. It might be that the FFmpeg CLI doesn't rely on |
| 1228 | + # libswresample for converting channels? |
| 1229 | + rtol, atol = 0, 1e-3 |
| 1230 | + else: |
| 1231 | + rtol, atol = None, None |
| 1232 | + torch.testing.assert_close( |
| 1233 | + self.decode(encoded_by_ffmpeg), |
| 1234 | + self.decode(encoded_by_us), |
| 1235 | + rtol=rtol, |
| 1236 | + atol=atol, |
| 1237 | + ) |
| 1238 | + |
| 1239 | + @pytest.mark.parametrize("asset", (NASA_AUDIO_MP3, SINE_MONO_S32)) |
| 1240 | + @pytest.mark.parametrize("bit_rate", (None, 0, 44_100, 999_999_999)) |
| 1241 | + @pytest.mark.parametrize("num_channels", (None, 1, 2)) |
| 1242 | + @pytest.mark.parametrize("output_format", ("mp3", "wav", "flac")) |
| 1243 | + def test_tensor_against_file( |
| 1244 | + self, asset, bit_rate, num_channels, output_format, tmp_path |
| 1245 | + ): |
| 1246 | + if get_ffmpeg_major_version() == 4 and output_format == "wav": |
| 1247 | + pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files") |
| 1248 | + |
| 1249 | + encoded_file = tmp_path / f"our_output.{output_format}" |
| 1250 | + encode_audio_to_file( |
| 1251 | + wf=self.decode(asset), |
| 1252 | + sample_rate=asset.sample_rate, |
| 1253 | + filename=str(encoded_file), |
| 1254 | + bit_rate=bit_rate, |
| 1255 | + num_channels=num_channels, |
| 1256 | + ) |
| 1257 | + |
| 1258 | + encoded_tensor = encode_audio_to_tensor( |
| 1259 | + wf=self.decode(asset), |
| 1260 | + sample_rate=asset.sample_rate, |
| 1261 | + format=output_format, |
| 1262 | + bit_rate=bit_rate, |
| 1263 | + num_channels=num_channels, |
| 1264 | + ) |
| 1265 | + |
| 1266 | + torch.testing.assert_close( |
| 1267 | + self.decode(encoded_file), self.decode(encoded_tensor) |
| 1268 | + ) |
| 1269 | + |
| 1270 | + def test_encode_to_tensor_long_output(self): |
| 1271 | + # Check that we support re-allocating the output tensor when the encoded |
| 1272 | + # data is large. |
| 1273 | + samples = torch.rand(1, int(1e7)) |
| 1274 | + encoded_tensor = encode_audio_to_tensor( |
| 1275 | + wf=samples, |
| 1276 | + sample_rate=16_000, |
| 1277 | + format="flac", |
| 1278 | + bit_rate=44_000, |
| 1279 | + ) |
| 1280 | + # Note: this should be in sync with its C++ counterpart for the test to |
| 1281 | + # be meaningful. |
| 1282 | + INITIAL_TENSOR_SIZE = 10_000_000 |
| 1283 | + assert encoded_tensor.numel() > INITIAL_TENSOR_SIZE |
| 1284 | + |
| 1285 | + torch.testing.assert_close(self.decode(encoded_tensor), samples) |
| 1286 | + |
| 1287 | + def test_contiguity(self): |
| 1288 | + # Ensure that 2 waveforms with the same values are encoded in the same |
| 1289 | + # way, regardless of their memory layout. Here we encode 2 equal |
| 1290 | + # waveforms, one is row-aligned while the other is column-aligned. |
| 1291 | + |
| 1292 | + num_samples = 10_000 # per channel |
| 1293 | + contiguous_samples = torch.rand(2, num_samples).contiguous() |
| 1294 | + assert contiguous_samples.stride() == (num_samples, 1) |
| 1295 | + |
| 1296 | + encoded_from_contiguous = encode_audio_to_tensor( |
| 1297 | + wf=contiguous_samples, |
| 1298 | + sample_rate=16_000, |
| 1299 | + format="flac", |
| 1300 | + bit_rate=44_000, |
| 1301 | + ) |
| 1302 | + non_contiguous_samples = contiguous_samples.T.contiguous().T |
| 1303 | + assert non_contiguous_samples.stride() == (1, 2) |
| 1304 | + |
| 1305 | + torch.testing.assert_close( |
| 1306 | + contiguous_samples, non_contiguous_samples, rtol=0, atol=0 |
| 1307 | + ) |
| 1308 | + |
| 1309 | + encoded_from_non_contiguous = encode_audio_to_tensor( |
| 1310 | + wf=non_contiguous_samples, |
| 1311 | + sample_rate=16_000, |
| 1312 | + format="flac", |
| 1313 | + bit_rate=44_000, |
| 1314 | + ) |
| 1315 | + |
| 1316 | + torch.testing.assert_close( |
| 1317 | + encoded_from_contiguous, encoded_from_non_contiguous, rtol=0, atol=0 |
| 1318 | + ) |
| 1319 | + |
| 1320 | + @pytest.mark.parametrize("num_channels_input", (1, 2)) |
| 1321 | + @pytest.mark.parametrize("num_channels_output", (1, 2, None)) |
| 1322 | + @pytest.mark.parametrize( |
| 1323 | + "encode_method", (encode_audio_to_file, encode_audio_to_tensor) |
| 1324 | + ) |
| 1325 | + def test_num_channels( |
| 1326 | + self, num_channels_input, num_channels_output, encode_method, tmp_path |
| 1327 | + ): |
| 1328 | + # We just check that the num_channels parameter is respected. |
| 1329 | + # Correctness is checked in other tests (like test_against_cli()) |
| 1330 | + |
| 1331 | + sample_rate = 16_000 |
| 1332 | + source_samples = torch.rand(num_channels_input, 1_000) |
| 1333 | + format = "mp3" |
| 1334 | + |
| 1335 | + if encode_method is encode_audio_to_file: |
| 1336 | + encoded_path = tmp_path / f"output.{format}" |
| 1337 | + encode_audio_to_file( |
| 1338 | + wf=source_samples, |
| 1339 | + sample_rate=sample_rate, |
| 1340 | + filename=str(encoded_path), |
| 1341 | + num_channels=num_channels_output, |
| 1342 | + ) |
| 1343 | + encoded_source = encoded_path |
| 1344 | + else: |
| 1345 | + encoded_source = encode_audio_to_tensor( |
| 1346 | + wf=source_samples, |
| 1347 | + sample_rate=sample_rate, |
| 1348 | + format=format, |
| 1349 | + num_channels=num_channels_output, |
| 1350 | + ) |
| 1351 | + |
| 1352 | + if num_channels_output is None: |
| 1353 | + num_channels_output = num_channels_input |
| 1354 | + assert self.decode(encoded_source).shape[0] == num_channels_output |
| 1355 | + |
1122 | 1356 |
|
1123 | 1357 | if __name__ == "__main__": |
1124 | 1358 | pytest.main() |
0 commit comments