|
19 | 19 | import logging
|
20 | 20 | import os
|
21 | 21 | import re
|
22 |
| -import sys |
23 | 22 | import shutil
|
| 23 | +import sys |
24 | 24 | import time
|
25 |
| -from multiprocess import set_start_method |
| 25 | +from dataclasses import dataclass, field |
26 | 26 | from datetime import timedelta
|
27 |
| - |
28 |
| - |
29 |
| -import evaluate |
30 |
| -from tqdm import tqdm |
31 | 27 | from pathlib import Path
|
32 |
| -from dataclasses import dataclass, field |
33 |
| -from typing import Dict, List, Optional, Union, Set |
| 28 | +from typing import Dict, List, Optional, Set, Union |
34 | 29 |
|
35 | 30 | import datasets
|
| 31 | +import evaluate |
36 | 32 | import numpy as np
|
37 | 33 | import torch
|
38 |
| -from torch.utils.data import DataLoader |
39 |
| - |
40 |
| -from datasets import DatasetDict, load_dataset, Dataset, IterableDataset, interleave_datasets, concatenate_datasets |
41 |
| - |
42 |
| -from huggingface_hub import Repository, create_repo |
43 | 34 | import transformers
|
| 35 | +from accelerate import Accelerator |
| 36 | +from accelerate.utils import AutocastKwargs, InitProcessGroupKwargs, TorchDynamoPlugin, set_seed |
| 37 | +from accelerate.utils.memory import release_memory |
| 38 | +from datasets import Dataset, DatasetDict, IterableDataset, concatenate_datasets, interleave_datasets, load_dataset |
| 39 | +from huggingface_hub import HfApi |
| 40 | +from multiprocess import set_start_method |
| 41 | +from torch.utils.data import DataLoader |
| 42 | +from tqdm import tqdm |
44 | 43 | from transformers import (
|
45 | 44 | AutoFeatureExtractor,
|
46 | 45 | AutoModel,
|
47 | 46 | AutoProcessor,
|
48 | 47 | AutoTokenizer,
|
49 | 48 | HfArgumentParser,
|
50 | 49 | Seq2SeqTrainingArguments,
|
| 50 | + pipeline, |
51 | 51 | )
|
52 |
| -from transformers.trainer_pt_utils import LengthGroupedSampler |
53 |
| -from transformers import pipeline |
54 | 52 | from transformers.optimization import get_scheduler
|
| 53 | +from transformers.trainer_pt_utils import LengthGroupedSampler |
55 | 54 | from transformers.utils import send_example_telemetry
|
56 |
| -from transformers import AutoModel |
57 |
| - |
58 |
| - |
59 |
| -from accelerate import Accelerator |
60 |
| -from accelerate.utils import set_seed, AutocastKwargs, InitProcessGroupKwargs, TorchDynamoPlugin |
61 |
| -from accelerate.utils.memory import release_memory |
| 55 | +from wandb import Audio |
62 | 56 |
|
63 | 57 | from parler_tts import (
|
64 |
| - ParlerTTSForConditionalGeneration, |
65 | 58 | ParlerTTSConfig,
|
| 59 | + ParlerTTSForConditionalGeneration, |
66 | 60 | build_delay_pattern_mask,
|
67 | 61 | )
|
68 | 62 |
|
69 |
| -from wandb import Audio |
70 |
| - |
71 | 63 |
|
72 | 64 | logger = logging.getLogger(__name__)
|
73 | 65 |
|
@@ -1415,14 +1407,13 @@ def compute_metrics(audios, descriptions, prompts, device="cpu"):
|
1415 | 1407 |
|
1416 | 1408 | if accelerator.is_main_process:
|
1417 | 1409 | if training_args.push_to_hub:
|
1418 |
| - # Retrieve of infer repo_name |
| 1410 | + api = HfApi(token=training_args.hub_token) |
| 1411 | + |
| 1412 | + # Create repo (repo_name from args or inferred) |
1419 | 1413 | repo_name = training_args.hub_model_id
|
1420 | 1414 | if repo_name is None:
|
1421 | 1415 | repo_name = Path(training_args.output_dir).absolute().name
|
1422 |
| - # Create repo and retrieve repo_id |
1423 |
| - repo_id = create_repo(repo_name, exist_ok=True, token=training_args.hub_token).repo_id |
1424 |
| - # Clone repo locally |
1425 |
| - repo = Repository(training_args.output_dir, clone_from=repo_id, token=training_args.hub_token) |
| 1416 | + repo_id = api.create_repo(repo_name, exist_ok=True).repo_id |
1426 | 1417 |
|
1427 | 1418 | with open(os.path.join(training_args.output_dir, ".gitignore"), "w+") as gitignore:
|
1428 | 1419 | if "wandb" not in gitignore:
|
@@ -1624,9 +1615,11 @@ def generate_step(batch):
|
1624 | 1615 | unwrapped_model.save_pretrained(training_args.output_dir)
|
1625 | 1616 |
|
1626 | 1617 | if training_args.push_to_hub:
|
1627 |
| - repo.push_to_hub( |
| 1618 | + api.upload_folder( |
| 1619 | + repo_id=repo_id, |
| 1620 | + folder_path=training_args.output_dir, |
1628 | 1621 | commit_message=f"Saving train state of step {cur_step}",
|
1629 |
| - blocking=False, |
| 1622 | + run_as_future=True, |
1630 | 1623 | )
|
1631 | 1624 |
|
1632 | 1625 | if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps):
|
|
0 commit comments