33
33
import datasets
34
34
from datasets import DatasetDict , Dataset , IterableDataset , concatenate_datasets
35
35
36
- from huggingface_hub import Repository , create_repo
36
+ from huggingface_hub import HfApi
37
+
37
38
import transformers
38
- from transformers import (
39
- AutoFeatureExtractor ,
40
- AutoTokenizer ,
41
- HfArgumentParser
42
- )
39
+ from transformers import AutoFeatureExtractor , AutoTokenizer , HfArgumentParser
43
40
from transformers .trainer_pt_utils import LengthGroupedSampler
44
41
from transformers .optimization import get_scheduler
45
42
from transformers .utils import send_example_telemetry
46
43
44
+
47
45
from accelerate import Accelerator
48
46
from accelerate .utils import set_seed , AutocastKwargs , InitProcessGroupKwargs , TorchDynamoPlugin
49
47
from accelerate .utils .memory import release_memory
50
48
51
49
from parler_tts import (
52
- ParlerTTSForConditionalGeneration ,
53
50
ParlerTTSConfig ,
51
+ ParlerTTSForConditionalGeneration ,
54
52
build_delay_pattern_mask ,
55
53
)
56
54
@@ -301,9 +299,7 @@ def main():
301
299
# update pad token id and decoder_start_token_id
302
300
config .update (
303
301
{
304
- "pad_token_id" : model_args .pad_token_id
305
- if model_args .pad_token_id is not None
306
- else config .pad_token_id ,
302
+ "pad_token_id" : model_args .pad_token_id if model_args .pad_token_id is not None else config .pad_token_id ,
307
303
"decoder_start_token_id" : model_args .decoder_start_token_id
308
304
if model_args .decoder_start_token_id is not None
309
305
else config .decoder_start_token_id ,
@@ -574,16 +570,18 @@ def compute_metrics(audios, descriptions, prompts, device="cpu"):
574
570
texts = description_tokenizer .batch_decode (input_ids , skip_special_tokens = True )
575
571
prompts = prompt_tokenizer .batch_decode (prompts , skip_special_tokens = True )
576
572
audios = [a .cpu ().numpy () for a in audios ]
577
-
573
+
578
574
clap_score = clap_similarity (model_args .clap_model_name_or_path , texts , audios , device )
579
575
results ["clap" ] = clap_score
580
576
581
- word_error , transcriptions = wer (model_args .asr_model_name_or_path ,
582
- prompts ,
583
- audios ,
584
- device ,
585
- training_args .per_device_eval_batch_size ,
586
- sampling_rate )
577
+ word_error , transcriptions = wer (
578
+ model_args .asr_model_name_or_path ,
579
+ prompts ,
580
+ audios ,
581
+ device ,
582
+ training_args .per_device_eval_batch_size ,
583
+ sampling_rate ,
584
+ )
587
585
results ["wer" ] = word_error
588
586
589
587
return results , texts , prompts , audios , transcriptions
@@ -673,14 +671,13 @@ def compute_metrics(audios, descriptions, prompts, device="cpu"):
673
671
674
672
if accelerator .is_main_process :
675
673
if training_args .push_to_hub :
676
- # Retrieve of infer repo_name
674
+ api = HfApi (token = training_args .hub_token )
675
+
676
+ # Create repo (repo_name from args or inferred)
677
677
repo_name = training_args .hub_model_id
678
678
if repo_name is None :
679
679
repo_name = Path (training_args .output_dir ).absolute ().name
680
- # Create repo and retrieve repo_id
681
- repo_id = create_repo (repo_name , exist_ok = True , token = training_args .hub_token ).repo_id
682
- # Clone repo locally
683
- repo = Repository (training_args .output_dir , clone_from = repo_id , token = training_args .hub_token )
680
+ repo_id = api .create_repo (repo_name , exist_ok = True ).repo_id
684
681
685
682
with open (os .path .join (training_args .output_dir , ".gitignore" ), "w+" ) as gitignore :
686
683
if "wandb" not in gitignore :
@@ -874,17 +871,21 @@ def generate_step(batch):
874
871
accelerator .save_state (output_dir = intermediate_dir , safe_serialization = False )
875
872
accelerator .wait_for_everyone ()
876
873
if accelerator .is_main_process :
877
- rotate_checkpoints (training_args .save_total_limit , output_dir = training_args .output_dir , logger = logger )
874
+ rotate_checkpoints (
875
+ training_args .save_total_limit , output_dir = training_args .output_dir , logger = logger
876
+ )
878
877
879
878
if cur_step == total_train_steps :
880
879
# un-wrap student model for save
881
880
unwrapped_model = accelerator .unwrap_model (model )
882
881
unwrapped_model .save_pretrained (training_args .output_dir )
883
882
884
883
if training_args .push_to_hub :
885
- repo .push_to_hub (
884
+ api .upload_folder (
885
+ repo_id = repo_id ,
886
+ folder_path = training_args .output_dir ,
886
887
commit_message = f"Saving train state of step { cur_step } " ,
887
- blocking = False ,
888
+ run_as_future = True ,
888
889
)
889
890
890
891
if training_args .do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps ):
@@ -1014,4 +1015,4 @@ def generate_step(batch):
1014
1015
1015
1016
if __name__ == "__main__" :
1016
1017
set_start_method ("spawn" )
1017
- main ()
1018
+ main ()
0 commit comments