Skip to content

Commit cd3d440

Browse files
committed
Use torchcodec for loading
1 parent a3fe94e commit cd3d440

File tree

14 files changed

+36
-20
lines changed

14 files changed

+36
-20
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Minimum runtime dependencies
22
torch
3+
torchcodec
34

45
# Optional runtime dependencies
56
kaldi_io

src/torchaudio/datasets/cmuarctic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Tuple, Union
55

66
import torchaudio
7+
from torchaudio.utils import load_torchcodec
78
from torch import Tensor
89
from torch.utils.data import Dataset
910
from torchaudio._internal import download_url_to_file
@@ -43,7 +44,7 @@ def load_cmuarctic_item(line: str, path: str, folder_audio: str, ext_audio: str)
4344
file_audio = os.path.join(path, folder_audio, utterance_id + ext_audio)
4445

4546
# Load audio
46-
waveform, sample_rate = torchaudio.load(file_audio)
47+
waveform, sample_rate = load_torchcodec(file_audio)
4748

4849
return (waveform, sample_rate, transcript, utterance_id.split("_")[1])
4950

src/torchaudio/datasets/commonvoice.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torchaudio
77
from torch import Tensor
88
from torch.utils.data import Dataset
9+
from torchaudio.utils import load_torchcodec
910

1011

