Skip to content

Commit aa4cbf2

Browse files
committed
make style
1 parent 9271958 commit aa4cbf2

File tree

7 files changed

+36
-29
lines changed

7 files changed

+36
-29
lines changed

parler_tts/configuration_parler_tts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
4040
Args:
4141
vocab_size (`int`, *optional*, defaults to 2049):
4242
Vocabulary size of the ParlerTTSDecoder model. Defines the number of different tokens that can be
43-
represented by the `inputs_ids` passed when calling [`ParlerTTSDecoder`].
43+
represented by the `inputs_ids` passed when calling [`ParlerTTSDecoder`].
4444
hidden_size (`int`, *optional*, defaults to 1024):
4545
Dimensionality of the layers and the pooler layer.
4646
num_hidden_layers (`int`, *optional*, defaults to 24):

parler_tts/modeling_parler_tts.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1522,7 +1522,7 @@ def generate(
15221522
output_ids = outputs.sequences
15231523
else:
15241524
output_ids = outputs
1525-
1525+
15261526
# apply the pattern mask to the final ids
15271527
output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"])
15281528

@@ -2460,7 +2460,10 @@ def generate(
24602460
if "encoder_outputs" not in model_kwargs:
24612461
# encoder_outputs are created and added to `model_kwargs`
24622462
model_kwargs = self._prepare_text_encoder_kwargs_for_generation(
2463-
inputs_tensor, model_kwargs, model_input_name, generation_config,
2463+
inputs_tensor,
2464+
model_kwargs,
2465+
model_input_name,
2466+
generation_config,
24642467
)
24652468

24662469
if "prompt_hidden_states" not in model_kwargs and "prompt_input_ids" in model_kwargs:
@@ -2667,4 +2670,4 @@ def generate(
26672670
outputs.sequences = output_values
26682671
return outputs
26692672
else:
2670-
return output_values
2673+
return output_values

training/arguments.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from transformers import Seq2SeqTrainingArguments
55

6+
67
@dataclass
78
class ModelArguments:
89
"""
@@ -67,15 +68,18 @@ class ModelArguments:
6768
)
6869
asr_model_name_or_path: str = field(
6970
default="distil-whisper/distil-large-v2",
70-
metadata={"help": "Used to compute WER during evaluation. Path to pretrained model or model identifier from huggingface.co/models"}
71+
metadata={
72+
"help": "Used to compute WER during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
73+
},
7174
)
7275
clap_model_name_or_path: str = field(
7376
default="laion/larger_clap_music_and_speech",
74-
metadata={"help": "Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"}
77+
metadata={
78+
"help": "Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
79+
},
7580
)
7681

7782

78-
7983
@dataclass
8084
class DataTrainingArguments:
8185
"""

training/data.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from accelerate import Accelerator
1313

14+
1415
@dataclass
1516
class DataCollatorEncodecWithPadding:
1617
"""
@@ -301,4 +302,4 @@ def load_multiple_datasets(
301302
with accelerator.main_process_first():
302303
interleaved_dataset = concatenate_datasets(all_datasets)
303304

304-
return interleaved_dataset
305+
return interleaved_dataset

training/eval.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import torch
1+
import torch
22
import evaluate
33
from transformers import AutoModel, AutoProcessor, pipeline
44

@@ -20,6 +20,7 @@ def clap_similarity(clap_model_name_or_path, texts, audios, device):
2020
clap_inputs.to("cpu")
2121
return cosine_sim.mean().to("cpu")
2222

23+
2324
def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_size, sampling_rate):
2425
metric = evaluate.load("wer")
2526
asr_pipeline = pipeline(model=asr_model_name_or_path, device=device)
@@ -32,4 +33,4 @@ def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_s
3233
predictions=[t["text"].lower() for t in transcriptions], references=[t.lower() for t in prompts]
3334
)
3435

35-
return word_error, [t["text"] for t in transcriptions]
36+
return word_error, [t["text"] for t in transcriptions]

training/run_parler_tts_training.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import re
2222
import sys
2323
import time
24-
from dataclasses import dataclass, field
2524
from datetime import timedelta
2625

2726
from tqdm import tqdm
@@ -38,11 +37,7 @@
3837
from multiprocess import set_start_method
3938
from torch.utils.data import DataLoader
4039
from tqdm import tqdm
41-
from transformers import (
42-
AutoFeatureExtractor,
43-
AutoTokenizer,
44-
HfArgumentParser
45-
)
40+
from transformers import AutoFeatureExtractor, AutoTokenizer, HfArgumentParser
4641
from transformers.trainer_pt_utils import LengthGroupedSampler
4742
from transformers.optimization import get_scheduler
4843
from transformers.trainer_pt_utils import LengthGroupedSampler
@@ -306,9 +301,7 @@ def main():
306301
# update pad token id and decoder_start_token_id
307302
config.update(
308303
{
309-
"pad_token_id": model_args.pad_token_id
310-
if model_args.pad_token_id is not None
311-
else config.pad_token_id,
304+
"pad_token_id": model_args.pad_token_id if model_args.pad_token_id is not None else config.pad_token_id,
312305
"decoder_start_token_id": model_args.decoder_start_token_id
313306
if model_args.decoder_start_token_id is not None
314307
else config.decoder_start_token_id,
@@ -579,16 +572,18 @@ def compute_metrics(audios, descriptions, prompts, device="cpu"):
579572
texts = description_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
580573
prompts = prompt_tokenizer.batch_decode(prompts, skip_special_tokens=True)
581574
audios = [a.cpu().numpy() for a in audios]
582-
575+
583576
clap_score = clap_similarity(model_args.clap_model_name_or_path, texts, audios, device)
584577
results["clap"] = clap_score
585578

586-
word_error, transcriptions = wer(model_args.asr_model_name_or_path,
587-
prompts,
588-
audios,
589-
device,
590-
training_args.per_device_eval_batch_size,
591-
sampling_rate)
579+
word_error, transcriptions = wer(
580+
model_args.asr_model_name_or_path,
581+
prompts,
582+
audios,
583+
device,
584+
training_args.per_device_eval_batch_size,
585+
sampling_rate,
586+
)
592587
results["wer"] = word_error
593588

594589
return results, texts, prompts, audios, transcriptions
@@ -878,7 +873,9 @@ def generate_step(batch):
878873
accelerator.save_state(output_dir=intermediate_dir, safe_serialization=False)
879874
accelerator.wait_for_everyone()
880875
if accelerator.is_main_process:
881-
rotate_checkpoints(training_args.save_total_limit, output_dir=training_args.output_dir, logger=logger)
876+
rotate_checkpoints(
877+
training_args.save_total_limit, output_dir=training_args.output_dir, logger=logger
878+
)
882879

883880
if cur_step == total_train_steps:
884881
# un-wrap student model for save
@@ -1020,4 +1017,4 @@ def generate_step(batch):
10201017

10211018
if __name__ == "__main__":
10221019
set_start_method("spawn")
1023-
main()
1020+
main()

training/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99
from wandb import Audio
1010

11+
1112
def list_field(default=None, metadata=None):
1213
return field(default_factory=lambda: default, metadata=metadata)
1314

@@ -121,4 +122,4 @@ def log_pred(
121122
]
122123
},
123124
step=step,
124-
)
125+
)

0 commit comments

Comments
 (0)