-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathapp.py
More file actions
794 lines (672 loc) · 29.7 KB
/
app.py
File metadata and controls
794 lines (672 loc) · 29.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
from flask import Flask, render_template, jsonify, request, send_file
import subprocess
import os
import logging
import time
import json
from datetime import datetime
import random
import string
import socket
from pypinyin import lazy_pinyin
import sys
import webbrowser
import threading
import fitz # PyMuPDF,用于提取图片
import base64 # 用于图片编码
import re # 用于清理文件名中的非法字符
import concurrent.futures # 新增:用于并发请求
from collections import Counter # 新增:用于计算众数
# 添加openai库导入
from openai import OpenAI
import traceback
logger = logging.getLogger('gunicorn.error')
app = Flask(__name__)
from flask import send_from_directory
@app.route('/favicon.ico')
def favicon():
return send_from_directory(app.static_folder, 'favicon.ico')
@app.route('/apple-touch-icon-precomposed.png')
def apple_icon_precomposed():
return send_from_directory(app.static_folder, 'apple-touch-icon-precomposed.png')
@app.route('/apple-touch-icon.png')
def apple_icon():
return send_from_directory(app.static_folder, 'apple-touch-icon.png')
def convert_to_pinyin(text):
"""将中文字符转换为拼音"""
return ''.join(lazy_pinyin(text))
SCRIPT_TIMEOUT = 300
DATA_FOLDERS = [
'input_pdf',
'mark/input_image',
'raw_content',
'output_pdf',
'mark/image_metadata',
'merged_content',
]
QWEN_SCRIPT_SEQUENCE = [
('pdf_to_image', 'PDF转换为图像(下一步可能需要一分钟或更长,请耐心等待)'),
('qwen_vl_extract', 'OCR识别'),
('content_preprocessor', '内容预处理'),
('llm_level_adjuster', '层级调整'),
('pdf_generator', 'PDF生成')
]
def generate_random_string(length=6):
"""生成指定长度的随机字母数字组合"""
characters = string.ascii_letters + string.digits
return ''.join(random.choice(characters) for _ in range(length))
def generate_session_id():
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
random_suffix = generate_random_string()
return f"{timestamp}_{random_suffix}"
def create_data_folders(session_id):
base_dir = os.path.join('data', session_id)
for folder in DATA_FOLDERS:
folder_path = os.path.join(base_dir, folder)
os.makedirs(folder_path, exist_ok=True)
return base_dir
def extract_env_var_name(api_key_value):
"""
从API KEY值中提取环境变量名称
例如: $CHERRY_IN_API_KEY$ -> CHERRY_IN_API_KEY
"""
if api_key_value.startswith('$') and api_key_value.endswith('$'):
return api_key_value[1:-1] # 移除开头和结尾的$
return None
def extract_first_page_and_recognize(pdf_path, original_filename, session_dir, config):
"""提取PDF第一页并调用LLM识别书名"""
try:
doc = fitz.open(pdf_path)
page = doc[0]
rect = page.rect
max_dim = max(rect.width, rect.height)
zoom = 1000.0 / max_dim if max_dim > 1000 else 1.0
mat = fitz.Matrix(zoom, zoom)
pix = page.get_pixmap(matrix=mat)
img_data = pix.tobytes("jpeg")
doc.close()
base64_image = base64.b64encode(img_data).decode('utf-8')
image_data_url = f"data:image/jpeg;base64,{base64_image}"
api_key_value = config.get("api_key", "")
env_var_name = extract_env_var_name(api_key_value)
actual_api_key = os.environ.get(env_var_name, "") if env_var_name else api_key_value
client = OpenAI(
api_key=actual_api_key,
base_url=config.get("base_url", "https://dashscope.aliyuncs.com/compatible-mode/v1")
)
prompt = f"这是PDF文件的第一页。该文件的原始文件名为:{original_filename}。请结合图片内容和原始文件名,识别并输出这本书的书名。只需输出书名文本,不要包含任何其他说明、标点或多余内容。"
completion = client.chat.completions.create(
model=config.get("model", "qwen-vl-max"),
messages=[
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_data_url}},
{"type": "text", "text": prompt}
]
}
],
extra_body={"enable_thinking": False}
)
book_name = completion.choices[0].message.content.strip()
book_name = re.sub(r'[\\/:*?"<>|]', '_', book_name)
info_path = os.path.join(session_dir, 'book_name.txt')
with open(info_path, 'w', encoding='utf-8') as f:
f.write(book_name)
except Exception as e:
logger.error(f"识别书名失败: {str(e)}")
def calculate_offset(pdf_path, json_path, config):
"""自动计算正文偏移量并更新JSON"""
try:
doc = fitz.open(pdf_path)
total_pages = len(doc)
start_idx = int(total_pages * 0.2)
end_idx = int(total_pages * 0.8)
if end_idx <= start_idx:
end_idx = total_pages - 1
start_idx = 0
pool = list(range(start_idx, end_idx + 1))
selected_pages = random.sample(pool, min(5, len(pool)))
images_data = []
for p in selected_pages:
page = doc[p]
rect = page.rect
max_dim = max(rect.width, rect.height)
zoom = 1500.0 / max_dim if max_dim > 1500 else 2.0
mat = fitz.Matrix(zoom, zoom)
pix = page.get_pixmap(matrix=mat)
img_data = pix.tobytes("jpeg")
base64_image = base64.b64encode(img_data).decode('utf-8')
images_data.append((p + 1, base64_image)) # p+1 为物理页码
doc.close()
api_key_value = config.get("api_key", "")
env_var_name = extract_env_var_name(api_key_value)
actual_api_key = os.environ.get(env_var_name, "") if env_var_name else api_key_value
client = OpenAI(
api_key=actual_api_key,
base_url=config.get("base_url", "https://dashscope.aliyuncs.com/compatible-mode/v1")
)
def fetch_offset(page_num, b64_img):
prompt = f"""你是一个专业的文档页码识别专家。你的任务是识别图片中页面底部或顶部标注的实际印刷页码,并计算正文偏移量。
计算公式:正文偏移量 = PDF物理页码 - 印刷页码。
当前图片的PDF物理页码是:{page_num}
【示例1】
物理页码:25
图片中底部写着:"10"
输出:15
【示例2】
物理页码:12
图片中顶部写着:"- 2 -"
输出:10
【示例3】
物理页码:100
图片中没有明确的阿拉伯数字页码
输出:Error
请仔细观察图片,找到印刷页码,并严格按照上述格式,仅输出计算后的正文偏移量数字。不要输出任何解释。"""
try:
completion = client.chat.completions.create(
model=config.get("model", "qwen-vl-max"),
messages=[
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64_img}"}},
{"type": "text", "text": prompt}
]
}
],
extra_body={"enable_thinking": False}
)
return completion.choices[0].message.content.strip()
except Exception as e:
logger.error(f"获取偏移量失败: {e}")
return "Error"
offsets = []
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(fetch_offset, p_num, b64) for p_num, b64 in images_data]
for future in concurrent.futures.as_completed(futures):
res = future.result()
if res.isdigit() or (res.startswith('-') and res[1:].isdigit()):
offsets.append(int(res))
if offsets:
most_common_offset = Counter(offsets).most_common(1)[0][0]
with open(json_path, 'r', encoding='utf-8') as f:
j_data = json.load(f)
j_data['content_start'] = most_common_offset
with open(json_path, 'w', encoding='utf-8') as f:
json.dump(j_data, f, ensure_ascii=False, indent=4)
logger.info(f"自动计算偏移量成功: {most_common_offset}")
else:
logger.warning("未能自动计算出有效的偏移量")
except Exception as e:
logger.error(f"自动计算偏移量过程发生异常: {str(e)}")
def background_tasks(pdf_path, original_filename, session_dir, config, json_path, auto_offset):
"""后台执行书名识别和偏移量计算"""
extract_first_page_and_recognize(pdf_path, original_filename, session_dir, config)
if auto_offset:
calculate_offset(pdf_path, json_path, config)
@app.route('/')
def home():
return render_template('index.html')
@app.route('/upload', methods=['POST'])
def upload_files():
try:
session_id = generate_session_id()
base_dir = create_data_folders(session_id)
if 'pdf' not in request.files:
return jsonify({'status': 'error', 'message': '未找到PDF文件'})
pdf_file = request.files['pdf']
if pdf_file.filename == '':
return jsonify({'status': 'error', 'message': '未选择PDF文件'})
toc_start = request.form.get('tocStart')
toc_end = request.form.get('tocEnd')
content_start = request.form.get('contentStart')
auto_offset = request.form.get('autoOffset') == 'true'
toc_structure = request.form.get('tocStructure', 'original')
if not toc_start or not toc_end or (not auto_offset and not content_start):
return jsonify({'status': 'error', 'message': '页码信息不完整'})
original_filename = pdf_file.filename
filename_without_ext, file_extension = os.path.splitext(original_filename)
pinyin_filename = convert_to_pinyin(filename_without_ext)
if len(pinyin_filename) > 25:
pinyin_filename = pinyin_filename[:25]
pinyin_filename = pinyin_filename + file_extension
upload_folder = os.path.join(base_dir, 'input_pdf')
pdf_path = os.path.join(upload_folder, pinyin_filename)
pdf_file.save(pdf_path)
json_data = {
"toc_start": int(toc_start),
"toc_end": int(toc_end),
"content_start": int(content_start) if content_start else 0,
"original_filename": original_filename,
"toc_structure": toc_structure
}
json_filename = pinyin_filename.replace(file_extension, '.json')
json_path = os.path.join(upload_folder, json_filename)
with open(json_path, 'w', encoding='utf-8') as f:
json.dump(json_data, f, ensure_ascii=False, indent=4)
config_path = os.path.join(app.static_folder, 'llm_config.json')
if os.path.exists(config_path):
with open(config_path, 'r', encoding='utf-8') as f:
llm_config = json.load(f)
else:
llm_config = {
"api_key": os.getenv("DASHSCOPE_API_KEY", ""),
"base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"model": "qwen-vl-max"
}
threading.Thread(
target=background_tasks,
args=(pdf_path, original_filename, base_dir, llm_config, json_path, auto_offset)
).start()
return jsonify({
'status': 'success',
'message': '文件上传成功',
'session_id': session_id
})
except Exception as e:
return jsonify({'status': 'error', 'message': str(e)})
@app.route('/download_result/<session_id>')
def download_result(session_id):
data_dir = 'data'
show_reminder = False
no_reminder_option = False
if os.path.exists(data_dir):
folders = [f for f in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, f))]
folder_count = len(folders)
if folder_count % 5 == 0:
no_reminder_file = os.path.join('data', 'no_reminder')
if not os.path.exists(no_reminder_file):
show_reminder = True
if folder_count >= 15:
no_reminder_option = True
output_folder = os.path.join('data', session_id, 'output_pdf')
input_folder = os.path.join('data', session_id, 'input_pdf')
pdf_files = [f for f in os.listdir(output_folder) if f.endswith('.pdf')]
if not pdf_files:
return jsonify({'status': 'error', 'message': '未找到输出PDF文件'})
file_path = os.path.join(output_folder, pdf_files[0])
try:
input_files = [f for f in os.listdir(input_folder) if f.endswith('.pdf') or f.endswith('.json')]
original_pdf = next((f for f in input_files if f.endswith('.pdf')), None)
if not original_pdf:
return jsonify({'status': 'error', 'message': '未找到原始PDF文件'})
json_file_name = os.path.splitext(original_pdf)[0] + '.json'
json_path = os.path.join(input_folder, json_file_name)
with open(json_path, 'r', encoding='utf-8') as f:
json_data = json.load(f)
original_filename = json_data.get('original_filename', None)
except Exception as e:
print(f"[ERROR] 读取 JSON 时出错: {e}")
original_filename = None
book_name = ""
book_name_path = os.path.join('data', session_id, 'book_name.txt')
if os.path.exists(book_name_path):
try:
with open(book_name_path, 'r', encoding='utf-8') as f:
book_name = f.read().strip()
except Exception:
pass
if not book_name:
if original_filename:
book_name, _ = os.path.splitext(original_filename)
else:
book_name = "处理结果"
time_str = datetime.now().strftime("%y%m%d%H%M%S")
download_filename = f"{book_name}-{time_str}-TOC.pdf"
response = send_file(file_path, as_attachment=True, download_name=download_filename)
expose_headers = ['Content-Disposition']
if show_reminder:
response.headers['X-Show-Reminder'] = 'true'
expose_headers.append('X-Show-Reminder')
if no_reminder_option:
response.headers['X-No-Reminder-Option'] = 'true'
expose_headers.append('X-No-Reminder-Option')
response.headers['Access-Control-Expose-Headers'] = ', '.join(expose_headers)
return response
@app.route('/set_no_reminder', methods=['POST'])
def set_no_reminder():
try:
no_reminder_file = os.path.join('data', 'no_reminder')
os.makedirs('data', exist_ok=True)
with open(no_reminder_file, 'w') as f:
f.write('do not remind')
return jsonify({'status': 'success', 'message': '已设置不再提醒'})
except Exception as e:
return jsonify({'status': 'error', 'message': str(e)}), 500
AZURE_TIMEOUT = 30
def run_azure_with_timeout(python_executable, script_path, env, script_dir):
try:
result = subprocess.run(
[python_executable, script_path],
env=env,
cwd=script_dir,
capture_output=True,
text=True,
timeout=AZURE_TIMEOUT
)
return {
'success': result.returncode == 0,
'stdout': result.stdout,
'stderr': result.stderr
}
except subprocess.TimeoutExpired as e:
return {
'success': False,
'error': 'Azure OCR timeout',
'stdout': e.stdout.decode() if e.stdout else '',
'stderr': e.stderr.decode() if e.stderr else ''
}
except Exception as e:
return {
'success': False,
'error': str(e)
}
def get_api_key_from_config():
config_path = os.path.join(app.static_folder, 'llm_config.json')
if os.path.exists(config_path):
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
api_key_value = config.get('api_key', '')
env_var_name = extract_env_var_name(api_key_value)
if env_var_name:
actual_api_key = os.environ.get(env_var_name, '')
return actual_api_key
else:
return api_key_value
else:
return ''
@app.route('/run_script/<session_id>/<int:script_index>/<int:retry_count>')
def run_script(session_id, script_index, retry_count):
ocr_model = request.args.get('ocr_model', 'aliyun')
if ocr_model == 'qwen':
script_sequence = QWEN_SCRIPT_SEQUENCE
total_scripts = len(QWEN_SCRIPT_SEQUENCE)
else:
script_sequence = []
total_scripts = 0
if script_index >= len(script_sequence):
return jsonify({
'status': 'completed',
'message': '所有脚本执行完成'
})
script_name, script_desc = script_sequence[script_index]
try:
script_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'mainprogress'))
script_path = os.path.join(script_dir, f'{script_name}.py')
base_dir = os.path.abspath(os.path.join('data', session_id))
env = os.environ.copy()
config_path = os.path.join(app.static_folder, 'llm_config.json')
if os.path.exists(config_path):
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
api_key_value = config.get('api_key', '')
env_var_name = extract_env_var_name(api_key_value)
if env_var_name:
actual_api_key = os.environ.get(env_var_name, '')
env[env_var_name] = actual_api_key
env['DASHSCOPE_API_KEY'] = actual_api_key
else:
env['DASHSCOPE_API_KEY'] = api_key_value
env.update({
'BASE_DIR': base_dir,
'PDF2JPG_INPUT': f"{base_dir}/input_pdf",
'PDF2JPG_OUTPUT': f"{base_dir}/mark/input_image",
'CONTENT_PREPROCESSOR_INPUT': f"{base_dir}/raw_content",
'CONTENT_PREPROCESSOR_OUTPUT': f"{base_dir}/merged_content",
'PDF_GENERATOR_INPUT_1': f"{base_dir}/level_adjusted_content",
'PDF_GENERATOR_INPUT_2': f"{base_dir}/input_pdf",
'PDF_GENERATOR_OUTPUT_1': f"{base_dir}/output_pdf",
'QWEN_VL_INPUT': f"{base_dir}/mark/input_image",
'QWEN_VL_OUTPUT': f"{base_dir}/automark_raw_data",
'LEVEL_ADJUSTER_INPUT': f"{base_dir}/merged_content",
'LEVEL_ADJUSTER_OUTPUT': f"{base_dir}/level_adjusted_content",
'LEVEL_ADJUSTER_CACHE': f"{base_dir}/level_adjuster_cache",
'LEVEL_ADJUSTER_PICTURES': f"{base_dir}/mark/input_image"
})
python_executable = sys.executable
if script_name in ['ocr_hybrid', 'ocr_and_projection_hybrid']:
ocr_model = request.args.get('ocr_model', 'aliyun')
if ocr_model == 'azure':
script_path = os.path.join(script_dir, f'{script_name.replace("hybrid", "azure")}.py')
azure_result = run_azure_with_timeout(python_executable, script_path, env, script_dir)
if azure_result.get('success', False):
return jsonify({
'status': 'success',
'currentScript': script_desc,
'message': f'{script_desc} (Azure) 执行成功',
'nextIndex': script_index + 1,
'totalScripts': total_scripts,
'retryCount': 0,
'session_id': session_id,
'stdout': azure_result.get('stdout', ''),
'stderr': azure_result.get('stderr', '')
})
else:
return jsonify({
'status': 'error',
'currentScript': script_desc,
'message': f'{script_desc} (Azure) 执行失败',
'stdout': azure_result.get('stdout', ''),
'stderr': azure_result.get('stderr', ''),
'retryCount': retry_count,
'scriptIndex': script_index,
'session_id': session_id
})
else:
script_path = os.path.join(script_dir, f'{script_name.replace("hybrid", "aliyun")}.py')
try:
result = subprocess.run(
[python_executable, script_path],
env=env,
cwd=script_dir,
capture_output=True,
text=True,
timeout=SCRIPT_TIMEOUT
)
if result.returncode == 0:
return jsonify({
'status': 'success',
'currentScript': script_desc,
'message': f'{script_desc}执行成功',
'nextIndex': script_index + 1,
'totalScripts': total_scripts,
'retryCount': 0,
'session_id': session_id,
'stdout': result.stdout,
'stderr': result.stderr
})
else:
return jsonify({
'status': 'error',
'currentScript': script_desc,
'message': f'{script_desc}执行失败',
'stdout': result.stdout,
'stderr': result.stderr,
'retryCount': retry_count,
'scriptIndex': script_index,
'session_id': session_id
})
except subprocess.TimeoutExpired as e:
return jsonify({
'status': 'error',
'currentScript': script_desc,
'message': f'脚本执行超时({SCRIPT_TIMEOUT}秒)',
'stdout': e.stdout.decode() if e.stdout else '',
'stderr': e.stderr.decode() if e.stderr else '',
'retryCount': retry_count,
'scriptIndex': script_index,
'session_id': session_id
})
except Exception as e:
logger.error(f"执行脚本时发生错误: {str(e)}")
return jsonify({
'status': 'error',
'currentScript': script_desc,
'message': f'执行出错: {str(e)}',
'retryCount': retry_count,
'scriptIndex': script_index,
'session_id': session_id
})
@app.route('/save_prompt/<filename>', methods=['POST'])
def save_prompt(filename):
try:
allowed_files = ['extract_prompt.md', 'adjuster_prompt_route.md', 'adjuster_prompt.md']
if filename not in allowed_files:
return jsonify({'status': 'error', 'message': '不允许保存该文件'}), 403
content = request.get_data(as_text=True)
static_dir = app.static_folder
if not os.path.exists(static_dir):
os.makedirs(static_dir)
file_path = os.path.join(static_dir, filename)
with open(file_path, 'w', encoding='utf-8') as f:
f.write(content)
return jsonify({'status': 'success', 'message': f'{filename} 保存成功'})
except Exception as e:
logger.error(f"保存提示词文件失败: {str(e)}")
return jsonify({'status': 'error', 'message': f'保存失败: {str(e)}'}), 500
@app.route('/get_llm_config')
def get_llm_config():
try:
config_path = os.path.join(app.static_folder, 'llm_config.json')
if os.path.exists(config_path):
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
return jsonify({'status': 'success', 'config': config})
else:
default_config = {
"api_key": "$DASHSCOPE_API_KEY$",
"base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"model": "qwen3.5-397b-a17b"
}
return jsonify({'status': 'success', 'config': default_config})
except Exception as e:
logger.error(f"获取 LLM 配置失败: {str(e)}")
return jsonify({'status': 'error', 'message': f'获取配置失败: {str(e)}'}), 500
@app.route('/save_llm_config', methods=['POST'])
def save_llm_config():
try:
config = request.get_json()
required_fields = ['api_key', 'base_url', 'model']
for field in required_fields:
if field not in config:
return jsonify({'status': 'error', 'message': f'缺少必需字段: {field}'}), 400
static_dir = app.static_folder
if not os.path.exists(static_dir):
os.makedirs(static_dir)
config_path = os.path.join(static_dir, 'llm_config.json')
with open(config_path, 'w', encoding='utf-8') as f:
json.dump(config, f, ensure_ascii=False, indent=2)
return jsonify({'status': 'success', 'message': 'LLM 配置保存成功'})
except Exception as e:
logger.error(f"保存 LLM 配置失败: {str(e)}")
return jsonify({'status': 'error', 'message': f'保存配置失败: {str(e)}'}), 500
@app.route('/test_qwen_service', methods=['POST'])
def test_qwen_service():
try:
config_path = os.path.join(app.static_folder, 'llm_config.json')
if os.path.exists(config_path):
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
else:
config = {
"api_key": os.getenv("DASHSCOPE_API_KEY", ""),
"base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"model": "qwen3.5-397b-a17b"
}
api_key_value = config["api_key"]
env_var_name = extract_env_var_name(api_key_value)
if env_var_name:
actual_api_key = os.environ.get(env_var_name, "")
else:
actual_api_key = api_key_value
client = OpenAI(
api_key=actual_api_key,
base_url=config["base_url"],
)
completion = client.chat.completions.create(
model=config["model"],
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "正在测试通义千问服务访问状态,请输出`正常`这两个中文字符,不要附带任何其他内容"},
],
)
return jsonify({
'status': 'success',
'message': '通义千问服务状态正常',
'response': completion.choices[0].message.content if completion.choices else ''
})
except Exception as e:
logger.error(f"通义千问服务测试失败: {str(e)}")
logger.error(traceback.format_exc())
return jsonify({
'status': 'error',
'message': f'测试失败: {str(e)}',
'error_code': type(e).__name__
}), 500
@app.route('/test_llm_service', methods=['POST'])
def test_llm_service():
try:
data = request.get_json()
api_key = data.get('api_key', '')
base_url = data.get('base_url', '')
model = data.get('model', '')
if not api_key or not base_url or not model:
return jsonify({
'status': 'error',
'message': 'API配置信息不完整,请检查API Key、Base URL和Model是否都已填写'
}), 400
env_var_name = extract_env_var_name(api_key)
if env_var_name:
actual_api_key = os.environ.get(env_var_name, "")
else:
actual_api_key = api_key
client = OpenAI(
api_key=actual_api_key,
base_url=base_url,
)
completion = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "正在测试LLM服务访问状态,请输出`正常`这两个中文字符,不要附带任何其他内容"},
],
extra_body={"enable_thinking": False}
)
return jsonify({
'status': 'success',
'message': 'LLM服务状态正常',
'response': completion.choices[0].message.content if completion.choices else ''
})
except Exception as e:
logger.error(f"LLM服务测试失败: {str(e)}")
logger.error(traceback.format_exc())
return jsonify({
'status': 'error',
'message': f'测试失败: {str(e)}',
'error_code': type(e).__name__
}), 500
def find_available_port(start_port=5000, max_port=6000):
current_port = start_port
while (current_port <= max_port):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
sock.bind(('', current_port))
sock.close()
return current_port
except OSError:
current_port += 1
finally:
sock.close()
return None
if __name__ == '__main__':
port = find_available_port()
if port is None:
print("Error: No available ports found between 5000 and 6000")
else:
def open_browser():
time.sleep(1.5)
webbrowser.open_new(f'http://127.0.0.1:{port}')
threading.Thread(target=open_browser).start()
print(f"Starting server on port {port}")
app.run(debug=True, port=port, use_reloader=False)