-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcli.py
More file actions
133 lines (105 loc) · 4.03 KB
/
cli.py
File metadata and controls
133 lines (105 loc) · 4.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import argparse
from pathlib import Path
import whisper
from pygtrans import TARGET_LANGUAGES
from pygtrans import Translate
from tqdm import tqdm
from whisper.tokenizer import TO_LANGUAGE_CODE
# fmt: off
MODELS = [
"tiny.en", "tiny", "base.en", "base", "small.en", "small", "medium.en", "medium", "large-v1", "large-v2", "large-v3", "large", "large-v3-turbo", "turbo"
]
SOURCES = [i for i in TO_LANGUAGE_CODE]
AUDIO_SUFFIXS = {
".mp4", ".mkv", ".avi", ".mov", ".flv", ".wmv", ".webm", ".3gp", ".mpeg", ".mpg", ".m4v",
".mp3", ".wav", ".flac", ".aac", ".ogg", ".m4a", ".wma", # '.aiff', '.alac'
}
# fmt: on
def print_models():
print("Available models:")
for model in MODELS:
print(f" - {model}")
def print_tasks():
print("Available tasks:")
print(" - transcribe")
print(" - translate")
def print_sources():
print("Available source languages:")
for source in SOURCES:
print(f" - {TO_LANGUAGE_CODE[source]}({source})")
def print_targets():
print("Available target languages:")
for code, name in TARGET_LANGUAGES.items():
print(f" - {code} ({name})")
def parse_args(parser):
parser.add_argument("--model", nargs="?", const=None, default="turbo", help="Model to use for whisper")
parser.add_argument("--task", nargs="?", const=None, default="transcribe", help="Task to use for whisper")
parser.add_argument("--source", nargs="?", const=False, default=None, help="Source language of the audio")
parser.add_argument(
"--target", nargs="?", const=None, default="zh-CN", help="Target language for the translated subtitles"
)
parser.add_argument("--proxy", help="Proxy for translate subtitles")
parser.add_argument("files", nargs="*", help="Audio/video file or folder")
return parser.parse_args()
def seconds_to_srt_time(seconds):
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
seconds = int(seconds % 60)
milliseconds = int((seconds - int(seconds)) * 1000)
return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"
def generate_srt(results):
srts = []
for i in results["segments"]:
srts.append(f'{i["id"] + 1}')
srts.append(f"{seconds_to_srt_time(i['start'])} --> {seconds_to_srt_time(i['end'])}")
srts.append(i["text"])
srts.append("")
return "\n".join(srts)
def save_srt(file: Path, lang, srt):
file.with_name(f"{file.stem}-{lang}.srt").write_text(srt, encoding="utf8")
def process(model, file, task, source, target, at: Translate):
results = model.transcribe(str(file), task=task, language=source)
source_srt = generate_srt(results)
save_srt(file, results["language"], source_srt)
texts = [i["text"] for i in results["segments"]]
target_texts = at.translate(texts)
for k, v in zip(results["segments"], target_texts):
k["text"] = v.translatedText
target_srt = generate_srt(results)
save_srt(file, target, target_srt)
def main():
parser = argparse.ArgumentParser()
args = parse_args(parser)
if args.model is None:
print_models()
return
if args.task is None:
print_tasks()
return
if args.source is False:
print_sources()
return
if args.target is None:
print_targets()
return
if not args.files:
parser.print_help()
return
all_files = set()
for i in tqdm(args.files, desc="收集文件"):
j = Path(i)
if j.is_file():
all_files.add(j)
else:
for k in j.rglob("*"):
if k.suffix.lower() in AUDIO_SUFFIXS:
all_files.add(k)
at = Translate(target=args.target, fmt="text", proxies={"https": args.proxy} if args.proxy else None)
model = whisper.load_model(args.model)
with tqdm(total=len(all_files), desc="转录字幕") as pbar:
for i in all_files:
pbar.set_postfix_str(i)
process(model, i, task=args.task, source=args.source, target=args.target, at=at)
pbar.update(1)
if __name__ == "__main__":
main()