Skip to content

Commit 20805ab

Browse files
committed
一致性检查器使用 llm_adapters
1 parent 8afa108 commit 20805ab

File tree

2 files changed

+24
-10
lines changed

2 files changed

+24
-10
lines changed

consistency_checker.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# consistency_checker.py
22
# -*- coding: utf-8 -*-
3-
from langchain_openai import ChatOpenAI
3+
from llm_adapters import create_llm_adapter
44

55
# ============== 增加对“剧情要点/未解决冲突”进行检查的可选引导 ==============
66
CONSISTENCY_PROMPT = """\
@@ -32,7 +32,10 @@ def check_consistency(
3232
base_url: str,
3333
model_name: str,
3434
temperature: float = 0.3,
35-
plot_arcs: str = "" # 新增参数,默认空字符串
35+
plot_arcs: str = "",
36+
interface_format: str = "OpenAI",
37+
max_tokens: int = 2048,
38+
timeout: int = 600
3639
) -> str:
3740
"""
3841
调用模型做简单的一致性检查。可扩展更多提示或校验规则。
@@ -45,20 +48,25 @@ def check_consistency(
4548
plot_arcs=plot_arcs,
4649
chapter_text=chapter_text
4750
)
48-
model = ChatOpenAI(
49-
model=model_name,
50-
api_key=api_key,
51+
52+
llm_adapter = create_llm_adapter(
53+
interface_format=interface_format,
5154
base_url=base_url,
52-
temperature=temperature
55+
model_name=model_name,
56+
api_key=api_key,
57+
temperature=temperature,
58+
max_tokens=max_tokens,
59+
timeout=timeout
5360
)
61+
5462
# 调试日志
5563
print("\n[ConsistencyChecker] Prompt >>>", prompt)
5664

57-
response = model.invoke(prompt)
65+
response = llm_adapter.invoke(prompt)
5866
if not response:
5967
return "审校Agent无回复"
60-
68+
6169
# 调试日志
62-
print("[ConsistencyChecker] Response <<<", response.content.strip())
70+
print("[ConsistencyChecker] Response <<<", response)
6371

64-
return response.content.strip()
72+
return response

ui.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,6 +1083,9 @@ def task():
10831083
base_url = self.base_url_var.get().strip()
10841084
model_name = self.model_name_var.get().strip()
10851085
temperature = self.temperature_var.get()
1086+
interface_format = self.interface_format_var.get()
1087+
max_tokens = self.max_tokens_var.get()
1088+
timeout = self.timeout_var.get()
10861089

10871090
chap_num = self.safe_get_int(self.chapter_num_var, 1)
10881091
chap_file = os.path.join(filepath, "chapters", f"chapter_{chap_num}.txt")
@@ -1102,6 +1105,9 @@ def task():
11021105
base_url=base_url,
11031106
model_name=model_name,
11041107
temperature=temperature,
1108+
interface_format=interface_format,
1109+
max_tokens=max_tokens,
1110+
timeout=timeout,
11051111
plot_arcs=""
11061112
)
11071113
self.safe_log("审校结果:")

0 commit comments

Comments
 (0)