diff --git a/bins/svc/preprocess.py b/bins/svc/preprocess.py index 453b5001..554a17bf 100644 --- a/bins/svc/preprocess.py +++ b/bins/svc/preprocess.py @@ -37,11 +37,8 @@ def extract_acoustic_features(dataset, output_path, cfg, n_workers=1): with open(dataset_file, "r") as f: metadata.extend(json.load(f)) - # acoustic_extractor.extract_utt_acoustic_features_parallel( - # metadata, dataset_output, cfg, n_workers=n_workers - # ) - acoustic_extractor.extract_utt_acoustic_features_serial( - metadata, dataset_output, cfg + acoustic_extractor.extract_utt_acoustic_features_parallel( + metadata, dataset_output, cfg, n_workers=n_workers ) diff --git a/processors/acoustic_extractor.py b/processors/acoustic_extractor.py index 9c4d9be7..bd645899 100644 --- a/processors/acoustic_extractor.py +++ b/processors/acoustic_extractor.py @@ -2,8 +2,9 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - import os +from functools import partial + import torch import numpy as np @@ -23,6 +24,7 @@ extract_linear_features, extract_mel_features_tts, ) +from concurrent.futures import as_completed, ProcessPoolExecutor ZERO = 1e-12 @@ -39,15 +41,24 @@ def extract_utt_acoustic_features_parallel(metadata, dataset_output, cfg, n_work Returns: list: acoustic features """ - for utt in tqdm(metadata): - if cfg.task_type == "tts": - extract_utt_acoustic_features_tts(dataset_output, cfg, utt) - if cfg.task_type == "svc": - extract_utt_acoustic_features_svc(dataset_output, cfg, utt) - if cfg.task_type == "vocoder": - extract_utt_acoustic_features_vocoder(dataset_output, cfg, utt) - if cfg.task_type == "tta": - extract_utt_acoustic_features_tta(dataset_output, cfg, utt) + extractor = None + if cfg.task_type == "tts": + extractor = partial(extract_utt_acoustic_features_tts, dataset_output, cfg) + if cfg.task_type == "svc": + extractor = partial(extract_utt_acoustic_features_svc, dataset_output, cfg) + if cfg.task_type == "vocoder": + extractor = partial(extract_utt_acoustic_features_vocoder, dataset_output, cfg) + if cfg.task_type == "tta": + extractor = partial(extract_utt_acoustic_features_tta, dataset_output, cfg) + + with ProcessPoolExecutor(max_workers=n_workers) as pool: + future_to_utt = {pool.submit(extractor, utt): utt for utt in metadata} + for future in tqdm(as_completed(future_to_utt), total=len(future_to_utt)): + utt = future_to_utt[future] + try: + future.result() + except Exception as exc: + print("%r generated an exception: %s" % (utt, exc)) def avg_phone_feature(feature, duration, interpolation=False):