1112
def load_commonvoice_item(
@@ -20,7 +21,7 @@ def load_commonvoice_item(
2021
filename = os.path.join(path, folder_audio, fileid)
2122
if not filename.endswith(ext_audio):
2223
filename += ext_audio
23-
waveform, sample_rate = torchaudio.load(filename)
24+
waveform, sample_rate = load_torchcodec(filename)
2425

2526
dic = dict(zip(header, line))
2627

src/torchaudio/datasets/dr_vctk.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torch.utils.data import Dataset
77
from torchaudio._internal import download_url_to_file
88
from torchaudio.datasets.utils import _extract_zip
9+
from torchaudio.utils import load_torchcodec
910

1011

1112
_URL = "https://datashare.ed.ac.uk/bitstream/handle/10283/3038/DR-VCTK.zip"
@@ -75,8 +76,8 @@ def _load_dr_vctk_item(self, filename: str) -> Tuple[Tensor, int, Tensor, int, s
7576
source, channel_id = self._config[filename]
7677
file_clean_audio = self._clean_audio_dir / filename
7778
file_noisy_audio = self._noisy_audio_dir / filename
78-
waveform_clean, sample_rate_clean = torchaudio.load(file_clean_audio)
79-
waveform_noisy, sample_rate_noisy = torchaudio.load(file_noisy_audio)
79+
waveform_clean, sample_rate_clean = load_torchcodec(file_clean_audio)
80+
waveform_noisy, sample_rate_noisy = load_torchcodec(file_noisy_audio)
8081
return (
8182
waveform_clean,
8283
sample_rate_clean,

src/torchaudio/datasets/gtzan.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from torch.utils.data import Dataset
88
from torchaudio._internal import download_url_to_file
99
from torchaudio.datasets.utils import _extract_tar
10+
from torchaudio.utils import load_torchcodec
1011

1112
# The following lists prefixed with `filtered_` provide a filtered split
1213
# that:
@@ -990,7 +991,7 @@ def load_gtzan_item(fileid: str, path: str, ext_audio: str) -> Tuple[Tensor, str
990991

991992
# Read wav
992993
file_audio = os.path.join(path, label, fileid + ext_audio)
993-
waveform, sample_rate = torchaudio.load(file_audio)
994+
waveform, sample_rate = load_torchcodec(file_audio)
994995

995996
return waveform, sample_rate, label
996997

src/torchaudio/datasets/librilight_limited.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from torchaudio._internal import download_url_to_file
99
from torchaudio.datasets.librispeech import _get_librispeech_metadata
1010
from torchaudio.datasets.utils import _extract_tar
11+
from torchaudio.utils import load_torchcodec
1112

1213

1314
_ARCHIVE_NAME = "librispeech_finetuning"
@@ -104,7 +105,7 @@ def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]:
104105
"""
105106
file_path, fileid = self._fileids_paths[n]
106107
metadata = _get_librispeech_metadata(fileid, self._path, file_path, self._ext_audio, self._ext_txt)
107-
waveform, _ = torchaudio.load(os.path.join(self._path, metadata[0]))
108+
waveform, _ = load_torchcodec(os.path.join(self._path, metadata[0]))
108109
return (waveform,) + metadata[1:]
109110

110111
def __len__(self) -> int:

src/torchaudio/datasets/libritts.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from torch.utils.data import Dataset
88
from torchaudio._internal import download_url_to_file
99
from torchaudio.datasets.utils import _extract_tar
10+
from torchaudio.utils import load_torchcodec
1011

1112
URL = "train-clean-100"
1213
FOLDER_IN_ARCHIVE = "LibriTTS"
@@ -41,7 +42,7 @@ def load_libritts_item(
4142
file_audio = os.path.join(path, speaker_id, chapter_id, file_audio)
4243

4344
# Load audio
44-
waveform, sample_rate = torchaudio.load(file_audio)
45+
waveform, sample_rate = load_torchcodec(file_audio)
4546

4647
# Load original text
4748
with open(original_text) as ft:

src/torchaudio/datasets/ljspeech.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from torch.utils.data import Dataset
99
from torchaudio._internal import download_url_to_file
1010
from torchaudio.datasets.utils import _extract_tar
11-
11+
from torchaudio.utils import load_torchcodec
1212

1313
_RELEASE_CONFIGS = {
1414
"release1": {
@@ -94,7 +94,7 @@ def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str]:
9494
fileid_audio = self._path / (fileid + ".wav")
9595

9696
# Load audio
97-
waveform, sample_rate = torchaudio.load(fileid_audio)
97+
waveform, sample_rate = load_torchcodec(fileid_audio)
9898

9999
return (
100100
waveform,

src/torchaudio/datasets/musdb_hq.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from torch.utils.data import Dataset
88
from torchaudio._internal import download_url_to_file
99
from torchaudio.datasets.utils import _extract_zip
10+
from torchaudio.utils import load_torchcodec
1011

1112
_URL = "https://zenodo.org/record/3338373/files/musdb18hq.zip"
1213
_CHECKSUM = "baac80d0483c61d74b2e5f3be75fa557eec52898339e6aa45c1fa48833c5d21d"
@@ -87,7 +88,7 @@ def _load_sample(self, n: int) -> Tuple[torch.Tensor, int, int, str]:
8788
num_frames = None
8889
for source in self.sources:
8990
track = self._get_track(name, source)
90-
wav, sr = torchaudio.load(str(track))
91+
wav, sr = load_torchcodec(str(track))
9192
if sr != _SAMPLE_RATE:
9293
raise ValueError(f"expected sample rate {_SAMPLE_RATE}, but got {sr}")
9394
if num_frames is None:

src/torchaudio/datasets/tedlium.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from torch.utils.data import Dataset
88
from torchaudio._internal import download_url_to_file
99
from torchaudio.datasets.utils import _extract_tar
10+
from torchaudio.utils import load_torchcodec
1011

1112

1213
_RELEASE_CONFIGS = {
@@ -163,12 +164,7 @@ def _load_audio(self, path: str, start_time: float, end_time: float, sample_rate
163164
Returns:
164165
[Tensor, int]: Audio tensor representation and sample rate
165166
"""
166-
start_time = int(float(start_time) * sample_rate)
167-
end_time = int(float(end_time) * sample_rate)
168-
169-
kwargs = {"frame_offset": start_time, "num_frames": end_time - start_time}
170-
171-
return torchaudio.load(path, **kwargs)
167+
return load_torchcodec(path, start_seconds=float(start_time), stop_seconds=float(end_time))
172168

173169
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]:
174170
"""Load the n-th sample from the dataset.

0 commit comments

Comments
 (0)