@@ -49,9 +49,13 @@ class Audio:
4949 Args:
5050 sampling_rate (`int`, *optional*):
5151 Target sampling rate. If `None`, the native sampling rate is used.
52- mono (`bool`, defaults to `True`):
53- Whether to convert the audio signal to mono by averaging samples across
54- channels.
52+ num_channels (`int`, *optional*):
53+ The desired number of channels of the samples. By default, the number of channels of the source is used.
54+ Audio decoding will return samples with shape (num_channels, num_samples)
55+ Currently `None` (number of channels of the source, default), `1` (mono) or `2` (stereo) channels are supported.
56+ The `num_channels` argument is passed to `torchcodec.decoders.AudioDecoder`.
57+
58+ <Added version="4.4.0"/>
5559 decode (`bool`, defaults to `True`):
5660 Whether to decode the audio data. If `False`,
5761 returns the underlying dictionary in the format `{"path": audio_path, "bytes": audio_bytes}`.
@@ -63,7 +67,7 @@ class Audio:
6367 ```py
6468 >>> from datasets import load_dataset, Audio
6569 >>> ds = load_dataset("PolyAI/minds14", name="en-US", split="train")
66- >>> ds = ds.cast_column("audio", Audio(sampling_rate=44100))
70+ >>> ds = ds.cast_column("audio", Audio(sampling_rate=44100, num_channels=2 ))
6771 >>> ds[0]["audio"]
6872 <datasets.features._torchcodec.AudioDecoder object at 0x11642b6a0>
6973 >>> audio = ds[0]["audio"]
@@ -78,6 +82,7 @@ class Audio:
7882
7983 sampling_rate : Optional [int ] = None
8084 decode : bool = True
85+ num_channels : Optional [int ] = None
8186 stream_index : Optional [int ] = None
8287 id : Optional [str ] = field (default = None , repr = False )
8388 # Automatically constructed
@@ -126,7 +131,7 @@ def encode_example(self, value: Union[str, bytes, bytearray, dict, "AudioDecoder
126131 buffer = BytesIO ()
127132 AudioEncoder (
128133 torch .from_numpy (value ["array" ].astype (np .float32 )), sample_rate = value ["sampling_rate" ]
129- ).to_file_like (buffer , format = "wav" )
134+ ).to_file_like (buffer , format = "wav" , num_channels = self . num_channels )
130135 return {"bytes" : buffer .getvalue (), "path" : None }
131136 elif value .get ("path" ) is not None and os .path .isfile (value ["path" ]):
132137 # we set "bytes": None to not duplicate the data if they're already available locally
@@ -143,7 +148,7 @@ def encode_example(self, value: Union[str, bytes, bytearray, dict, "AudioDecoder
143148
144149 buffer = BytesIO ()
145150 AudioEncoder (torch .from_numpy (bytes_value ), sample_rate = value ["sampling_rate" ]).to_file_like (
146- buffer , format = "wav"
151+ buffer , format = "wav" , num_channels = self . num_channels
147152 )
148153 return {"bytes" : buffer .getvalue (), "path" : None }
149154 else :
@@ -188,7 +193,9 @@ def decode_example(
188193 raise ValueError (f"An audio sample should have one of 'path' or 'bytes' but both are None in { value } ." )
189194
190195 if bytes is None and is_local_path (path ):
191- audio = AudioDecoder (path , stream_index = self .stream_index , sample_rate = self .sampling_rate )
196+ audio = AudioDecoder (
197+ path , stream_index = self .stream_index , sample_rate = self .sampling_rate , num_channels = self .num_channels
198+ )
192199
193200 elif bytes is None :
194201 token_per_repo_id = token_per_repo_id or {}
@@ -201,10 +208,14 @@ def decode_example(
201208
202209 download_config = DownloadConfig (token = token )
203210 f = xopen (path , "rb" , download_config = download_config )
204- audio = AudioDecoder (f , stream_index = self .stream_index , sample_rate = self .sampling_rate )
211+ audio = AudioDecoder (
212+ f , stream_index = self .stream_index , sample_rate = self .sampling_rate , num_channels = self .num_channels
213+ )
205214
206215 else :
207- audio = AudioDecoder (bytes , stream_index = self .stream_index , sample_rate = self .sampling_rate )
216+ audio = AudioDecoder (
217+ bytes , stream_index = self .stream_index , sample_rate = self .sampling_rate , num_channels = self .num_channels
218+ )
208219 audio ._hf_encoded = {"path" : path , "bytes" : bytes }
209220 audio .metadata .path = path
210221 return audio
@@ -312,5 +323,8 @@ def encode_torchcodec_audio(audio: "AudioDecoder") -> dict:
312323
313324 samples = audio .get_all_samples ()
314325 buffer = BytesIO ()
315- AudioEncoder (samples .data .cpu (), sample_rate = samples .sample_rate ).to_file_like (buffer , format = "wav" )
326+ num_channels = samples .data .shape [0 ]
327+ AudioEncoder (samples .data .cpu (), sample_rate = samples .sample_rate ).to_file_like (
328+ buffer , format = "wav" , num_channels = num_channels
329+ )
316330 return {"bytes" : buffer .getvalue (), "path" : None }
0 commit comments