Skip to content

Commit bdb0363

Browse files
authored
Merge pull request #48 from ylacombe/pr/Wauplin/18
Pr/wauplin/18
2 parents b2b749d + 3f5fd26 commit bdb0363

File tree

7 files changed

+48
-37
lines changed

7 files changed

+48
-37
lines changed

parler_tts/configuration_parler_tts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
4040
Args:
4141
vocab_size (`int`, *optional*, defaults to 2049):
4242
Vocabulary size of the ParlerTTSDecoder model. Defines the number of different tokens that can be
43-
represented by the `inputs_ids` passed when calling [`ParlerTTSDecoder`].
43+
represented by the `inputs_ids` passed when calling [`ParlerTTSDecoder`].
4444
hidden_size (`int`, *optional*, defaults to 1024):
4545
Dimensionality of the layers and the pooler layer.
4646
num_hidden_layers (`int`, *optional*, defaults to 24):

parler_tts/modeling_parler_tts.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1522,7 +1522,7 @@ def generate(
15221522
output_ids = outputs.sequences
15231523
else:
15241524
output_ids = outputs
1525-
1525+
15261526
# apply the pattern mask to the final ids
15271527
output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"])
15281528

@@ -2460,7 +2460,10 @@ def generate(
24602460
if "encoder_outputs" not in model_kwargs:
24612461
# encoder_outputs are created and added to `model_kwargs`
24622462
model_kwargs = self._prepare_text_encoder_kwargs_for_generation(
2463-
inputs_tensor, model_kwargs, model_input_name, generation_config,
2463+
inputs_tensor,
2464+
model_kwargs,
2465+
model_input_name,
2466+
generation_config,
24642467
)
24652468

24662469
if "prompt_hidden_states" not in model_kwargs and "prompt_input_ids" in model_kwargs:
@@ -2667,4 +2670,4 @@ def generate(
26672670
outputs.sequences = output_values
26682671
return outputs
26692672
else:
2670-
return output_values
2673+
return output_values

training/arguments.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from transformers import Seq2SeqTrainingArguments
55

6+
67
@dataclass
78
class ModelArguments:
89
"""
@@ -67,15 +68,18 @@ class ModelArguments:
6768
)
6869
asr_model_name_or_path: str = field(
6970
default="distil-whisper/distil-large-v2",
70-
metadata={"help": "Used to compute WER during evaluation. Path to pretrained model or model identifier from huggingface.co/models"}
71+
metadata={
72+
"help": "Used to compute WER during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
73+
},
7174
)
7275
clap_model_name_or_path: str = field(
7376
default="laion/larger_clap_music_and_speech",
74-
metadata={"help": "Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"}
77+
metadata={
78+
"help": "Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
79+
},
7580
)
7681

7782

78-
7983
@dataclass
8084
class DataTrainingArguments:
8185
"""

training/data.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from accelerate import Accelerator
1313

14+
1415
@dataclass
1516
class DataCollatorEncodecWithPadding:
1617
"""
@@ -301,4 +302,4 @@ def load_multiple_datasets(
301302
with accelerator.main_process_first():
302303
interleaved_dataset = concatenate_datasets(all_datasets)
303304

304-
return interleaved_dataset
305+
return interleaved_dataset

training/eval.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import torch
1+
import torch
22
import evaluate
33
from transformers import AutoModel, AutoProcessor, pipeline
44

@@ -20,6 +20,7 @@ def clap_similarity(clap_model_name_or_path, texts, audios, device):
2020
clap_inputs.to("cpu")
2121
return cosine_sim.mean().to("cpu")
2222

23+
2324
def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_size, sampling_rate):
2425
metric = evaluate.load("wer")
2526
asr_pipeline = pipeline(model=asr_model_name_or_path, device=device)
@@ -32,4 +33,4 @@ def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_s
3233
predictions=[t["text"].lower() for t in transcriptions], references=[t.lower() for t in prompts]
3334
)
3435

35-
return word_error, [t["text"] for t in transcriptions]
36+
return word_error, [t["text"] for t in transcriptions]

training/run_parler_tts_training.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -33,24 +33,22 @@
3333
import datasets
3434
from datasets import DatasetDict, Dataset, IterableDataset, concatenate_datasets
3535

36-
from huggingface_hub import Repository, create_repo
36+
from huggingface_hub import HfApi
37+
3738
import transformers
38-
from transformers import (
39-
AutoFeatureExtractor,
40-
AutoTokenizer,
41-
HfArgumentParser
42-
)
39+
from transformers import AutoFeatureExtractor, AutoTokenizer, HfArgumentParser
4340
from transformers.trainer_pt_utils import LengthGroupedSampler
4441
from transformers.optimization import get_scheduler
4542
from transformers.utils import send_example_telemetry
4643

44+
4745
from accelerate import Accelerator
4846
from accelerate.utils import set_seed, AutocastKwargs, InitProcessGroupKwargs, TorchDynamoPlugin
4947
from accelerate.utils.memory import release_memory
5048

5149
from parler_tts import (
52-
ParlerTTSForConditionalGeneration,
5350
ParlerTTSConfig,
51+
ParlerTTSForConditionalGeneration,
5452
build_delay_pattern_mask,
5553
)
5654

@@ -301,9 +299,7 @@ def main():
301299
# update pad token id and decoder_start_token_id
302300
config.update(
303301
{
304-
"pad_token_id": model_args.pad_token_id
305-
if model_args.pad_token_id is not None
306-
else config.pad_token_id,
302+
"pad_token_id": model_args.pad_token_id if model_args.pad_token_id is not None else config.pad_token_id,
307303
"decoder_start_token_id": model_args.decoder_start_token_id
308304
if model_args.decoder_start_token_id is not None
309305
else config.decoder_start_token_id,
@@ -574,16 +570,18 @@ def compute_metrics(audios, descriptions, prompts, device="cpu"):
574570
texts = description_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
575571
prompts = prompt_tokenizer.batch_decode(prompts, skip_special_tokens=True)
576572
audios = [a.cpu().numpy() for a in audios]
577-
573+
578574
clap_score = clap_similarity(model_args.clap_model_name_or_path, texts, audios, device)
579575
results["clap"] = clap_score
580576

581-
word_error, transcriptions = wer(model_args.asr_model_name_or_path,
582-
prompts,
583-
audios,
584-
device,
585-
training_args.per_device_eval_batch_size,
586-
sampling_rate)
577+
word_error, transcriptions = wer(
578+
model_args.asr_model_name_or_path,
579+
prompts,
580+
audios,
581+
device,
582+
training_args.per_device_eval_batch_size,
583+
sampling_rate,
584+
)
587585
results["wer"] = word_error
588586

589587
return results, texts, prompts, audios, transcriptions
@@ -673,14 +671,13 @@ def compute_metrics(audios, descriptions, prompts, device="cpu"):
673671

674672
if accelerator.is_main_process:
675673
if training_args.push_to_hub:
676-
# Retrieve of infer repo_name
674+
api = HfApi(token=training_args.hub_token)
675+
676+
# Create repo (repo_name from args or inferred)
677677
repo_name = training_args.hub_model_id
678678
if repo_name is None:
679679
repo_name = Path(training_args.output_dir).absolute().name
680-
# Create repo and retrieve repo_id
681-
repo_id = create_repo(repo_name, exist_ok=True, token=training_args.hub_token).repo_id
682-
# Clone repo locally
683-
repo = Repository(training_args.output_dir, clone_from=repo_id, token=training_args.hub_token)
680+
repo_id = api.create_repo(repo_name, exist_ok=True).repo_id
684681

685682
with open(os.path.join(training_args.output_dir, ".gitignore"), "w+") as gitignore:
686683
if "wandb" not in gitignore:
@@ -874,17 +871,21 @@ def generate_step(batch):
874871
accelerator.save_state(output_dir=intermediate_dir, safe_serialization=False)
875872
accelerator.wait_for_everyone()
876873
if accelerator.is_main_process:
877-
rotate_checkpoints(training_args.save_total_limit, output_dir=training_args.output_dir, logger=logger)
874+
rotate_checkpoints(
875+
training_args.save_total_limit, output_dir=training_args.output_dir, logger=logger
876+
)
878877

879878
if cur_step == total_train_steps:
880879
# un-wrap student model for save
881880
unwrapped_model = accelerator.unwrap_model(model)
882881
unwrapped_model.save_pretrained(training_args.output_dir)
883882

884883
if training_args.push_to_hub:
885-
repo.push_to_hub(
884+
api.upload_folder(
885+
repo_id=repo_id,
886+
folder_path=training_args.output_dir,
886887
commit_message=f"Saving train state of step {cur_step}",
887-
blocking=False,
888+
run_as_future=True,
888889
)
889890

890891
if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps):
@@ -1014,4 +1015,4 @@ def generate_step(batch):
10141015

10151016
if __name__ == "__main__":
10161017
set_start_method("spawn")
1017-
main()
1018+
main()

training/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99
from wandb import Audio
1010

11+
1112
def list_field(default=None, metadata=None):
1213
return field(default_factory=lambda: default, metadata=metadata)
1314

@@ -121,4 +122,4 @@ def log_pred(
121122
]
122123
},
123124
step=step,
124-
)
125+
)

0 commit comments

Comments
 (0)