21
21
import re
22
22
import sys
23
23
import time
24
- from dataclasses import dataclass , field
25
24
from datetime import timedelta
26
25
27
26
from tqdm import tqdm
38
37
from multiprocess import set_start_method
39
38
from torch .utils .data import DataLoader
40
39
from tqdm import tqdm
41
- from transformers import (
42
- AutoFeatureExtractor ,
43
- AutoTokenizer ,
44
- HfArgumentParser
45
- )
40
+ from transformers import AutoFeatureExtractor , AutoTokenizer , HfArgumentParser
46
41
from transformers .trainer_pt_utils import LengthGroupedSampler
47
42
from transformers .optimization import get_scheduler
48
43
from transformers .trainer_pt_utils import LengthGroupedSampler
@@ -306,9 +301,7 @@ def main():
306
301
# update pad token id and decoder_start_token_id
307
302
config .update (
308
303
{
309
- "pad_token_id" : model_args .pad_token_id
310
- if model_args .pad_token_id is not None
311
- else config .pad_token_id ,
304
+ "pad_token_id" : model_args .pad_token_id if model_args .pad_token_id is not None else config .pad_token_id ,
312
305
"decoder_start_token_id" : model_args .decoder_start_token_id
313
306
if model_args .decoder_start_token_id is not None
314
307
else config .decoder_start_token_id ,
@@ -579,16 +572,18 @@ def compute_metrics(audios, descriptions, prompts, device="cpu"):
579
572
texts = description_tokenizer .batch_decode (input_ids , skip_special_tokens = True )
580
573
prompts = prompt_tokenizer .batch_decode (prompts , skip_special_tokens = True )
581
574
audios = [a .cpu ().numpy () for a in audios ]
582
-
575
+
583
576
clap_score = clap_similarity (model_args .clap_model_name_or_path , texts , audios , device )
584
577
results ["clap" ] = clap_score
585
578
586
- word_error , transcriptions = wer (model_args .asr_model_name_or_path ,
587
- prompts ,
588
- audios ,
589
- device ,
590
- training_args .per_device_eval_batch_size ,
591
- sampling_rate )
579
+ word_error , transcriptions = wer (
580
+ model_args .asr_model_name_or_path ,
581
+ prompts ,
582
+ audios ,
583
+ device ,
584
+ training_args .per_device_eval_batch_size ,
585
+ sampling_rate ,
586
+ )
592
587
results ["wer" ] = word_error
593
588
594
589
return results , texts , prompts , audios , transcriptions
@@ -878,7 +873,9 @@ def generate_step(batch):
878
873
accelerator .save_state (output_dir = intermediate_dir , safe_serialization = False )
879
874
accelerator .wait_for_everyone ()
880
875
if accelerator .is_main_process :
881
- rotate_checkpoints (training_args .save_total_limit , output_dir = training_args .output_dir , logger = logger )
876
+ rotate_checkpoints (
877
+ training_args .save_total_limit , output_dir = training_args .output_dir , logger = logger
878
+ )
882
879
883
880
if cur_step == total_train_steps :
884
881
# un-wrap student model for save
@@ -1020,4 +1017,4 @@ def generate_step(batch):
1020
1017
1021
1018
if __name__ == "__main__" :
1022
1019
set_start_method ("spawn" )
1023
- main ()
1020
+ main ()
0 commit comments