Skip to content

Commit 779942d

Browse files
authored
Merge pull request #89 from sangyuxiaowu/azure
添加对 Azure OpenAI 的支持,新增适配器并更新界面选项
2 parents 1ce0edc + 20805ab commit 779942d

File tree

4 files changed

+100
-14
lines changed

4 files changed

+100
-14
lines changed

consistency_checker.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# consistency_checker.py
22
# -*- coding: utf-8 -*-
3-
from langchain_openai import ChatOpenAI
3+
from llm_adapters import create_llm_adapter
44

55
# ============== 增加对“剧情要点/未解决冲突”进行检查的可选引导 ==============
66
CONSISTENCY_PROMPT = """\
@@ -32,7 +32,10 @@ def check_consistency(
3232
base_url: str,
3333
model_name: str,
3434
temperature: float = 0.3,
35-
plot_arcs: str = "" # 新增参数,默认空字符串
35+
plot_arcs: str = "",
36+
interface_format: str = "OpenAI",
37+
max_tokens: int = 2048,
38+
timeout: int = 600
3639
) -> str:
3740
"""
3841
调用模型做简单的一致性检查。可扩展更多提示或校验规则。
@@ -45,20 +48,25 @@ def check_consistency(
4548
plot_arcs=plot_arcs,
4649
chapter_text=chapter_text
4750
)
48-
model = ChatOpenAI(
49-
model=model_name,
50-
api_key=api_key,
51+
52+
llm_adapter = create_llm_adapter(
53+
interface_format=interface_format,
5154
base_url=base_url,
52-
temperature=temperature
55+
model_name=model_name,
56+
api_key=api_key,
57+
temperature=temperature,
58+
max_tokens=max_tokens,
59+
timeout=timeout
5360
)
61+
5462
# 调试日志
5563
print("\n[ConsistencyChecker] Prompt >>>", prompt)
5664

57-
response = model.invoke(prompt)
65+
response = llm_adapter.invoke(prompt)
5866
if not response:
5967
return "审校Agent无回复"
60-
68+
6169
# 调试日志
62-
print("[ConsistencyChecker] Response <<<", response.content.strip())
70+
print("[ConsistencyChecker] Response <<<", response)
6371

64-
return response.content.strip()
72+
return response

embedding_adapters.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import requests
55
import traceback
66
from typing import List
7-
from langchain_openai import OpenAIEmbeddings
7+
from langchain_openai import OpenAIEmbeddings, AzureOpenAIEmbeddings
88

99
def ensure_openai_base_url_has_v1(url: str) -> str:
1010
"""
@@ -45,6 +45,33 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
4545

4646
def embed_query(self, query: str) -> List[float]:
4747
return self._embedding.embed_query(query)
48+
49+
class AzureOpenAIEmbeddingAdapter(BaseEmbeddingAdapter):
50+
"""
51+
基于 AzureOpenAIEmbeddings(或兼容接口)的适配器
52+
"""
53+
def __init__(self, api_key: str, base_url: str, model_name: str):
54+
import re
55+
match = re.match(r'https://(.+?)/openai/deployments/(.+?)/embeddings\?api-version=(.+)', base_url)
56+
if match:
57+
self.azure_endpoint = f"https://{match.group(1)}"
58+
self.azure_deployment = match.group(2)
59+
self.api_version = match.group(3)
60+
else:
61+
raise ValueError("Invalid Azure OpenAI base_url format")
62+
63+
self._embedding = AzureOpenAIEmbeddings(
64+
azure_endpoint=self.azure_endpoint,
65+
azure_deployment=self.azure_deployment,
66+
openai_api_key=api_key,
67+
api_version=self.api_version,
68+
)
69+
70+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
71+
return self._embedding.embed_documents(texts)
72+
73+
def embed_query(self, query: str) -> List[float]:
74+
return self._embedding.embed_query(query)
4875

4976
class OllamaEmbeddingAdapter(BaseEmbeddingAdapter):
5077
"""
@@ -112,6 +139,8 @@ def create_embedding_adapter(
112139
"""
113140
if interface_format.lower() == "openai":
114141
return OpenAIEmbeddingAdapter(api_key, base_url, model_name)
142+
elif interface_format.lower() == "azure openai":
143+
return AzureOpenAIEmbeddingAdapter(api_key, base_url, model_name)
115144
elif interface_format.lower() == "ollama":
116145
return OllamaEmbeddingAdapter(model_name, base_url)
117146
elif interface_format.lower() == "ml studio":

llm_adapters.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# -*- coding: utf-8 -*-
33
import logging
44
from typing import Optional
5-
from langchain_openai import ChatOpenAI
5+
from langchain_openai import ChatOpenAI, AzureChatOpenAI
66

77
def ensure_openai_base_url_has_v1(url: str) -> str:
88
import re
@@ -77,6 +77,43 @@ def invoke(self, prompt: str) -> str:
7777
return ""
7878
return response.content
7979

80+
class AzureOpenAIAdapter(BaseLLMAdapter):
81+
"""
82+
适配 Azure OpenAI 接口(使用 langchain.ChatOpenAI)
83+
"""
84+
def __init__(self, api_key: str, base_url: str, model_name: str, max_tokens: int, temperature: float = 0.7, timeout: Optional[int] = 600):
85+
import re
86+
match = re.match(r'https://(.+?)/openai/deployments/(.+?)/chat/completions\?api-version=(.+)', base_url)
87+
if match:
88+
self.azure_endpoint = f"https://{match.group(1)}"
89+
self.azure_deployment = match.group(2)
90+
self.api_version = match.group(3)
91+
else:
92+
raise ValueError("Invalid Azure OpenAI base_url format")
93+
94+
self.api_key = api_key
95+
self.model_name = self.azure_deployment
96+
self.max_tokens = max_tokens
97+
self.temperature = temperature
98+
self.timeout = timeout
99+
100+
self._client = AzureChatOpenAI(
101+
azure_endpoint=self.azure_endpoint,
102+
azure_deployment=self.azure_deployment,
103+
api_version=self.api_version,
104+
api_key=self.api_key,
105+
max_tokens=self.max_tokens,
106+
temperature=self.temperature,
107+
timeout=self.timeout
108+
)
109+
110+
def invoke(self, prompt: str) -> str:
111+
response = self._client.invoke(prompt)
112+
if not response:
113+
logging.warning("No response from AzureOpenAIAdapter.")
114+
return ""
115+
return response.content
116+
80117
class OllamaAdapter(BaseLLMAdapter):
81118
"""
82119
Ollama 同样有一个 OpenAI-like /v1/chat 接口,可直接使用 ChatOpenAI。
@@ -147,6 +184,8 @@ def create_llm_adapter(
147184
return DeepSeekAdapter(api_key, base_url, model_name, max_tokens, temperature, timeout)
148185
elif interface_format.lower() == "openai":
149186
return OpenAIAdapter(api_key, base_url, model_name, max_tokens, temperature, timeout)
187+
elif interface_format.lower() == "azure openai":
188+
return AzureOpenAIAdapter(api_key, base_url, model_name, max_tokens, temperature, timeout)
150189
elif interface_format.lower() == "ollama":
151190
return OllamaAdapter(api_key, base_url, model_name, max_tokens, temperature, timeout)
152191
elif interface_format.lower() == "ml studio":

ui.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,8 @@ def on_interface_format_changed(new_value):
290290
self.base_url_var.set("http://localhost:1234/v1")
291291
elif new_value == "OpenAI":
292292
self.base_url_var.set("https://api.openai.com/v1")
293+
elif new_value == "Azure OpenAI":
294+
self.base_url_var.set("https://[az].openai.azure.com/openai/deployments/[model]/chat/completions?api-version=2024-08-01-preview")
293295
elif new_value == "DeepSeek":
294296
self.base_url_var.set("https://api.deepseek.com/v1")
295297

@@ -332,7 +334,7 @@ def on_interface_format_changed(new_value):
332334
column=0,
333335
font=("Microsoft YaHei", 12)
334336
)
335-
interface_options = ["DeepSeek", "OpenAI", "Ollama", "ML Studio"]
337+
interface_options = ["DeepSeek", "OpenAI", "Azure OpenAI", "Ollama", "ML Studio"]
336338
interface_dropdown = ctk.CTkOptionMenu(
337339
self.ai_config_tab,
338340
values=interface_options,
@@ -452,6 +454,8 @@ def on_embedding_interface_changed(new_value):
452454
self.embedding_url_var.set("http://localhost:1234/v1")
453455
elif new_value == "OpenAI":
454456
self.embedding_url_var.set("https://api.openai.com/v1")
457+
elif new_value == "Azure OpenAI":
458+
self.embedding_url_var.set("https://[az].openai.azure.com/openai/deployments/[model]/embeddings?api-version=2023-05-15")
455459
elif new_value == "DeepSeek":
456460
self.embedding_url_var.set("https://api.deepseek.com/v1")
457461

@@ -482,7 +486,7 @@ def on_embedding_interface_changed(new_value):
482486
column=0,
483487
font=("Microsoft YaHei", 12)
484488
)
485-
emb_interface_options = ["DeepSeek", "OpenAI", "Ollama", "ML Studio"]
489+
emb_interface_options = ["DeepSeek", "OpenAI", "Azure OpenAI", "Ollama", "ML Studio"]
486490
emb_interface_dropdown = ctk.CTkOptionMenu(
487491
self.embeddings_config_tab,
488492
values=emb_interface_options,
@@ -1079,6 +1083,9 @@ def task():
10791083
base_url = self.base_url_var.get().strip()
10801084
model_name = self.model_name_var.get().strip()
10811085
temperature = self.temperature_var.get()
1086+
interface_format = self.interface_format_var.get()
1087+
max_tokens = self.max_tokens_var.get()
1088+
timeout = self.timeout_var.get()
10821089

10831090
chap_num = self.safe_get_int(self.chapter_num_var, 1)
10841091
chap_file = os.path.join(filepath, "chapters", f"chapter_{chap_num}.txt")
@@ -1098,6 +1105,9 @@ def task():
10981105
base_url=base_url,
10991106
model_name=model_name,
11001107
temperature=temperature,
1108+
interface_format=interface_format,
1109+
max_tokens=max_tokens,
1110+
timeout=timeout,
11011111
plot_arcs=""
11021112
)
11031113
self.safe_log("审校结果:")

0 commit comments

Comments
 (0)