@@ -343,31 +343,45 @@ def test_encode_to_tensor_long_output(self):
343343
344344 torch .testing .assert_close (self .decode (encoded_tensor ).data , samples )
345345
346- def test_contiguity (self ):
346+ @pytest .mark .parametrize ("method" , ("to_file" , "to_tensor" , "to_file_like" ))
347+ def test_contiguity (self , method , tmp_path ):
347348 # Ensure that 2 waveforms with the same values are encoded in the same
348349 # way, regardless of their memory layout. Here we encode 2 equal
349350 # waveforms, one is row-aligned while the other is column-aligned.
350- # TODO: Ideally we'd be testing all encoding methods here
351351
352352 num_samples = 10_000 # per channel
353353 contiguous_samples = torch .rand (2 , num_samples ).contiguous ()
354354 assert contiguous_samples .stride () == (num_samples , 1 )
355355
356- params = dict (format = "flac" , bit_rate = 44_000 )
357- encoded_from_contiguous = AudioEncoder (
358- contiguous_samples , sample_rate = 16_000
359- ).to_tensor (** params )
360-
361356 non_contiguous_samples = contiguous_samples .T .contiguous ().T
362357 assert non_contiguous_samples .stride () == (1 , 2 )
363358
364359 torch .testing .assert_close (
365360 contiguous_samples , non_contiguous_samples , rtol = 0 , atol = 0
366361 )
367362
368- encoded_from_non_contiguous = AudioEncoder (
369- non_contiguous_samples , sample_rate = 16_000
370- ).to_tensor (** params )
363+ def encode_to_tensor (samples ):
364+ params = dict (bit_rate = 44_000 )
365+ if method == "to_file" :
366+ dest = str (tmp_path / "output.flac" )
367+ AudioEncoder (samples , sample_rate = 16_000 ).to_file (dest = dest , ** params )
368+ with open (dest , "rb" ) as f :
369+ return torch .frombuffer (f .read (), dtype = torch .uint8 )
370+ elif method == "to_tensor" :
371+ return AudioEncoder (samples , sample_rate = 16_000 ).to_tensor (
372+ format = "flac" , ** params
373+ )
374+ elif method == "to_file_like" :
375+ file_like = io .BytesIO ()
376+ AudioEncoder (samples , sample_rate = 16_000 ).to_file_like (
377+ file_like , format = "flac" , ** params
378+ )
379+ return torch .frombuffer (file_like .getvalue (), dtype = torch .uint8 )
380+ else :
381+ raise ValueError (f"Unknown method: { method } " )
382+
383+ encoded_from_contiguous = encode_to_tensor (contiguous_samples )
384+ encoded_from_non_contiguous = encode_to_tensor (non_contiguous_samples )
371385
372386 torch .testing .assert_close (
373387 encoded_from_contiguous , encoded_from_non_contiguous , rtol = 0 , atol = 0
0 commit comments