Skip to content

Commit 7e154ee

Browse files
committed
添加多卡预处理
1 parent ebc6e3f commit 7e154ee

File tree

7 files changed

+48
-11
lines changed

7 files changed

+48
-11
lines changed

basics/base_binarizer.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,23 @@ def process_dataset(self, prefix, num_workers=0, apply_augmentation=False):
310310

311311
aug_map = self.arrange_data_augmentation(self.meta_data_iterator(prefix)) if apply_augmentation else {}
312312

313+
device_ids = None
314+
num_workers = int(num_workers)
315+
if (
316+
num_workers > 0 and torch.cuda.is_available()
317+
and (torch.cuda.device_count() > 1)
318+
):
319+
per_gpu_workers = self.binarization_args.get('num_workers_per_gpu')
320+
if per_gpu_workers is None and self.binarization_args.get('workers_per_gpu', False):
321+
per_gpu_workers = num_workers
322+
if per_gpu_workers:
323+
per_gpu_workers = int(per_gpu_workers)
324+
device_ids = [
325+
gpu for gpu in range(torch.cuda.device_count())
326+
for _ in range(per_gpu_workers)
327+
]
328+
num_workers = len(device_ids)
329+
313330
def postprocess(_item):
314331
nonlocal total_sec, total_raw_sec, extra_info, max_no
315332
if _item is None:
@@ -349,7 +366,9 @@ def postprocess(_item):
349366
if num_workers > 0:
350367
# code for parallel processing
351368
for item in tqdm(
352-
chunked_multiprocess_run(self.process_item, args, num_workers=num_workers),
369+
chunked_multiprocess_run(
370+
self.process_item, args, num_workers=num_workers, device_ids=device_ids
371+
),
353372
total=len(list(self.meta_data_iterator(prefix)))
354373
):
355374
postprocess(item)

configs/original/base.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ binarizer_cls: null
1111
binarization_args:
1212
shuffle: false
1313
num_workers: 0
14+
workers_per_gpu: false
1415

1516
audio_sample_rate: 44100
1617
hop_size: 512

modules/pe/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
from .rmvpe import RMVPE
66

77

8-
def initialize_pe():
8+
def initialize_pe(device=None):
99
pe = hparams['pe']
1010
pe_ckpt = hparams['pe_ckpt']
1111
if pe == 'parselmouth':
1212
return ParselmouthPE()
1313
elif pe == 'rmvpe':
14-
return RMVPE(pe_ckpt)
14+
return RMVPE(pe_ckpt, device=device)
1515
elif pe == 'harvest':
1616
return HarvestPE()
1717
else:

modules/pe/rmvpe/inference.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313

1414

1515
class RMVPE(BasePE):
16-
def __init__(self, model_path, hop_length=160):
16+
def __init__(self, model_path, hop_length=160, device=None):
1717
self.resample_kernel = {}
18-
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
18+
if device is None:
19+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
20+
self.device = torch.device(device)
1921
self.model = E2E0(4, 1, (2, 2)).eval().to(self.device)
2022
ckpt = torch.load(model_path, map_location=self.device)
2123
self.model.load_state_dict(ckpt['model'], strict=False)

preprocessing/acoustic_binarizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def process_item(self, item_name, meta_data, binarization_args):
134134
# get ground truth f0
135135
global pitch_extractor
136136
if pitch_extractor is None:
137-
pitch_extractor = initialize_pe()
137+
pitch_extractor = initialize_pe(self.device)
138138
gt_f0, uv = pitch_extractor.get_pitch(
139139
waveform, samplerate=hparams['audio_sample_rate'], length=length,
140140
hop_size=hparams['hop_size'], f0_min=hparams['f0_min'], f0_max=hparams['f0_max'],
@@ -229,7 +229,7 @@ def arrange_data_augmentation(self, data_iterator):
229229
aug_list = []
230230
all_item_names = [item_name for item_name, _ in data_iterator]
231231
total_scale = 0
232-
aug_pe = initialize_pe()
232+
aug_pe = initialize_pe(self.device)
233233
if self.augmentation_args['random_pitch_shifting']['enabled']:
234234
from augmentation.spec_stretch import SpectrogramStretchAugmentation
235235
aug_args = self.augmentation_args['random_pitch_shifting']

preprocessing/variance_binarizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def process_item(self, item_name, meta_data, binarization_args):
300300

301301
global pitch_extractor
302302
if pitch_extractor is None:
303-
pitch_extractor = initialize_pe()
303+
pitch_extractor = initialize_pe(self.device)
304304
f0 = uv = None
305305
if self.prefer_ds:
306306
f0_seq = self.load_attr_from_ds(ds_id, name, 'f0_seq', idx=ds_seg_idx)

utils/multiprocess_utils.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,15 @@ def main_process_print(self, *args, sep=' ', end='\n', file=None):
1212
print(self, *args, sep=sep, end=end, file=file)
1313

1414

15-
def chunked_worker_run(map_func, args, results_queue=None):
15+
def chunked_worker_run(map_func, args, results_queue=None, device_id=None):
16+
if device_id is not None:
17+
try:
18+
import torch
19+
torch.cuda.set_device(device_id)
20+
if hasattr(map_func, '__self__') and map_func.__self__ is not None:
21+
map_func.__self__.device = torch.device(f'cuda:{device_id}')
22+
except Exception:
23+
traceback.print_exc()
1624
for a in args:
1725
# noinspection PyBroadException
1826
try:
@@ -25,10 +33,15 @@ def chunked_worker_run(map_func, args, results_queue=None):
2533
results_queue.put(None)
2634

2735

28-
def chunked_multiprocess_run(map_func, args, num_workers, q_max_size=1000):
36+
def chunked_multiprocess_run(map_func, args, num_workers, q_max_size=1000, device_ids=None):
2937
num_jobs = len(args)
3038
if num_jobs < num_workers:
3139
num_workers = num_jobs
40+
if device_ids is not None:
41+
device_ids = device_ids[:num_workers]
42+
43+
if device_ids is not None:
44+
assert len(device_ids) == num_workers
3245

3346
queues = [Manager().Queue(maxsize=q_max_size // num_workers) for _ in range(num_workers)]
3447
if platform.system().lower() != 'windows':
@@ -39,7 +52,9 @@ def chunked_multiprocess_run(map_func, args, num_workers, q_max_size=1000):
3952
workers = []
4053
for i in range(num_workers):
4154
worker = process_creation_func(
42-
target=chunked_worker_run, args=(map_func, args[i::num_workers], queues[i]), daemon=True
55+
target=chunked_worker_run,
56+
args=(map_func, args[i::num_workers], queues[i], None if device_ids is None else device_ids[i]),
57+
daemon=True
4358
)
4459
workers.append(worker)
4560
worker.start()

0 commit comments

Comments
 (0)