Skip to content

Commit 0e7c6ca

Browse files
CloseChoicelhoestq
andauthored
Add num channels to audio (#7840)
* WIP: add audio, tests failing * WIP: add mono argument, tests failing * change from mono to num_channels in documentation, audio tests passing * update docs and move test for audio * update audio * update docstring for audio * Apply suggestions from code review --------- Co-authored-by: Quentin Lhoest <[email protected]>
1 parent 8b1bd4e commit 0e7c6ca

File tree

2 files changed

+52
-10
lines changed

2 files changed

+52
-10
lines changed

src/datasets/features/audio.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,13 @@ class Audio:
4949
Args:
5050
sampling_rate (`int`, *optional*):
5151
Target sampling rate. If `None`, the native sampling rate is used.
52-
mono (`bool`, defaults to `True`):
53-
Whether to convert the audio signal to mono by averaging samples across
54-
channels.
52+
num_channels (`int`, *optional*):
53+
The desired number of channels of the samples. By default, the number of channels of the source is used.
54+
Audio decoding will return samples with shape (num_channels, num_samples)
55+
Currently `None` (number of channels of the source, default), `1` (mono) or `2` (stereo) channels are supported.
56+
The `num_channels` argument is passed to `torchcodec.decoders.AudioDecoder`.
57+
58+
<Added version="4.4.0"/>
5559
decode (`bool`, defaults to `True`):
5660
Whether to decode the audio data. If `False`,
5761
returns the underlying dictionary in the format `{"path": audio_path, "bytes": audio_bytes}`.
@@ -63,7 +67,7 @@ class Audio:
6367
```py
6468
>>> from datasets import load_dataset, Audio
6569
>>> ds = load_dataset("PolyAI/minds14", name="en-US", split="train")
66-
>>> ds = ds.cast_column("audio", Audio(sampling_rate=44100))
70+
>>> ds = ds.cast_column("audio", Audio(sampling_rate=44100, num_channels=2))
6771
>>> ds[0]["audio"]
6872
<datasets.features._torchcodec.AudioDecoder object at 0x11642b6a0>
6973
>>> audio = ds[0]["audio"]
@@ -78,6 +82,7 @@ class Audio:
7882

7983
sampling_rate: Optional[int] = None
8084
decode: bool = True
85+
num_channels: Optional[int] = None
8186
stream_index: Optional[int] = None
8287
id: Optional[str] = field(default=None, repr=False)
8388
# Automatically constructed
@@ -126,7 +131,7 @@ def encode_example(self, value: Union[str, bytes, bytearray, dict, "AudioDecoder
126131
buffer = BytesIO()
127132
AudioEncoder(
128133
torch.from_numpy(value["array"].astype(np.float32)), sample_rate=value["sampling_rate"]
129-
).to_file_like(buffer, format="wav")
134+
).to_file_like(buffer, format="wav", num_channels=self.num_channels)
130135
return {"bytes": buffer.getvalue(), "path": None}
131136
elif value.get("path") is not None and os.path.isfile(value["path"]):
132137
# we set "bytes": None to not duplicate the data if they're already available locally
@@ -143,7 +148,7 @@ def encode_example(self, value: Union[str, bytes, bytearray, dict, "AudioDecoder
143148

144149
buffer = BytesIO()
145150
AudioEncoder(torch.from_numpy(bytes_value), sample_rate=value["sampling_rate"]).to_file_like(
146-
buffer, format="wav"
151+
buffer, format="wav", num_channels=self.num_channels
147152
)
148153
return {"bytes": buffer.getvalue(), "path": None}
149154
else:
@@ -188,7 +193,9 @@ def decode_example(
188193
raise ValueError(f"An audio sample should have one of 'path' or 'bytes' but both are None in {value}.")
189194

190195
if bytes is None and is_local_path(path):
191-
audio = AudioDecoder(path, stream_index=self.stream_index, sample_rate=self.sampling_rate)
196+
audio = AudioDecoder(
197+
path, stream_index=self.stream_index, sample_rate=self.sampling_rate, num_channels=self.num_channels
198+
)
192199

193200
elif bytes is None:
194201
token_per_repo_id = token_per_repo_id or {}
@@ -201,10 +208,14 @@ def decode_example(
201208

202209
download_config = DownloadConfig(token=token)
203210
f = xopen(path, "rb", download_config=download_config)
204-
audio = AudioDecoder(f, stream_index=self.stream_index, sample_rate=self.sampling_rate)
211+
audio = AudioDecoder(
212+
f, stream_index=self.stream_index, sample_rate=self.sampling_rate, num_channels=self.num_channels
213+
)
205214

206215
else:
207-
audio = AudioDecoder(bytes, stream_index=self.stream_index, sample_rate=self.sampling_rate)
216+
audio = AudioDecoder(
217+
bytes, stream_index=self.stream_index, sample_rate=self.sampling_rate, num_channels=self.num_channels
218+
)
208219
audio._hf_encoded = {"path": path, "bytes": bytes}
209220
audio.metadata.path = path
210221
return audio
@@ -312,5 +323,8 @@ def encode_torchcodec_audio(audio: "AudioDecoder") -> dict:
312323

313324
samples = audio.get_all_samples()
314325
buffer = BytesIO()
315-
AudioEncoder(samples.data.cpu(), sample_rate=samples.sample_rate).to_file_like(buffer, format="wav")
326+
num_channels = samples.data.shape[0]
327+
AudioEncoder(samples.data.cpu(), sample_rate=samples.sample_rate).to_file_like(
328+
buffer, format="wav", num_channels=num_channels
329+
)
316330
return {"bytes": buffer.getvalue(), "path": None}

tests/features/test_audio.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -789,3 +789,31 @@ def test_audio_embed_storage(shared_datadir):
789789
embedded_storage = Audio().embed_storage(storage)
790790
embedded_example = embedded_storage.to_pylist()[0]
791791
assert embedded_example == {"bytes": open(audio_path, "rb").read(), "path": "test_audio_44100.wav"}
792+
793+
794+
@require_torchcodec
795+
def test_audio_decode_example_opus_convert_to_stereo(shared_datadir):
796+
# GH 7837
797+
from torchcodec.decoders import AudioDecoder
798+
799+
audio_path = str(shared_datadir / "test_audio_48000.opus") # mono file
800+
audio = Audio(num_channels=2)
801+
decoded_example = audio.decode_example(audio.encode_example(audio_path))
802+
assert isinstance(decoded_example, AudioDecoder)
803+
samples = decoded_example.get_all_samples()
804+
assert samples.sample_rate == 48000
805+
assert samples.data.shape == (2, 48000)
806+
807+
808+
@require_torchcodec
809+
def test_audio_decode_example_opus_convert_to_mono(shared_datadir):
810+
# GH 7837
811+
from torchcodec.decoders import AudioDecoder
812+
813+
audio_path = str(shared_datadir / "test_audio_44100.wav") # stereo file
814+
audio = Audio(num_channels=1)
815+
decoded_example = audio.decode_example(audio.encode_example(audio_path))
816+
assert isinstance(decoded_example, AudioDecoder)
817+
samples = decoded_example.get_all_samples()
818+
assert samples.sample_rate == 44100
819+
assert samples.data.shape == (1, 202311)

0 commit comments

Comments
 (0)