Skip to content

Commit a04d685

Browse files
committed
Don't use deprecated Repository anymore
1 parent 10016fb commit a04d685

File tree

1 file changed

+24
-31
lines changed

1 file changed

+24
-31
lines changed

training/run_parler_tts_training.py

Lines changed: 24 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -19,55 +19,47 @@
1919
import logging
2020
import os
2121
import re
22-
import sys
2322
import shutil
23+
import sys
2424
import time
25-
from multiprocess import set_start_method
25+
from dataclasses import dataclass, field
2626
from datetime import timedelta
27-
28-
29-
import evaluate
30-
from tqdm import tqdm
3127
from pathlib import Path
32-
from dataclasses import dataclass, field
33-
from typing import Dict, List, Optional, Union, Set
28+
from typing import Dict, List, Optional, Set, Union
3429

3530
import datasets
31+
import evaluate
3632
import numpy as np
3733
import torch
38-
from torch.utils.data import DataLoader
39-
40-
from datasets import DatasetDict, load_dataset, Dataset, IterableDataset, interleave_datasets, concatenate_datasets
41-
42-
from huggingface_hub import Repository, create_repo
4334
import transformers
35+
from accelerate import Accelerator
36+
from accelerate.utils import AutocastKwargs, InitProcessGroupKwargs, TorchDynamoPlugin, set_seed
37+
from accelerate.utils.memory import release_memory
38+
from datasets import Dataset, DatasetDict, IterableDataset, concatenate_datasets, interleave_datasets, load_dataset
39+
from huggingface_hub import HfApi
40+
from multiprocess import set_start_method
41+
from torch.utils.data import DataLoader
42+
from tqdm import tqdm
4443
from transformers import (
4544
AutoFeatureExtractor,
4645
AutoModel,
4746
AutoProcessor,
4847
AutoTokenizer,
4948
HfArgumentParser,
5049
Seq2SeqTrainingArguments,
50+
pipeline,
5151
)
52-
from transformers.trainer_pt_utils import LengthGroupedSampler
53-
from transformers import pipeline
5452
from transformers.optimization import get_scheduler
53+
from transformers.trainer_pt_utils import LengthGroupedSampler
5554
from transformers.utils import send_example_telemetry
56-
from transformers import AutoModel
57-
58-
59-
from accelerate import Accelerator
60-
from accelerate.utils import set_seed, AutocastKwargs, InitProcessGroupKwargs, TorchDynamoPlugin
61-
from accelerate.utils.memory import release_memory
55+
from wandb import Audio
6256

6357
from parler_tts import (
64-
ParlerTTSForConditionalGeneration,
6558
ParlerTTSConfig,
59+
ParlerTTSForConditionalGeneration,
6660
build_delay_pattern_mask,
6761
)
6862

69-
from wandb import Audio
70-
7163

7264
logger = logging.getLogger(__name__)
7365

@@ -1415,14 +1407,13 @@ def compute_metrics(audios, descriptions, prompts, device="cpu"):
14151407

14161408
if accelerator.is_main_process:
14171409
if training_args.push_to_hub:
1418-
# Retrieve of infer repo_name
1410+
api = HfApi(token=training_args.hub_token)
1411+
1412+
# Create repo (repo_name from args or inferred)
14191413
repo_name = training_args.hub_model_id
14201414
if repo_name is None:
14211415
repo_name = Path(training_args.output_dir).absolute().name
1422-
# Create repo and retrieve repo_id
1423-
repo_id = create_repo(repo_name, exist_ok=True, token=training_args.hub_token).repo_id
1424-
# Clone repo locally
1425-
repo = Repository(training_args.output_dir, clone_from=repo_id, token=training_args.hub_token)
1416+
repo_id = api.create_repo(repo_name, exist_ok=True).repo_id
14261417

14271418
with open(os.path.join(training_args.output_dir, ".gitignore"), "w+") as gitignore:
14281419
if "wandb" not in gitignore:
@@ -1624,9 +1615,11 @@ def generate_step(batch):
16241615
unwrapped_model.save_pretrained(training_args.output_dir)
16251616

16261617
if training_args.push_to_hub:
1627-
repo.push_to_hub(
1618+
api.upload_folder(
1619+
repo_id=repo_id,
1620+
folder_path=training_args.output_dir,
16281621
commit_message=f"Saving train state of step {cur_step}",
1629-
blocking=False,
1622+
run_as_future=True,
16301623
)
16311624

16321625
if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps):

0 commit comments

Comments
 (0)