Skip to content

Commit b2d9fcc

Browse files
authored
fix(audio): migrate from deprecated torchaudio (#1406)
1 parent 0e27a8e commit b2d9fcc

File tree

3 files changed

+93
-93
lines changed

3 files changed

+93
-93
lines changed

pyproject.toml

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ torch = [
7373
"transformers>=4.36.0"
7474
]
7575
audio = [
76-
"torchaudio",
7776
"soundfile"
7877
]
7978
remote = [
@@ -88,7 +87,11 @@ hf = [
8887
"datasets[vision]>=4.0.0",
8988
# https://github.com/pytorch/torchcodec/issues/640
9089
"datasets[audio]>=4.0.0 ; (sys_platform == 'linux' or sys_platform == 'darwin')",
91-
"fsspec>=2024.12.0"
90+
"fsspec>=2024.12.0",
91+
# Until datasets solve the issue, run test_hf_audio test to see if this can be removed
92+
# https://github.com/meta-pytorch/torchcodec/issues/912
93+
# https://github.com/huggingface/transformers/pull/41610
94+
"torch<2.9.0"
9295
]
9396
video = [
9497
"ffmpeg-python",
@@ -134,7 +137,9 @@ examples = [
134137
"huggingface_hub[hf_transfer]",
135138
"ultralytics",
136139
"open_clip_torch",
137-
"openai"
140+
"openai",
141+
# Transformers still require it
142+
"torchaudio<2.9.0"
138143
]
139144

140145
[project.urls]

src/datachain/lib/audio.py

Lines changed: 54 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import posixpath
2+
import re
23
from typing import TYPE_CHECKING
34

45
from datachain.lib.file import FileError
@@ -9,7 +10,7 @@
910
from datachain.lib.file import Audio, AudioFile, File
1011

1112
try:
12-
import torchaudio
13+
import soundfile as sf
1314
except ImportError as exc:
1415
raise ImportError(
1516
"Missing dependencies for processing audio.\n"
@@ -26,18 +27,25 @@ def audio_info(file: "File | AudioFile") -> "Audio":
2627

2728
try:
2829
with file.open() as f:
29-
info = torchaudio.info(f)
30+
info = sf.info(f)
31+
32+
sample_rate = int(info.samplerate)
33+
channels = int(info.channels)
34+
frames = int(info.frames)
35+
duration = float(info.duration)
3036

31-
sample_rate = int(info.sample_rate)
32-
channels = int(info.num_channels)
33-
frames = int(info.num_frames)
34-
duration = float(frames / sample_rate) if sample_rate > 0 else 0.0
37+
# soundfile provides format and subtype
38+
if info.format:
39+
format_name = info.format.lower()
40+
else:
41+
format_name = file.get_file_ext().lower()
3542

36-
codec_name = getattr(info, "encoding", "")
37-
file_ext = file.get_file_ext().lower()
38-
format_name = _encoding_to_format(codec_name, file_ext)
43+
if not format_name:
44+
format_name = "unknown"
45+
codec_name = info.subtype if info.subtype else ""
3946

40-
bits_per_sample = getattr(info, "bits_per_sample", 0)
47+
# Calculate bit rate from subtype
48+
bits_per_sample = _get_bits_per_sample(info.subtype)
4149
bit_rate = (
4250
bits_per_sample * sample_rate * channels if bits_per_sample > 0 else -1
4351
)
@@ -58,44 +66,39 @@ def audio_info(file: "File | AudioFile") -> "Audio":
5866
)
5967

6068

61-
def _encoding_to_format(encoding: str, file_ext: str) -> str:
69+
def _get_bits_per_sample(subtype: str) -> int:
6270
"""
63-
Map torchaudio encoding to a format name.
71+
Map soundfile subtype to bits per sample.
6472
6573
Args:
66-
encoding: The encoding string from torchaudio.info()
67-
file_ext: The file extension as a fallback
74+
subtype: The subtype string from soundfile
6875
6976
Returns:
70-
Format name as a string
77+
Bits per sample, or 0 if unknown
7178
"""
72-
# Direct mapping for formats that match exactly
73-
encoding_map = {
74-
"FLAC": "flac",
75-
"MP3": "mp3",
76-
"VORBIS": "ogg",
77-
"AMR_WB": "amr",
78-
"AMR_NB": "amr",
79-
"OPUS": "opus",
80-
"GSM": "gsm",
79+
if not subtype:
80+
return 0
81+
82+
# Common PCM and floating-point subtypes
83+
pcm_bits = {
84+
"PCM_16": 16,
85+
"PCM_24": 24,
86+
"PCM_32": 32,
87+
"PCM_S8": 8,
88+
"PCM_U8": 8,
89+
"FLOAT": 32,
90+
"DOUBLE": 64,
8191
}
8292

83-
if encoding in encoding_map:
84-
return encoding_map[encoding]
93+
if subtype in pcm_bits:
94+
return pcm_bits[subtype]
8595

86-
# For PCM variants, use file extension to determine format
87-
if encoding.startswith("PCM_"):
88-
# Common PCM formats by extension
89-
pcm_formats = {
90-
"wav": "wav",
91-
"aiff": "aiff",
92-
"au": "au",
93-
"raw": "raw",
94-
}
95-
return pcm_formats.get(file_ext, "wav") # Default to wav for PCM
96+
# Handle variants such as PCM_S16LE, PCM_F32LE, etc.
97+
match = re.search(r"PCM_(?:[A-Z]*?)(\d+)", subtype)
98+
if match:
99+
return int(match.group(1))
96100

97-
# Fallback to file extension if encoding is unknown
98-
return file_ext if file_ext else "unknown"
101+
return 0
99102

100103

101104
def audio_to_np(
@@ -114,27 +117,27 @@ def audio_to_np(
114117

115118
try:
116119
with audio.open() as f:
117-
info = torchaudio.info(f)
118-
sample_rate = info.sample_rate
120+
info = sf.info(f)
121+
sample_rate = info.samplerate
119122

120123
frame_offset = int(start * sample_rate)
121124
num_frames = int(duration * sample_rate) if duration is not None else -1
122125

123126
# Reset file pointer to the beginning
124-
# This is important to ensure we read from the correct position later
125127
f.seek(0)
126128

127-
waveform, sr = torchaudio.load(
128-
f, frame_offset=frame_offset, num_frames=num_frames
129+
# Read audio data with offset and frame count
130+
audio_np, sr = sf.read(
131+
f,
132+
start=frame_offset,
133+
frames=num_frames,
134+
always_2d=False,
135+
dtype="float32",
129136
)
130137

131-
audio_np = waveform.numpy()
132-
133-
if audio_np.shape[0] > 1:
134-
audio_np = audio_np.T
135-
else:
136-
audio_np = audio_np.squeeze()
137-
138+
# soundfile returns shape (frames,) for mono or
139+
# (frames, channels) for multi-channel
140+
# We keep this format as it matches expected output
138141
return audio_np, int(sr)
139142
except Exception as exc:
140143
raise FileError(
@@ -152,11 +155,9 @@ def audio_to_bytes(
152155
153156
If duration is None, converts from start to end of file.
154157
If start is 0 and duration is None, converts entire file."""
155-
y, sr = audio_to_np(audio, start, duration)
156-
157158
import io
158159

159-
import soundfile as sf
160+
y, sr = audio_to_np(audio, start, duration)
160161

161162
buffer = io.BytesIO()
162163
sf.write(buffer, y, sr, format=format)

tests/unit/lib/test_audio.py

Lines changed: 31 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -274,9 +274,7 @@ def test_save_audio_auto_format(tmp_path, catalog):
274274

275275
def test_audio_info_file_error(audio_file):
276276
"""Test audio_info handles file errors properly."""
277-
with patch(
278-
"datachain.lib.audio.torchaudio.info", side_effect=Exception("Test error")
279-
):
277+
with patch("datachain.lib.audio.sf.info", side_effect=Exception("Test error")):
280278
with pytest.raises(
281279
FileError, match="unable to extract metadata from audio file"
282280
):
@@ -285,9 +283,7 @@ def test_audio_info_file_error(audio_file):
285283

286284
def test_audio_fragment_np_file_error(audio_file):
287285
"""Test audio_fragment_np handles file errors properly."""
288-
with patch(
289-
"datachain.lib.audio.torchaudio.info", side_effect=Exception("Test error")
290-
):
286+
with patch("datachain.lib.audio.sf.info", side_effect=Exception("Test error")):
291287
with pytest.raises(FileError, match="unable to read audio fragment"):
292288
audio_to_np(audio_file)
293289

@@ -322,34 +318,30 @@ def test_audio_to_bytes_formats(audio_file, format_type):
322318

323319

324320
@pytest.mark.parametrize(
325-
"encoding,file_ext,expected_format",
321+
"format_str,subtype,file_ext,expected_format,expected_bit_rate",
326322
[
327-
# Test direct encoding mappings
328-
("FLAC", "flac", "flac"),
329-
("MP3", "mp3", "mp3"),
330-
("VORBIS", "ogg", "ogg"),
331-
("OPUS", "opus", "opus"),
332-
("AMR_WB", "amr", "amr"),
333-
("AMR_NB", "amr", "amr"),
334-
("GSM", "gsm", "gsm"),
335-
# Test PCM variants with different extensions
336-
("PCM_S16LE", "wav", "wav"),
337-
("PCM_S24LE", "aiff", "aiff"),
338-
("PCM_F32LE", "au", "au"),
339-
("PCM_U8", "raw", "raw"),
340-
("PCM_S16BE", "unknown_ext", "wav"), # Default for PCM
341-
# Test unknown encoding falls back to file extension
342-
("UNKNOWN_CODEC", "mp3", "mp3"),
343-
("UNKNOWN_CODEC", "flac", "flac"),
344-
# Test files without extension
345-
("UNKNOWN_CODEC", "", "unknown"),
346-
("", "", "unknown"),
323+
# Direct format mappings from soundfile
324+
("WAV", "PCM_16", "wav", "wav", 16 * 16000),
325+
("FLAC", "PCM_16", "flac", "flac", 16 * 16000),
326+
("OGG", "VORBIS", "ogg", "ogg", -1),
327+
("AIFF", "PCM_24", "aiff", "aiff", 24 * 16000),
328+
# Format fallback to file extension when subtype is PCM
329+
(None, "PCM_16", "wav", "wav", 16 * 16000),
330+
(None, "PCM_24", "aiff", "aiff", 24 * 16000),
331+
(None, "PCM_S16LE", "au", "au", 16 * 16000),
332+
(None, "PCM_F32LE", "wav", "wav", 32 * 16000),
333+
# Unknown format with extension falls back to extension
334+
(None, "UNKNOWN_CODEC", "mp3", "mp3", -1),
335+
("", "UNKNOWN_CODEC", "flac", "flac", -1),
336+
# Files without extension should fall back to "unknown"
337+
(None, "PCM_16", "", "unknown", 16 * 16000),
338+
("", "UNKNOWN_CODEC", "", "unknown", -1),
347339
],
348340
)
349341
def test_audio_info_format_detection(
350-
tmp_path, catalog, encoding, file_ext, expected_format
342+
tmp_path, catalog, format_str, subtype, file_ext, expected_format, expected_bit_rate
351343
):
352-
"""Test audio format detection for different file extensions and encodings."""
344+
"""Test audio format detection for different file extensions and formats."""
353345
# Create a test audio file with the specified extension
354346
filename = f"test_audio.{file_ext}" if file_ext else "test_audio"
355347
audio_data = generate_test_wav(duration=0.1, sample_rate=16000)
@@ -359,18 +351,20 @@ def test_audio_info_format_detection(
359351
audio_file = AudioFile(path=str(audio_path), source="file://")
360352
audio_file._set_stream(catalog, caching_enabled=False)
361353

362-
# Mock torchaudio.info to return controlled encoding
363-
with patch("datachain.lib.audio.torchaudio.info") as mock_info:
364-
mock_info.return_value.sample_rate = 16000
365-
mock_info.return_value.num_channels = 1
366-
mock_info.return_value.num_frames = 1600 # 0.1 seconds
367-
mock_info.return_value.encoding = encoding
368-
mock_info.return_value.bits_per_sample = 16
354+
# Mock soundfile.info to return controlled format
355+
with patch("datachain.lib.audio.sf.info") as mock_info:
356+
mock_info.return_value.samplerate = 16000
357+
mock_info.return_value.channels = 1
358+
mock_info.return_value.frames = 1600 # 0.1 seconds
359+
mock_info.return_value.duration = 0.1
360+
mock_info.return_value.format = format_str
361+
mock_info.return_value.subtype = subtype
369362

370363
result = audio_info(audio_file)
371364

372365
assert result.format == expected_format
373-
assert result.codec == encoding
366+
assert result.codec == subtype
367+
assert result.bit_rate == expected_bit_rate
374368

375369

376370
def test_audio_info_stereo(stereo_audio_file):

0 commit comments

Comments
 (0)