|
| 1 | +import json |
1 | 2 | import re |
2 | 3 | import subprocess |
| 4 | +from pathlib import Path |
3 | 5 |
|
4 | 6 | import pytest |
5 | 7 | import torch |
|
16 | 18 | ) |
17 | 19 |
|
18 | 20 |
|
| 21 | +def validate_frames_properties(*, actual: Path, expected: Path): |
| 22 | + |
| 23 | + frames_actual, frames_expected = ( |
| 24 | + json.loads( |
| 25 | + subprocess.run( |
| 26 | + [ |
| 27 | + "ffprobe", |
| 28 | + "-v", |
| 29 | + "error", |
| 30 | + "-hide_banner", |
| 31 | + "-select_streams", |
| 32 | + "a:0", |
| 33 | + "-show_frames", |
| 34 | + "-of", |
| 35 | + "json", |
| 36 | + f"{f}", |
| 37 | + ], |
| 38 | + check=True, |
| 39 | + capture_output=True, |
| 40 | + text=True, |
| 41 | + ).stdout |
| 42 | + )["frames"] |
| 43 | + for f in (actual, expected) |
| 44 | + ) |
| 45 | + |
| 46 | + # frames_actual and frames_expected are both a list of dicts, each dict |
| 47 | + # corresponds to a frame and each key-value pair corresponds to a frame |
| 48 | + # property like pts, nb_samples, etc., similar to the AVFrame fields. |
| 49 | + assert isinstance(frames_actual, list) |
| 50 | + assert all(isinstance(d, dict) for d in frames_actual) |
| 51 | + |
| 52 | + assert len(frames_actual) == len(frames_expected) |
| 53 | + for frame_index, (d_actual, d_expected) in enumerate( |
| 54 | + zip(frames_actual, frames_expected) |
| 55 | + ): |
| 56 | + for prop in d_actual: |
| 57 | + if prop == "pkt_pos": |
| 58 | + continue # TODO this probably matters |
| 59 | + assert ( |
| 60 | + d_actual[prop] == d_expected[prop] |
| 61 | + ), f"{prop} value is different for frame {frame_index}:" |
| 62 | + |
| 63 | + |
19 | 64 | class TestAudioEncoder: |
20 | 65 |
|
21 | 66 | def decode(self, source) -> torch.Tensor: |
@@ -162,13 +207,19 @@ def test_against_cli(self, asset, bit_rate, num_channels, format, method, tmp_pa |
162 | 207 | rtol, atol = 0, 1e-3 |
163 | 208 | else: |
164 | 209 | rtol, atol = None, None |
| 210 | + # TODO should validate `.pts_seconds` and `duration_seconds` as well |
165 | 211 | torch.testing.assert_close( |
166 | | - self.decode(encoded_by_ffmpeg), |
167 | 212 | self.decode(encoded_by_us), |
| 213 | + self.decode(encoded_by_ffmpeg), |
168 | 214 | rtol=rtol, |
169 | 215 | atol=atol, |
170 | 216 | ) |
171 | 217 |
|
| 218 | + if method == "to_file": |
| 219 | + validate_frames_properties(actual=encoded_by_us, expected=encoded_by_ffmpeg) |
| 220 | + else: |
| 221 | + assert method == "to_tensor", "wrong test parametrization!" |
| 222 | + |
172 | 223 | @pytest.mark.parametrize("asset", (NASA_AUDIO_MP3, SINE_MONO_S32)) |
173 | 224 | @pytest.mark.parametrize("bit_rate", (None, 0, 44_100, 999_999_999)) |
174 | 225 | @pytest.mark.parametrize("num_channels", (None, 1, 2)) |
|
0 commit comments