@@ -53,7 +53,7 @@ def validate_frames_properties(*, actual: Path, expected: Path):
5353 for frame_index , (d_actual , d_expected ) in enumerate (
5454 zip (frames_actual , frames_expected )
5555 ):
56- for prop in d_actual :
56+ for prop in d_expected :
5757 if prop == "pkt_pos" :
5858 continue # TODO this probably matters
5959 assert (
@@ -66,7 +66,7 @@ class TestAudioEncoder:
6666 def decode (self , source ) -> torch .Tensor :
6767 if isinstance (source , TestContainerFile ):
6868 source = str (source .path )
69- return AudioDecoder (source ).get_all_samples (). data
69+ return AudioDecoder (source ).get_all_samples ()
7070
7171 def test_bad_input (self ):
7272 with pytest .raises (ValueError , match = "Expected samples to be a Tensor" ):
@@ -108,12 +108,12 @@ def test_bad_input_parametrized(self, method, tmp_path):
108108 else dict (format = "mp3" )
109109 )
110110
111- decoder = AudioEncoder (self .decode (NASA_AUDIO_MP3 ), sample_rate = 10 )
111+ decoder = AudioEncoder (self .decode (NASA_AUDIO_MP3 ). data , sample_rate = 10 )
112112 with pytest .raises (RuntimeError , match = "invalid sample rate=10" ):
113113 getattr (decoder , method )(** valid_params )
114114
115115 decoder = AudioEncoder (
116- self .decode (NASA_AUDIO_MP3 ), sample_rate = NASA_AUDIO_MP3 .sample_rate
116+ self .decode (NASA_AUDIO_MP3 ). data , sample_rate = NASA_AUDIO_MP3 .sample_rate
117117 )
118118 with pytest .raises (RuntimeError , match = "bit_rate=-1 must be >= 0" ):
119119 getattr (decoder , method )(** valid_params , bit_rate = - 1 )
@@ -126,7 +126,7 @@ def test_bad_input_parametrized(self, method, tmp_path):
126126 getattr (decoder , method )(** valid_params )
127127
128128 decoder = AudioEncoder (
129- self .decode (NASA_AUDIO_MP3 ), sample_rate = NASA_AUDIO_MP3 .sample_rate
129+ self .decode (NASA_AUDIO_MP3 ). data , sample_rate = NASA_AUDIO_MP3 .sample_rate
130130 )
131131 for num_channels in (0 , 3 ):
132132 with pytest .raises (
@@ -146,7 +146,7 @@ def test_round_trip(self, method, format, tmp_path):
146146 pytest .skip ("Swresample with FFmpeg 4 doesn't work on wav files" )
147147
148148 asset = NASA_AUDIO_MP3
149- source_samples = self .decode (asset )
149+ source_samples = self .decode (asset ). data
150150
151151 encoder = AudioEncoder (source_samples , sample_rate = asset .sample_rate )
152152
@@ -161,7 +161,7 @@ def test_round_trip(self, method, format, tmp_path):
161161
162162 rtol , atol = (0 , 1e-4 ) if format == "wav" else (None , None )
163163 torch .testing .assert_close (
164- self .decode (encoded_source ), source_samples , rtol = rtol , atol = atol
164+ self .decode (encoded_source ). data , source_samples , rtol = rtol , atol = atol
165165 )
166166
167167 @pytest .mark .skipif (in_fbcode (), reason = "TODO: enable ffmpeg CLI" )
@@ -189,7 +189,7 @@ def test_against_cli(self, asset, bit_rate, num_channels, format, method, tmp_pa
189189 check = True ,
190190 )
191191
192- encoder = AudioEncoder (self .decode (asset ), sample_rate = asset .sample_rate )
192+ encoder = AudioEncoder (self .decode (asset ). data , sample_rate = asset .sample_rate )
193193 params = dict (bit_rate = bit_rate , num_channels = num_channels )
194194 if method == "to_file" :
195195 encoded_by_us = tmp_path / f"output.{ format } "
@@ -207,13 +207,17 @@ def test_against_cli(self, asset, bit_rate, num_channels, format, method, tmp_pa
207207 rtol , atol = 0 , 1e-3
208208 else :
209209 rtol , atol = None , None
210- # TODO should validate `.pts_seconds` and `duration_seconds` as well
210+ samples_by_us = self .decode (encoded_by_us )
211+ samples_by_ffmpeg = self .decode (encoded_by_ffmpeg )
211212 torch .testing .assert_close (
212- self . decode ( encoded_by_us ) ,
213- self . decode ( encoded_by_ffmpeg ) ,
213+ samples_by_us . data ,
214+ samples_by_ffmpeg . data ,
214215 rtol = rtol ,
215216 atol = atol ,
216217 )
218+ assert samples_by_us .pts_seconds == samples_by_ffmpeg .pts_seconds
219+ assert samples_by_us .duration_seconds == samples_by_ffmpeg .duration_seconds
220+ assert samples_by_us .sample_rate == samples_by_ffmpeg .sample_rate
217221
218222 if method == "to_file" :
219223 validate_frames_properties (actual = encoded_by_us , expected = encoded_by_ffmpeg )
@@ -230,7 +234,7 @@ def test_to_tensor_against_to_file(
230234 if get_ffmpeg_major_version () == 4 and format == "wav" :
231235 pytest .skip ("Swresample with FFmpeg 4 doesn't work on wav files" )
232236
233- encoder = AudioEncoder (self .decode (asset ), sample_rate = asset .sample_rate )
237+ encoder = AudioEncoder (self .decode (asset ). data , sample_rate = asset .sample_rate )
234238
235239 params = dict (bit_rate = bit_rate , num_channels = num_channels )
236240 encoded_file = tmp_path / f"output.{ format } "
@@ -240,7 +244,7 @@ def test_to_tensor_against_to_file(
240244 )
241245
242246 torch .testing .assert_close (
243- self .decode (encoded_file ), self .decode (encoded_tensor )
247+ self .decode (encoded_file ). data , self .decode (encoded_tensor ). data
244248 )
245249
246250 def test_encode_to_tensor_long_output (self ):
@@ -256,7 +260,7 @@ def test_encode_to_tensor_long_output(self):
256260 INITIAL_TENSOR_SIZE = 10_000_000
257261 assert encoded_tensor .numel () > INITIAL_TENSOR_SIZE
258262
259- torch .testing .assert_close (self .decode (encoded_tensor ), samples )
263+ torch .testing .assert_close (self .decode (encoded_tensor ). data , samples )
260264
261265 def test_contiguity (self ):
262266 # Ensure that 2 waveforms with the same values are encoded in the same
@@ -313,4 +317,4 @@ def test_num_channels(
313317
314318 if num_channels_output is None :
315319 num_channels_output = num_channels_input
316- assert self .decode (encoded_source ).shape [0 ] == num_channels_output
320+ assert self .decode (encoded_source ).data . shape [0 ] == num_channels_output
0 commit comments