Skip to content

Commit 9d48ff6

Browse files
committed
Also check fields of AudioSamples
1 parent 30a4754 commit 9d48ff6

File tree

1 file changed

+19
-15
lines changed

1 file changed

+19
-15
lines changed

test/test_encoders.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def validate_frames_properties(*, actual: Path, expected: Path):
5353
for frame_index, (d_actual, d_expected) in enumerate(
5454
zip(frames_actual, frames_expected)
5555
):
56-
for prop in d_actual:
56+
for prop in d_expected:
5757
if prop == "pkt_pos":
5858
continue # TODO this probably matters
5959
assert (
@@ -66,7 +66,7 @@ class TestAudioEncoder:
6666
def decode(self, source) -> torch.Tensor:
6767
if isinstance(source, TestContainerFile):
6868
source = str(source.path)
69-
return AudioDecoder(source).get_all_samples().data
69+
return AudioDecoder(source).get_all_samples()
7070

7171
def test_bad_input(self):
7272
with pytest.raises(ValueError, match="Expected samples to be a Tensor"):
@@ -108,12 +108,12 @@ def test_bad_input_parametrized(self, method, tmp_path):
108108
else dict(format="mp3")
109109
)
110110

111-
decoder = AudioEncoder(self.decode(NASA_AUDIO_MP3), sample_rate=10)
111+
decoder = AudioEncoder(self.decode(NASA_AUDIO_MP3).data, sample_rate=10)
112112
with pytest.raises(RuntimeError, match="invalid sample rate=10"):
113113
getattr(decoder, method)(**valid_params)
114114

115115
decoder = AudioEncoder(
116-
self.decode(NASA_AUDIO_MP3), sample_rate=NASA_AUDIO_MP3.sample_rate
116+
self.decode(NASA_AUDIO_MP3).data, sample_rate=NASA_AUDIO_MP3.sample_rate
117117
)
118118
with pytest.raises(RuntimeError, match="bit_rate=-1 must be >= 0"):
119119
getattr(decoder, method)(**valid_params, bit_rate=-1)
@@ -126,7 +126,7 @@ def test_bad_input_parametrized(self, method, tmp_path):
126126
getattr(decoder, method)(**valid_params)
127127

128128
decoder = AudioEncoder(
129-
self.decode(NASA_AUDIO_MP3), sample_rate=NASA_AUDIO_MP3.sample_rate
129+
self.decode(NASA_AUDIO_MP3).data, sample_rate=NASA_AUDIO_MP3.sample_rate
130130
)
131131
for num_channels in (0, 3):
132132
with pytest.raises(
@@ -146,7 +146,7 @@ def test_round_trip(self, method, format, tmp_path):
146146
pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files")
147147

148148
asset = NASA_AUDIO_MP3
149-
source_samples = self.decode(asset)
149+
source_samples = self.decode(asset).data
150150

151151
encoder = AudioEncoder(source_samples, sample_rate=asset.sample_rate)
152152

@@ -161,7 +161,7 @@ def test_round_trip(self, method, format, tmp_path):
161161

162162
rtol, atol = (0, 1e-4) if format == "wav" else (None, None)
163163
torch.testing.assert_close(
164-
self.decode(encoded_source), source_samples, rtol=rtol, atol=atol
164+
self.decode(encoded_source).data, source_samples, rtol=rtol, atol=atol
165165
)
166166

167167
@pytest.mark.skipif(in_fbcode(), reason="TODO: enable ffmpeg CLI")
@@ -189,7 +189,7 @@ def test_against_cli(self, asset, bit_rate, num_channels, format, method, tmp_pa
189189
check=True,
190190
)
191191

192-
encoder = AudioEncoder(self.decode(asset), sample_rate=asset.sample_rate)
192+
encoder = AudioEncoder(self.decode(asset).data, sample_rate=asset.sample_rate)
193193
params = dict(bit_rate=bit_rate, num_channels=num_channels)
194194
if method == "to_file":
195195
encoded_by_us = tmp_path / f"output.{format}"
@@ -207,13 +207,17 @@ def test_against_cli(self, asset, bit_rate, num_channels, format, method, tmp_pa
207207
rtol, atol = 0, 1e-3
208208
else:
209209
rtol, atol = None, None
210-
# TODO should validate `.pts_seconds` and `duration_seconds` as well
210+
samples_by_us = self.decode(encoded_by_us)
211+
samples_by_ffmpeg = self.decode(encoded_by_ffmpeg)
211212
torch.testing.assert_close(
212-
self.decode(encoded_by_us),
213-
self.decode(encoded_by_ffmpeg),
213+
samples_by_us.data,
214+
samples_by_ffmpeg.data,
214215
rtol=rtol,
215216
atol=atol,
216217
)
218+
assert samples_by_us.pts_seconds == samples_by_ffmpeg.pts_seconds
219+
assert samples_by_us.duration_seconds == samples_by_ffmpeg.duration_seconds
220+
assert samples_by_us.sample_rate == samples_by_ffmpeg.sample_rate
217221

218222
if method == "to_file":
219223
validate_frames_properties(actual=encoded_by_us, expected=encoded_by_ffmpeg)
@@ -230,7 +234,7 @@ def test_to_tensor_against_to_file(
230234
if get_ffmpeg_major_version() == 4 and format == "wav":
231235
pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files")
232236

233-
encoder = AudioEncoder(self.decode(asset), sample_rate=asset.sample_rate)
237+
encoder = AudioEncoder(self.decode(asset).data, sample_rate=asset.sample_rate)
234238

235239
params = dict(bit_rate=bit_rate, num_channels=num_channels)
236240
encoded_file = tmp_path / f"output.{format}"
@@ -240,7 +244,7 @@ def test_to_tensor_against_to_file(
240244
)
241245

242246
torch.testing.assert_close(
243-
self.decode(encoded_file), self.decode(encoded_tensor)
247+
self.decode(encoded_file).data, self.decode(encoded_tensor).data
244248
)
245249

246250
def test_encode_to_tensor_long_output(self):
@@ -256,7 +260,7 @@ def test_encode_to_tensor_long_output(self):
256260
INITIAL_TENSOR_SIZE = 10_000_000
257261
assert encoded_tensor.numel() > INITIAL_TENSOR_SIZE
258262

259-
torch.testing.assert_close(self.decode(encoded_tensor), samples)
263+
torch.testing.assert_close(self.decode(encoded_tensor).data, samples)
260264

261265
def test_contiguity(self):
262266
# Ensure that 2 waveforms with the same values are encoded in the same
@@ -313,4 +317,4 @@ def test_num_channels(
313317

314318
if num_channels_output is None:
315319
num_channels_output = num_channels_input
316-
assert self.decode(encoded_source).shape[0] == num_channels_output
320+
assert self.decode(encoded_source).data.shape[0] == num_channels_output

0 commit comments

Comments
 (0)