Skip to content

Commit 0ecb5d2

Browse files
committed
Added a pipeline step for bulk stemming using Demucs
1 parent f1a0388 commit 0ecb5d2

File tree

8 files changed

+346
-14
lines changed

8 files changed

+346
-14
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
This application implements a pipeline that can be used to create audio datasets for the generation of stem continuations of music audio files. The code uses [Dask](https://www.dask.org/) in order to scale the dataset processing on a cluster of virtual machines in the cloud. The application is configured to run on AWS EC2 and to use S3 as storage. The audio files are encoded using Meta's [Encodec](https://github.com/facebookresearch/encodec) into a discrete, compressed, tokenized representation. Finally, the last step uploads the dataset to [ClearML](https://clear.ml) to be used for training and/or inference.
55

66
The dataset generation pipeline is comprised of several steps:
7+
- **Stem**. Creates drums, bass, guitar and other stems starting from MP3 files using [Demucs](https://github.com/adefossez/demucs)
78
- **Uncompress**. The application expects to find the stem files for a single music file (in .wav format) in a compressed zip archive. Each stem should have a predefined name in order to be identified as a guitar, bass, drum, etc.
89
- **Convert to ogg**. Conversion of wav files to the Ogg Opus audio format.
910
- **Merge**. Several different assortments of stems are generated.

poetry.lock

Lines changed: 229 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ torch = "^2.5.1"
4545
torchaudio = "^2.5.1"
4646
torchvision = "^0.20.1"
4747
accelerate = "^1.1.1"
48+
demucs = "^4.0.1"
4849

4950
[tool.poetry.dev-dependencies]
5051
flake8 = "^7.1.1"

src/stem_continuation_dataset_generator/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
DASK_CLUSTER_NAME = 'stem-continuation-dataset-generator-cluster'
99

1010

11+
def get_whole_tracks_files_path():
12+
return os.path.join(STORAGE_BUCKET_NAME, 'whole-tracks')
13+
14+
1115
def get_original_files_path():
1216
return os.path.join(STORAGE_BUCKET_NAME, 'original')
1317

src/stem_continuation_dataset_generator/pipeline.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
from stem_continuation_dataset_generator.constants import DATASET_TAGS, get_augmented_files_path, get_distorted_files_path, get_encoded_files_path, get_merged_files_path, get_original_files_path, get_split_files_path
1+
from stem_continuation_dataset_generator.constants import DATASET_TAGS, get_augmented_files_path, get_distorted_files_path, get_encoded_files_path, get_merged_files_path
2+
from stem_continuation_dataset_generator.constants import get_original_files_path, get_split_files_path, get_whole_tracks_files_path
23
from stem_continuation_dataset_generator.steps.augment import augment_all
34
from stem_continuation_dataset_generator.steps.convert_to_ogg import convert_to_ogg
45
from stem_continuation_dataset_generator.steps.encode import encode_all
56
from stem_continuation_dataset_generator.steps.merge import assort_and_merge_all
67
from stem_continuation_dataset_generator.steps.split import split_all
8+
from stem_continuation_dataset_generator.steps.stem import stem_all
79
from stem_continuation_dataset_generator.steps.uncompress import uncompress_files
810
from stem_continuation_dataset_generator.steps.upload import upload
911
from stem_continuation_dataset_generator.steps.distort import distort_all
@@ -33,6 +35,7 @@ def dataset_creation_pipeline(stem_name: str):
3335

3436
tags = DATASET_TAGS + [f'stem-{stem_name}']
3537

38+
stem_all(get_whole_tracks_files_path(), get_original_files_path())
3639
assort_and_merge_all(get_original_files_path(), get_merged_files_path(stem_name), stem_name)
3740
augment_all(get_merged_files_path(stem_name), get_augmented_files_path(stem_name))
3841
distort_all(get_augmented_files_path(stem_name), get_distorted_files_path(stem_name))

src/stem_continuation_dataset_generator/steps/merge.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
import os
44
import random
55
from typing import FrozenSet, List, Optional, Tuple, cast, Set
6-
import librosa
76
from pydub import AudioSegment
87
from dask.distributed import progress, Client
98
from s3fs.core import S3FileSystem
109

1110
from stem_continuation_dataset_generator.cluster import get_client
1211
from stem_continuation_dataset_generator.constants import DEFAULT_STEM_NAME, get_merged_files_path, get_original_files_path
1312
from stem_continuation_dataset_generator.utils.constants import get_random_seed
13+
from stem_continuation_dataset_generator.utils.utils import is_mostly_silent
1414

1515
STEM_NAMES = ['guitar', 'drum', 'bass', 'perc', 'fx', 'vocals', 'piano', 'synth', 'winds', 'strings']
1616
BASIC_STEM_NAMES = ['guitar', 'drum', 'bass', 'perc', 'gtr', 'drm', 'piano']
@@ -133,22 +133,17 @@ def create_stems_assortments(other_stems: List[StemFile], current_stem_file: str
133133
return [(current_stem_file, assortment) for assortment in assortments]
134134

135135

136-
def is_mostly_silent(fs: S3FileSystem, file_path: str) -> bool:
137-
with fs.open(file_path, 'rb') as file:
138-
139-
audio, sr = librosa.load(file) # type: ignore
140-
no_of_samples = audio.shape[-1]
141-
splits = librosa.effects.split(audio, top_db=60)
142-
non_silent_samples = sum([end - start for (start, end) in splits])
143-
return non_silent_samples / no_of_samples < MIN_PERCENTAGE_OF_AUDIO_IN_NON_SILENT_FILES
144-
145-
146136
def get_stem(file_path: str, silent: bool) -> StemFile:
147137
return StemFile(file_path=file_path, is_mostly_silent=silent)
148138

149139

140+
def is_remote_file_mostly_silent(fs: S3FileSystem, file_path: str):
141+
with fs.open(file_path, 'rb') as file:
142+
return is_mostly_silent(cast(io.TextIOWrapper, file), MIN_PERCENTAGE_OF_AUDIO_IN_NON_SILENT_FILES)
143+
144+
150145
def get_stems(fs: S3FileSystem, paths: List[str]) -> List[StemFile]:
151-
return [get_stem(path, is_mostly_silent(fs, path)) for path in paths]
146+
return [get_stem(path, is_remote_file_mostly_silent(fs, path)) for path in paths]
152147

153148

154149
def assort(fs: S3FileSystem, directory: str, stem_name: str) -> List[List[Tuple[str, FrozenSet[str]]]]:
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import glob
2+
import os
3+
import shlex
4+
import tempfile
5+
from typing import List, Tuple, cast
6+
7+
from distributed import Client, progress
8+
import demucs.separate
9+
from s3fs.core import S3FileSystem
10+
11+
from stem_continuation_dataset_generator.cluster import get_client
12+
from stem_continuation_dataset_generator.constants import get_original_files_path, get_whole_tracks_files_path
13+
from stem_continuation_dataset_generator.steps.convert_to_ogg import convert_to_ogg
14+
from stem_continuation_dataset_generator.utils.utils import is_mostly_silent
15+
16+
17+
RUN_LOCALLY = False
18+
PERCENTAGE_OF_NON_SILENT_AUDIO_FILE = 0.25
19+
EXCLUDED_STEMS = ['piano', 'vocals'] # Piano and vocals stems produced by Demucs are low quality
20+
21+
22+
def get_whole_track_files(fs: S3FileSystem, dir: str) -> List[str]:
23+
return cast(List[str], fs.glob(os.path.join(dir, '**/*.mp3')))
24+
25+
26+
def stem_file(output_directory: str, file_path: str) -> tuple[str, list[tuple[str, str]]]:
27+
"""
28+
Separates an audio file into its individual tracks using the Demucs model.
29+
30+
This function takes an audio file as input, separates it into its individual tracks using the Demucs model,
31+
and returns the directory where the separated tracks are stored along with a list of tuples containing the
32+
instrument name of each track and its corresponding file path.
33+
34+
Args:
35+
filename (str): The path to the audio file to be separated.
36+
37+
Returns:
38+
tuple[str, list[tuple[str, str]]]: A tuple containing the directory path where the separated tracks are stored,
39+
and a list of tuples where each tuple contains the instrument name of a track and its file path.
40+
"""
41+
demucs.separate.main(shlex.split(f'-n htdemucs_6s --clip-mode clamp --out "{output_directory}" "{file_path}"'))
42+
return (output_directory, [(os.path.splitext(os.path.basename(filename))[0], filename) for filename in glob.glob(os.path.join(output_directory, '**', '*.wav'), recursive=True)])
43+
44+
45+
def stem(params: Tuple[S3FileSystem, str, str, str, str]):
46+
fs, file_path, artist, source_directory, base_output_directory = params
47+
48+
basename = os.path.basename(file_path)
49+
song_name = basename.replace('.mp3', '')
50+
output_directory = os.path.join(base_output_directory, artist, song_name)
51+
52+
with tempfile.TemporaryDirectory() as local_directory:
53+
local_path = os.path.join(local_directory, basename)
54+
fs.download(file_path, local_path)
55+
stem_file(local_directory, local_path)
56+
os.remove(local_path)
57+
convert_to_ogg(local_directory)
58+
ogg_files = glob.glob(os.path.join(local_directory, '**/*.ogg'), recursive=True)
59+
for ogg_file in ogg_files:
60+
if os.path.basename(ogg_file).split('.')[0] not in EXCLUDED_STEMS:
61+
with open(ogg_file, 'rb') as file:
62+
if not is_mostly_silent(file, PERCENTAGE_OF_NON_SILENT_AUDIO_FILE):
63+
print(ogg_file)
64+
fs.upload(ogg_file, os.path.join(output_directory, os.path.basename(ogg_file)))
65+
66+
67+
def stem_all(source_directory: str, output_directory: str):
68+
69+
fs = S3FileSystem()
70+
files = get_whole_track_files(fs, source_directory)
71+
files_with_artist = [(file_path, os.path.dirname(file_path).split(os.path.sep)[-1]) for file_path in files]
72+
73+
client = cast(
74+
Client,
75+
get_client(
76+
RUN_LOCALLY,
77+
),
78+
)
79+
80+
params_list: List[Tuple[S3FileSystem, str, str, str, str]] = [(fs, file_path, artist, source_directory, output_directory) for file_path, artist in files_with_artist]
81+
82+
print('Stemming audio tracks')
83+
futures = client.map(stem, params_list, retries=2)
84+
progress(futures)
85+
86+
return output_directory
87+
88+
89+
if __name__ == '__main__':
90+
stem_all(get_whole_tracks_files_path(), get_original_files_path())

src/stem_continuation_dataset_generator/utils/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import io
12
from clearml import Dataset
23
import numpy as np
4+
import librosa
5+
from typing import Union
36

47
from stem_continuation_dataset_generator.constants import CLEARML_DATASET_NAME
58
from stem_continuation_dataset_generator.utils.constants import get_clearml_project_name
@@ -40,3 +43,10 @@ def convert_audio_to_float_32(audio_data: np.ndarray) -> np.ndarray:
4043
raw_data = audio_data / max_32bit
4144
return raw_data.astype(np.float32)
4245

46+
47+
def is_mostly_silent(file: Union[io.TextIOWrapper, io.BufferedReader], percentage_non_silent: float) -> bool:
48+
audio, sr = librosa.load(file) # type: ignore
49+
no_of_samples = audio.shape[-1]
50+
splits = librosa.effects.split(audio, top_db=60)
51+
non_silent_samples = sum([end - start for (start, end) in splits])
52+
return non_silent_samples / no_of_samples < percentage_non_silent

0 commit comments

Comments
 (0)