forked from chaosen315/AIwork4translator
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
226 lines (224 loc) · 10.9 KB
/
main.py
File metadata and controls
226 lines (224 loc) · 10.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
from dotenv import load_dotenv
from time import perf_counter
import csv
load_dotenv(dotenv_path="data/.env")
import os, time, json
from modules.config import global_config, setup_runtime_config
from modules.read_tool import read_structured_paragraphs
from modules.csv_process_tool import get_valid_path, validate_csv_file, load_terms_dict, find_matching_terms
from modules.terminology_tool import load_glossary_df, merge_new_terms, save_glossary_df, dict_to_df
from modules.api_tool import LLMService
from modules.write_out_tool import write_to_markdown
from modules.markitdown_tool import markitdown_tool
from modules.count_tool import count_md_words, count_structured_paragraphs
def main():
PS = global_config.preserve_structure
CHUNK_SIZE = global_config.max_chunk_size
prefs_path = os.path.join("data", ".prefs.json")
prefs = {}
if os.path.exists(prefs_path):
try:
with open(prefs_path, "r", encoding="utf-8") as f:
prefs = json.load(f)
except Exception:
prefs = {}
while True:
default_provider = prefs.get("last_provider")
provider_prompt = "需要确认API平台(\"kimi\",\"gpt\",\"deepseek\",\"sillion\",\"gemini\")。"
if default_provider:
provider_prompt += f"上一次使用{default_provider},如继续使用请按回车,如有更换请输入:"
else:
provider_prompt += ": "
provider_in = input(provider_prompt).strip()
if not provider_in and default_provider:
provider_in = default_provider
llm_service = LLMService(provider=provider_in)
print('')
default_input = prefs.get("last_input_md_file")
file_prompt = "需要原文文件路径。"
if default_input:
file_prompt += f"上一次使用{default_input},如继续使用请按回车,如有更换请输入:"
else:
file_prompt += ": "
input_file = input(file_prompt).strip()
if not input_file and default_input:
input_file = default_input
if not input_file:
print("输入不能为空,请重新输入。")
continue
# 智能去除双引号 - 处理用户输入时可能自带的引号
input_file = input_file.strip('"\'') # 去除开头和结尾的双引号和单引号
if not os.path.exists(input_file):
print(f"错误:文件 {input_file} 不存在,请检查路径后重试。")
continue
_, check_extension = os.path.splitext(input_file)
if check_extension.lower() != '.md':
print(f"已输入的文件格式不是Markdown格式,而是{check_extension}。正在使用Markitdown插件进行格式转换……")
input_md_file = markitdown_tool(input_file)
if not input_md_file:
print("请更换有效文件格式后重试。有效的文件格式例如:PDF,PowerPoint,Word,Excel,HTML,基于文本的格式(CSV,JSON,XML),EPubs")
continue
file_size = os.path.getsize(input_md_file)
with open(input_md_file, 'r', encoding='utf-8') as f:
content = f.read()
word_count = len(content)
paragraph_count = len([p for p in content.split('\n\n') if p.strip()])
print(f"文件转换完成,大小为【{file_size/1024:.2f} KB】,字数为【{word_count}】,段落数量为【{paragraph_count}】")
print(f"转换后的文件路径:")
print(input_md_file)
print("请打开并检查转换后的md文件内容")
while True:
confirm = input("是否继续翻译?如果选择n将结束程序。(y/n): ").strip().lower()
if confirm == 'y':
break
elif confirm == 'n':
print("程序已退出。您可以手动修改md文件后重新运行程序。")
return
else:
print("输入无效,请输入y或n")
PS = False
else:
input_md_file = input_file
break
input_dir = os.path.dirname(input_md_file)
input_filename = os.path.basename(input_md_file)
base_name, extension = os.path.splitext(input_filename)
output_base_filename = f"{base_name}_output"
print('')
output_md_file = os.path.join(input_dir, f"{output_base_filename}{extension}")
has_glossary = input("您是否已有术语表文件(csv,xlsx)?(y/n): ").strip().lower()
if has_glossary == 'n':
print(f'开始调取NER模型...')
print(f'如您没有下载模型,可前往https://huggingface.co/zhayunduo/ner-bert-chinese-base下载。并将模型文件放入./models目录下。')
from modules.ner_list_tool import EntityRecognizer
print("正在从文档中提取专业名词生成空白名词表...")
try:
recognizer = EntityRecognizer()
base_csv_path = recognizer.process_file(input_md_file)
except Exception as e:
print(f"名词提取失败: {str(e)}")
print("请检查文档格式后重试")
return
print(f"\n空白名词表已生成: {base_csv_path}")
print("请填写该文件中的译文列,然后重新运行程序使用名词表")
print("程序将在5秒后退出...")
time.sleep(5)
return
csv_file = get_valid_path("需要术语表文件地址(csv,xlsx)路径: ", validate_csv_file, prefs.get("last_csv_path"))
start_time = perf_counter()
terms_dict = load_terms_dict(csv_file)
glossary_df = load_glossary_df(csv_file)
aggregated_new_terms = []
merge_choice = input("是否将新术语合并到术语表?(y/n): ").strip().lower()
merge_in_place = (merge_choice == 'y')
new_terms_df = None
counter = 1
print(f"开始翻译文档...")
while os.path.exists(output_md_file):
output_md_file = os.path.join(input_dir, f"{output_base_filename}_{counter}{extension}")
counter += 1
total_paragraphs = count_structured_paragraphs(input_md_file, max_chunk_size=CHUNK_SIZE, preserve_structure=PS)
print(f"文档总段落数为【{total_paragraphs}】")
print(f"开始调取文档段落...")
paragraphs = read_structured_paragraphs(input_md_file, max_chunk_size=CHUNK_SIZE, preserve_structure=PS)
total_token = 0
current_paragraph = 0
consecutive_api_failures = 0
last_api_error = None
for segment in paragraphs:
current_paragraph += 1
print(f"开始翻译段落【{current_paragraph}】/【{total_paragraphs}】")
if PS:
paragraph, meta_data = segment
else:
paragraph = segment
meta_data = None
union_terms_dict = terms_dict.copy()
if aggregated_new_terms:
for nt in aggregated_new_terms:
k = str(nt.get('term', '')).strip()
v = str(nt.get('translation', '')).strip()
if k and k not in union_terms_dict:
union_terms_dict[k] = v
specific_terms_dict = find_matching_terms(paragraph, union_terms_dict)
prompt = llm_service.create_prompt(paragraph, specific_terms_dict)
while True:
try:
response_obj, usage_tokens = llm_service.call_ai_model_api(prompt)
translation = response_obj.get('translation', '')
notes = response_obj.get('notes', '')
new_terms = response_obj.get('newterminology', [])
aggregated_new_terms.extend(new_terms)
response = "\n\n---\n\n".join([translation, notes])
print(response)
consecutive_api_failures = 0
break
except Exception as e:
last_api_error = e
consecutive_api_failures += 1
print(f"\nAPI调用失败:{str(e)}")
print(f"连续翻译失败次数: {consecutive_api_failures}/{os.getenv('MAX_RETRIES', 3)}")
if consecutive_api_failures >= int(os.getenv('MAX_RETRIES', 3)):
print("\n连续3次翻译失败,开始进行API配置测试...")
try:
test_results = llm_service.test_api()
print("\n=== API测试完成 ===")
print(f"测试结果: {test_results}")
print(f"最后一次API调用错误: {str(last_api_error)}")
print(f"\n请检查API配置或网络连接后重新运行程序。")
print("配置文件位置: data/.env")
except Exception as test_e:
print(f"\nAPI测试过程中发生错误: {str(test_e)}")
print(f"最后一次API调用错误: {str(last_api_error)}")
print("\n请检查API配置或网络连接后重新运行程序。")
print("配置文件位置: data/.env")
return
print("正在重试当前段落...")
total_token += usage_tokens
if PS:
write_to_markdown(
output_md_file,
(response, meta_data),
mode='structured'
)
else:
write_to_markdown(
output_md_file,
(response, meta_data),
mode='flat'
)
print(f"已处理第{current_paragraph}段内容,输出已保存到:")
print(output_md_file)
end_time = perf_counter()
time_taken = end_time-start_time
print(time.strftime('共耗时:%H时%M分%S秒', time.gmtime(int(time_taken))))
raw_len = count_md_words(input_md_file)
processed_len = count_md_words(output_md_file)
if merge_in_place:
merged_glossary = merge_new_terms(glossary_df, aggregated_new_terms)
new_glossary_path = save_glossary_df(merged_glossary, csv_file)
else:
new_terms_df = dict_to_df(aggregated_new_terms)
new_glossary_path = save_glossary_df(new_terms_df, csv_file)
print("新的术语表已保存:")
print(new_glossary_path)
new_row = [str(input_md_file),raw_len,str(output_md_file),processed_len,total_token,time_taken]
file_exists = os.path.isfile('counting_table.csv') and os.path.getsize('counting_table.csv') > 0
with open('counting_table.csv','a',newline='',encoding='utf-8') as csvfile:
writer = csv.writer(csvfile)
if not file_exists:
writer.writerow(['Input file','Input len','Output file','Output len','Tokens','Taken time'])
writer.writerow(new_row)
new_prefs = {
"last_provider": provider_in,
"last_input_md_file": input_md_file,
"last_csv_path": csv_file,
}
try:
os.makedirs("data", exist_ok=True)
with open(prefs_path, "w", encoding="utf-8") as f:
json.dump(new_prefs, f, ensure_ascii=False, indent=2)
except Exception:
pass
if __name__ == "__main__":
main()