Skip to content

Commit 61b8427

Browse files
committed
- Added retries to dask steps
- Updated README with development information
1 parent 6894384 commit 61b8427

File tree

9 files changed

+89
-103
lines changed

9 files changed

+89
-103
lines changed

README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,23 @@ poetry run python -m stem_continuation_dataset_generator.process <STEM_NAME>
6565
```
6666

6767
The pipeline will augment, distort, encode and split the samples into chunks, generating three different folders for the train, validation and test sets. The result will be uploaded to ClearML into 3 different datasets.
68+
69+
### Development
70+
71+
Download the repository and install the package:
72+
73+
```sh
74+
git clone https://github.com/energydrink9/stem_continuation_dataset_generator.git
75+
cd stem_continuation_dataset_generator
76+
poetry install
77+
```
78+
79+
Once you've downloaded the repository and installed the package, please run the following command to setup the pre-commit hooks:
80+
```sh
81+
pre-commit install
82+
```
83+
84+
Please run the tests before submitting a PR:
85+
```sh
86+
pytest
87+
```

entrypoint.sh

Lines changed: 0 additions & 20 deletions
This file was deleted.

src/stem_continuation_dataset_generator/cluster.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import dask.config
44
from dask.distributed import Client, LocalCluster
55

6+
from stem_continuation_dataset_generator.constants import DASK_CLUSTER_NAME
7+
68
NUM_WORKERS = [4, 50]
79
BUCKET = 's3://stem-continuation-dataset'
810

@@ -16,10 +18,11 @@ def get_client(
1618
dask.config.set({'distributed.scheduler.allowed-failures': 12})
1719

1820
if run_locally is True:
19-
cluster = LocalCluster(n_workers=2, threads_per_worker=1, **kwargs)
21+
cluster = LocalCluster(n_workers=2, threads_per_worker=1)
2022

2123
else:
2224
cluster = coiled.Cluster(
25+
name=DASK_CLUSTER_NAME,
2326
n_workers=n_workers,
2427
package_sync_conda_extras=['portaudio', 'ffmpeg'],
2528
idle_timeout="5 minutes",

src/stem_continuation_dataset_generator/codec.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from stem_continuation_dataset_generator.utils.device import Device
1313

14-
ENCODER_BATCH_SIZE = 32
14+
ENCODER_BATCH_SIZE = 1
1515
ENCODED_TOKENS_PER_CHUNK = 512 # large values (over 1024) require a large amount of memory and can produce OOM errors
1616

1717

@@ -28,10 +28,10 @@ def get_processor(device: Device):
2828
return AutoProcessor.from_pretrained("facebook/encodec_32khz", device_map=device)
2929

3030

31-
def encode_file(audio_path: Union[BinaryIO, str, PathLike], device: Device, format: Optional[str] = None) -> Tuple[Tensor, float]:
31+
def encode_file(audio_path: Union[BinaryIO, str, PathLike], device: Device, format: Optional[str] = None, batch_size: int = ENCODER_BATCH_SIZE) -> Tuple[Tensor, float]:
3232
# Load and pre-process the audio waveform
3333
wav, sr = torchaudio.load(audio_path, format=format, normalize=False) # Normalization is later performed using librosa as it seems to work better
34-
return encode(wav, sr, device)
34+
return encode(wav, sr, device, batch_size=batch_size)
3535

3636

3737
def get_total_chunks(samples_per_chunk: int, num_samples: int) -> int:
@@ -57,7 +57,7 @@ def chunk_list(lst, n: int):
5757
yield lst[i:i + n]
5858

5959

60-
def encode(audio: Tensor, sr: int, device: Device) -> Tuple[Tensor, float]:
60+
def encode(audio: Tensor, sr: int, device: Device, batch_size: int = ENCODER_BATCH_SIZE) -> Tuple[Tensor, float]:
6161

6262
device = device if not device.startswith('mps') else 'cpu' # Encoding is not supported on MPS
6363
processor = get_processor(device)
@@ -85,7 +85,7 @@ def encode(audio: Tensor, sr: int, device: Device) -> Tuple[Tensor, float]:
8585
encoded_chunks = []
8686

8787
# create audio chunks
88-
batches: List[List[Tensor]] = list(chunk_list(chunks, ENCODER_BATCH_SIZE))
88+
batches: List[List[Tensor]] = list(chunk_list(chunks, batch_size))
8989

9090
for batch in batches:
9191
inputs = processor(raw_audio=batch, sampling_rate=processor.sampling_rate, return_tensors="pt")

src/stem_continuation_dataset_generator/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
CLEARML_DATASET_TRAINING_VERSION = '1.0.0'
66
DEFAULT_STEM_NAME = 'drum'
77
STORAGE_BUCKET_NAME = 'stem-continuation-dataset'
8+
DASK_CLUSTER_NAME = 'stem-continuation-dataset-generator-cluster'
89

910

1011
def get_original_files_path():

src/stem_continuation_dataset_generator/steps/augment.py

Lines changed: 32 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import io
22
import os
3-
import traceback
43
from typing import Any, List, Tuple, cast
54
from dask.distributed import Client
65
from distributed import progress
@@ -69,47 +68,42 @@ def augment(params: Tuple[S3FileSystem, str, str, str]) -> None:
6968

7069
fs, file_path, source_directory, output_directory = params
7170

72-
try:
73-
file_dir = os.path.dirname(file_path)
74-
stem_file_path = os.path.join(file_dir, 'stem.ogg')
71+
file_dir = os.path.dirname(file_path)
72+
stem_file_path = os.path.join(file_dir, 'stem.ogg')
73+
file_dir = os.path.dirname(file_path)
74+
relative_path = os.path.relpath(file_dir, source_directory)
75+
output_file_path = os.path.join(output_directory, relative_path + '-original')
76+
77+
full_track_output_file_path = os.path.join(output_file_path, os.path.basename(file_path))
78+
79+
if not fs.exists(full_track_output_file_path):
80+
fs.makedirs(os.path.dirname(full_track_output_file_path), exist_ok=True)
81+
if fs.exists(file_path):
82+
fs.copy(file_path, full_track_output_file_path)
83+
84+
stem_output_file_path = os.path.join(output_file_path, os.path.basename(stem_file_path))
85+
86+
if not fs.exists(stem_output_file_path):
87+
fs.makedirs(os.path.dirname(full_track_output_file_path), exist_ok=True)
88+
if fs.exists(stem_file_path):
89+
fs.copy(stem_file_path, stem_output_file_path)
90+
91+
for i in range(AUGMENTATIONS_COUNT):
7592
file_dir = os.path.dirname(file_path)
7693
relative_path = os.path.relpath(file_dir, source_directory)
77-
output_file_path = os.path.join(output_directory, relative_path + '-original')
78-
94+
output_file_path = os.path.join(output_directory, relative_path + f'-augmented{i}')
7995
full_track_output_file_path = os.path.join(output_file_path, os.path.basename(file_path))
80-
81-
if not fs.exists(full_track_output_file_path):
82-
fs.makedirs(os.path.dirname(full_track_output_file_path), exist_ok=True)
83-
if fs.exists(file_path):
84-
fs.copy(file_path, full_track_output_file_path)
85-
8696
stem_output_file_path = os.path.join(output_file_path, os.path.basename(stem_file_path))
8797

88-
if not fs.exists(stem_output_file_path):
89-
fs.makedirs(os.path.dirname(full_track_output_file_path), exist_ok=True)
90-
if fs.exists(stem_file_path):
91-
fs.copy(stem_file_path, stem_output_file_path)
92-
93-
for i in range(AUGMENTATIONS_COUNT):
94-
file_dir = os.path.dirname(file_path)
95-
relative_path = os.path.relpath(file_dir, source_directory)
96-
output_file_path = os.path.join(output_directory, relative_path + f'-augmented{i}')
97-
full_track_output_file_path = os.path.join(output_file_path, os.path.basename(file_path))
98-
stem_output_file_path = os.path.join(output_file_path, os.path.basename(stem_file_path))
99-
100-
if not fs.exists(full_track_output_file_path) or not fs.exists(stem_output_file_path):
101-
fs.makedirs(output_file_path, exist_ok=True)
102-
augment_pitch_and_tempo(
103-
fs,
104-
[
105-
(file_path, full_track_output_file_path),
106-
(stem_file_path, stem_output_file_path)
107-
]
108-
)
109-
110-
except Exception as e:
111-
print(f'Error augmenting file {file_path}: {e}')
112-
print(traceback.format_exc())
98+
if not fs.exists(full_track_output_file_path) or not fs.exists(stem_output_file_path):
99+
fs.makedirs(output_file_path, exist_ok=True)
100+
augment_pitch_and_tempo(
101+
fs,
102+
[
103+
(file_path, full_track_output_file_path),
104+
(stem_file_path, stem_output_file_path)
105+
]
106+
)
113107

114108

115109
def augment_all(source_directory: str, output_directory: str):
@@ -127,7 +121,7 @@ def augment_all(source_directory: str, output_directory: str):
127121
params_list: List[Tuple[S3FileSystem, str, str, str]] = [(fs, file_path, source_directory, output_directory) for file_path in files]
128122

129123
print('Augmenting audio tracks')
130-
futures = client.map(augment, params_list)
124+
futures = client.map(augment, params_list, retries=2)
131125
progress(futures)
132126

133127
return output_directory

src/stem_continuation_dataset_generator/steps/distort.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import io
22
import os
3-
import traceback
43
from typing import List, Tuple, cast
54
from fsspec import AbstractFileSystem
65
import numpy as np
@@ -78,24 +77,19 @@ def distort(params: Tuple[S3FileSystem, Tuple[str, str], str, str]) -> None:
7877

7978
fs, (full_track_file_path, stem_file_path), source_directory, output_directory = params
8079

81-
try:
82-
file_dir = os.path.dirname(full_track_file_path)
83-
full_track_relative_path = os.path.relpath(file_dir, source_directory)
84-
actual_output_dir = os.path.join(output_directory, full_track_relative_path)
85-
fs.makedirs(actual_output_dir, exist_ok=True)
86-
full_track_output_file_path = os.path.join(actual_output_dir, os.path.basename(full_track_file_path))
80+
file_dir = os.path.dirname(full_track_file_path)
81+
full_track_relative_path = os.path.relpath(file_dir, source_directory)
82+
actual_output_dir = os.path.join(output_directory, full_track_relative_path)
83+
fs.makedirs(actual_output_dir, exist_ok=True)
84+
full_track_output_file_path = os.path.join(actual_output_dir, os.path.basename(full_track_file_path))
8785

88-
if not fs.exists(full_track_output_file_path):
89-
distort_file(fs, full_track_file_path, full_track_output_file_path)
86+
if not fs.exists(full_track_output_file_path):
87+
distort_file(fs, full_track_file_path, full_track_output_file_path)
9088

91-
stem_relative_path = os.path.relpath(stem_file_path, source_directory)
92-
stem_output_file_path = os.path.join(output_directory, stem_relative_path)
93-
if not fs.exists(stem_output_file_path):
94-
fs.copy(stem_file_path, stem_output_file_path)
95-
96-
except Exception as e:
97-
print(f'Error processing {full_track_file_path} or {stem_file_path}: {e}')
98-
print(traceback.format_exc())
89+
stem_relative_path = os.path.relpath(stem_file_path, source_directory)
90+
stem_output_file_path = os.path.join(output_directory, stem_relative_path)
91+
if not fs.exists(stem_output_file_path):
92+
fs.copy(stem_file_path, stem_output_file_path)
9993

10094

10195
def distort_all(source_directory: str, output_directory: str):
@@ -110,7 +104,7 @@ def distort_all(source_directory: str, output_directory: str):
110104
))
111105

112106
print('Distorting audio tracks')
113-
futures = client.map(distort, params_list)
107+
futures = client.map(distort, params_list, retries=2)
114108
progress(futures)
115109

116110
return output_directory

src/stem_continuation_dataset_generator/steps/encode.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import os
22
import pickle
3-
import traceback
43
from typing import List, Tuple, cast
54
from distributed import Client, progress
65
from s3fs.core import S3FileSystem
@@ -22,26 +21,22 @@ def encode(params: Tuple[S3FileSystem, str, str, str]):
2221
fs, file_path, source_directory, output_directory = params
2322
device = get_device()
2423

25-
try:
26-
file_dir = os.path.dirname(file_path)
27-
relative_path = os.path.relpath(file_dir, source_directory)
28-
file_output_directory = os.path.join(output_directory, relative_path)
29-
fs.makedirs(file_output_directory, exist_ok=True)
24+
file_dir = os.path.dirname(file_path)
25+
relative_path = os.path.relpath(file_dir, source_directory)
26+
file_output_directory = os.path.join(output_directory, relative_path)
3027

31-
output_filename = os.path.basename(file_path)
32-
output_file_path = os.path.join(file_output_directory, output_filename)
28+
output_filename = os.path.basename(file_path).split('.')[0] + '.pkl'
29+
output_file_path = os.path.join(file_output_directory, output_filename)
3330

34-
if not fs.exists(output_file_path):
35-
with fs.open(file_path, 'rb') as file:
36-
encoded_audio, frame_rate = encode_file(file, device)
31+
if not fs.exists(output_file_path):
32+
with fs.open(file_path, 'rb') as file:
33+
encoded_audio, frame_rate = encode_file(file, device, batch_size=2)
3734

38-
if not fs.exists(output_file_path):
39-
with fs.open(output_file_path, 'wb') as output_file:
40-
pickle.dump(encoded_audio.detach().to('cpu'), output_file)
41-
42-
except Exception:
43-
print(f'Error while encoding file {file_path}')
44-
print(traceback.format_exc())
35+
fs.makedirs(file_output_directory, exist_ok=True)
36+
with fs.open(output_file_path, 'wb') as output_file:
37+
pickle.dump(encoded_audio.detach().to('cpu'), output_file)
38+
else:
39+
print(f'path {output_file_path} already exists')
4540

4641

4742
def encode_all(source_directory: str, output_directory: str):
@@ -53,7 +48,6 @@ def encode_all(source_directory: str, output_directory: str):
5348
client = cast(Client, get_client(
5449
RUN_LOCALLY,
5550
n_workers=[1, 1],
56-
# worker_vm_types=['c6a.xlarge'],
5751
worker_vm_types=['g4dn.xlarge'],
5852
scheduler_vm_types=['t3.medium'],
5953
spot_policy='spot',
@@ -66,7 +60,7 @@ def encode_all(source_directory: str, output_directory: str):
6660
# print(f'Processing {i} of {len(params_list)} {round(cast(float, i) / len(params_list) * 100)}')
6761
# encode(params_list[i])
6862

69-
futures = client.map(encode, params_list)
63+
futures = client.map(encode, params_list, retries=2, batch_size=8)
7064
progress(futures)
7165

7266
return output_directory

src/stem_continuation_dataset_generator/steps/merge.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def assort_and_merge_all(source_directory: str, output_directory: str, stem_name
218218
params_list: List[Tuple[S3FileSystem, str, str, str, str]] = [(fs, source_directory, output_directory, directory, stem_name) for directory in dirs]
219219

220220
print('Assorting and merging audio tracks')
221-
progress(client.map(assort_directory, params_list))
221+
progress(client.map(assort_directory, params_list, retries=2))
222222

223223
return output_directory
224224

0 commit comments

Comments
 (0)