Skip to content

Commit 1e186dd

Browse files
authored
✨ Model services that connect to ModelEngine model enablement
2 parents 5805b5f + d7cb555 commit 1e186dd

37 files changed

+1252
-354
lines changed

backend/agents/create_agent_info.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ async def create_model_config_list(tenant_id):
4444
model_name=record["model_name"],
4545
),
4646
url=record["base_url"],
47-
ssl_verify=record.get("ssl_verify", True)))
47+
ssl_verify=record.get("ssl_verify", True),
48+
model_factory=record.get("model_factory")))
4849
# fit for old version, main_model and sub_model use default model
4950
main_model_config = tenant_config_manager.get_model_config(
5051
key=MODEL_CONFIG_MAPPING["llm"], tenant_id=tenant_id)
@@ -54,14 +55,16 @@ async def create_model_config_list(tenant_id):
5455
model_name=get_model_name_from_config(main_model_config) if main_model_config.get(
5556
"model_name") else "",
5657
url=main_model_config.get("base_url", ""),
57-
ssl_verify=main_model_config.get("ssl_verify", True)))
58+
ssl_verify=main_model_config.get("ssl_verify", True),
59+
model_factory=main_model_config.get("model_factory")))
5860
model_list.append(
5961
ModelConfig(cite_name="sub_model",
6062
api_key=main_model_config.get("api_key", ""),
6163
model_name=get_model_name_from_config(main_model_config) if main_model_config.get(
6264
"model_name") else "",
6365
url=main_model_config.get("base_url", ""),
64-
ssl_verify=main_model_config.get("ssl_verify", True)))
66+
ssl_verify=main_model_config.get("ssl_verify", True),
67+
model_factory=main_model_config.get("model_factory")))
6568

6669
return model_list
6770

backend/services/conversation_management_service.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,9 @@ def call_llm_for_title(content: str, tenant_id: str, language: str = LANGUAGE["Z
268268
api_base=model_config.get("base_url", ""),
269269
api_key=model_config.get("api_key", ""),
270270
temperature=0.7,
271-
top_p=0.95
271+
top_p=0.95,
272+
model_factory=model_config.get("model_factory", None),
273+
ssl_verify=model_config.get("ssl_verify", True)
272274
)
273275

274276
# Build messages
@@ -280,6 +282,10 @@ def call_llm_for_title(content: str, tenant_id: str, language: str = LANGUAGE["Z
280282
{"role": MESSAGE_ROLE["USER"],
281283
"content": user_prompt}]
282284

285+
# ModelEngine 只接受 role/content 的简单结构,确保提前扁平化
286+
if model_config.get("model_factory", "").lower() == "modelengine":
287+
messages = [{"role": msg["role"], "content": str(msg.get("content", ""))} for msg in messages]
288+
283289
# Call the model
284290
response = llm.generate(messages)
285291
if not response or not response.content or not response.content.strip():

backend/services/file_management_service.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ def get_llm_model(tenant_id: str):
192192
model_id=get_model_name_from_config(main_model_config),
193193
api_base=main_model_config.get("base_url"),
194194
api_key=main_model_config.get("api_key"),
195-
max_context_tokens=main_model_config.get("max_tokens")
195+
max_context_tokens=main_model_config.get("max_tokens"),
196+
ssl_verify=main_model_config.get("ssl_verify", True),
196197
)
197198
return long_text_to_text_model

backend/services/image_service.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,5 +42,6 @@ def get_vlm_model(tenant_id: str):
4242
temperature=0.7,
4343
top_p=0.7,
4444
frequency_penalty=0.5,
45-
max_tokens=512
45+
max_tokens=512,
46+
ssl_verify=vlm_model_config.get("ssl_verify", True),
4647
)

backend/services/model_health_service.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,17 @@ async def _embedding_dimension_check(
1717
model_name: str,
1818
model_type: str,
1919
model_base_url: str,
20-
model_api_key: str
20+
model_api_key: str,
21+
ssl_verify: bool = True,
2122
):
2223
# Test connectivity based on different model types
2324
if model_type == "embedding":
2425
embedding = await OpenAICompatibleEmbedding(
2526
model_name=model_name,
2627
base_url=model_base_url,
2728
api_key=model_api_key,
28-
embedding_dim=0
29+
embedding_dim=0,
30+
ssl_verify=ssl_verify,
2931
).dimension_check()
3032
if len(embedding) > 0:
3133
return len(embedding[0])
@@ -37,7 +39,8 @@ async def _embedding_dimension_check(
3739
model_name=model_name,
3840
base_url=model_base_url,
3941
api_key=model_api_key,
40-
embedding_dim=0
42+
embedding_dim=0,
43+
ssl_verify=ssl_verify,
4144
).dimension_check()
4245
if len(embedding) > 0:
4346
return len(embedding[0])
@@ -78,14 +81,16 @@ async def _perform_connectivity_check(
7881
model_name=model_name,
7982
base_url=model_base_url,
8083
api_key=model_api_key,
81-
embedding_dim=0
84+
embedding_dim=0,
85+
ssl_verify=ssl_verify
8286
).dimension_check()) > 0
8387
elif model_type == "multi_embedding":
8488
connectivity = len(await JinaEmbedding(
8589
model_name=model_name,
8690
base_url=model_base_url,
8791
api_key=model_api_key,
88-
embedding_dim=0
92+
embedding_dim=0,
93+
ssl_verify=ssl_verify
8994
).dimension_check()) > 0
9095
elif model_type == "llm":
9196
observer = MessageObserver()
@@ -104,7 +109,8 @@ async def _perform_connectivity_check(
104109
observer,
105110
model_id=model_name,
106111
api_base=model_base_url,
107-
api_key=model_api_key
112+
api_key=model_api_key,
113+
ssl_verify=ssl_verify
108114
).check_connectivity()
109115
elif model_type in ["tts", "stt"]:
110116
voice_service = get_voice_service()
@@ -227,8 +233,9 @@ async def embedding_dimension_check(model_config: dict):
227233
model_api_key = model_config["api_key"]
228234

229235
try:
236+
ssl_verify = model_config.get("ssl_verify", True)
230237
dimension = await _embedding_dimension_check(
231-
model_name, model_type, model_base_url, model_api_key
238+
model_name, model_type, model_base_url, model_api_key, ssl_verify
232239
)
233240
return dimension
234241
except ValueError as e:

backend/services/model_provider_service.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,10 @@ class ModelEngineProvider(AbstractModelProvider):
7777
async def get_models(self, provider_config: Dict) -> List[Dict]:
7878
"""
7979
Fetch models from ModelEngine API.
80-
80+
8181
Args:
8282
provider_config: Configuration dict containing model_type
83-
83+
8484
Returns:
8585
List of models with canonical fields
8686
"""
@@ -111,19 +111,19 @@ async def get_models(self, provider_config: Dict) -> List[Dict]:
111111
"asr": "stt",
112112
"tts": "tts",
113113
"rerank": "rerank",
114-
"vlm": "vlm",
114+
"multimodal": "vlm",
115115
}
116116

117117
# Filter models by type if specified
118118
filtered_models = []
119119
for model in all_models:
120120
me_type = model.get("type", "")
121121
internal_type = type_map.get(me_type)
122-
122+
123123
# If model_type filter is provided, only include matching models
124124
if model_type and internal_type != model_type:
125125
continue
126-
126+
127127
if internal_type:
128128
filtered_models.append({
129129
"id": model.get("id", ""),

backend/services/vectordatabase_service.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,21 @@ def get_embedding_model(tenant_id: str):
204204

205205
if model_type == "embedding":
206206
# Get the es core
207-
return OpenAICompatibleEmbedding(api_key=model_config.get("api_key", ""), base_url=model_config.get("base_url", ""), model_name=get_model_name_from_config(model_config) or "", embedding_dim=model_config.get("max_tokens", 1024))
207+
return OpenAICompatibleEmbedding(
208+
api_key=model_config.get("api_key", ""),
209+
base_url=model_config.get("base_url", ""),
210+
model_name=get_model_name_from_config(model_config) or "",
211+
embedding_dim=model_config.get("max_tokens", 1024),
212+
ssl_verify=model_config.get("ssl_verify", True),
213+
)
208214
elif model_type == "multi_embedding":
209-
return JinaEmbedding(api_key=model_config.get("api_key", ""), base_url=model_config.get("base_url", ""), model_name=get_model_name_from_config(model_config) or "", embedding_dim=model_config.get("max_tokens", 1024))
215+
return JinaEmbedding(
216+
api_key=model_config.get("api_key", ""),
217+
base_url=model_config.get("base_url", ""),
218+
model_name=get_model_name_from_config(model_config) or "",
219+
embedding_dim=model_config.get("max_tokens", 1024),
220+
ssl_verify=model_config.get("ssl_verify", True),
221+
)
210222
else:
211223
return None
212224

@@ -997,7 +1009,7 @@ async def summary_index_name(self,
9971009
):
9981010
"""
9991011
Generate a summary for the specified index using advanced Map-Reduce approach
1000-
1012+
10011013
New implementation:
10021014
1. Get documents and cluster them by semantic similarity
10031015
2. Map: Summarize each document individually
@@ -1019,17 +1031,17 @@ async def summary_index_name(self,
10191031
try:
10201032
if not tenant_id:
10211033
raise Exception("Tenant ID is required for summary generation.")
1022-
1034+
10231035
from utils.document_vector_utils import (
10241036
process_documents_for_clustering,
10251037
kmeans_cluster_documents,
10261038
summarize_clusters_map_reduce,
10271039
merge_cluster_summaries
10281040
)
1029-
1041+
10301042
# Use new Map-Reduce approach
10311043
sample_count = min(batch_size // 5, 200) # Sample reasonable number of documents
1032-
1044+
10331045
# Define a helper function to run all blocking operations in a thread pool
10341046
def _generate_summary_sync():
10351047
"""Synchronous function that performs all blocking operations"""
@@ -1039,13 +1051,13 @@ def _generate_summary_sync():
10391051
vdb_core=vdb_core,
10401052
sample_doc_count=sample_count
10411053
)
1042-
1054+
10431055
if not document_samples:
10441056
raise Exception("No documents found in index.")
1045-
1057+
10461058
# Step 2: Cluster documents (CPU-intensive operation)
10471059
clusters = kmeans_cluster_documents(doc_embeddings, k=None)
1048-
1060+
10491061
# Step 3: Map-Reduce summarization (contains blocking LLM calls)
10501062
cluster_summaries = summarize_clusters_map_reduce(
10511063
document_samples=document_samples,
@@ -1056,11 +1068,11 @@ def _generate_summary_sync():
10561068
model_id=model_id,
10571069
tenant_id=tenant_id
10581070
)
1059-
1071+
10601072
# Step 4: Merge into final summary
10611073
final_summary = merge_cluster_summaries(cluster_summaries)
10621074
return final_summary
1063-
1075+
10641076
# Run blocking operations in a thread pool to avoid blocking the event loop
10651077
# Use get_running_loop() for better compatibility with modern asyncio
10661078
try:
@@ -1069,7 +1081,7 @@ def _generate_summary_sync():
10691081
# Fallback for edge cases
10701082
loop = asyncio.get_event_loop()
10711083
final_summary = await loop.run_in_executor(None, _generate_summary_sync)
1072-
1084+
10731085
# Stream the result
10741086
async def generate_summary():
10751087
try:
@@ -1080,12 +1092,12 @@ async def generate_summary():
10801092
yield "data: {\"status\": \"completed\"}\n\n"
10811093
except Exception as e:
10821094
yield f"data: {{\"status\": \"error\", \"message\": \"{e}\"}}\n\n"
1083-
1095+
10841096
return StreamingResponse(
10851097
generate_summary(),
10861098
media_type="text/event-stream"
10871099
)
1088-
1100+
10891101
except Exception as e:
10901102
logger.error(f"Knowledge base summary generation failed: {str(e)}", exc_info=True)
10911103
raise Exception(f"Failed to generate summary: {str(e)}")

backend/utils/attachment_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ def convert_image_to_text(query: str, image_input: Union[str, BinaryIO], tenant_
3434
temperature=0.7,
3535
top_p=0.7,
3636
frequency_penalty=0.5,
37-
max_tokens=512
37+
max_tokens=512,
38+
ssl_verify=vlm_model_config.get("ssl_verify", True),
3839
)
3940

4041
# Load prompts from yaml file
@@ -65,7 +66,8 @@ def convert_long_text_to_text(query: str, file_context: str, tenant_id: str, lan
6566
model_id=get_model_name_from_config(main_model_config),
6667
api_base=main_model_config.get("base_url"),
6768
api_key=main_model_config.get("api_key"),
68-
max_context_tokens=main_model_config.get("max_tokens")
69+
max_context_tokens=main_model_config.get("max_tokens"),
70+
ssl_verify=main_model_config.get("ssl_verify", True),
6971
)
7072

7173
# Load prompts from yaml file

backend/utils/llm_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ def call_llm_for_system_prompt(
7171
api_key=llm_model_config.get("api_key", "") if llm_model_config else "",
7272
temperature=0.3,
7373
top_p=0.95,
74+
model_factory=llm_model_config.get("model_factory") if llm_model_config else None,
75+
ssl_verify=llm_model_config.get("ssl_verify", True) if llm_model_config else True,
7476
)
7577
messages = [
7678
{"role": MESSAGE_ROLE["SYSTEM"], "content": system_prompt},
@@ -108,15 +110,15 @@ def call_llm_for_system_prompt(
108110
token_join,
109111
callback,
110112
)
111-
113+
112114
result = "".join(token_join)
113115
if not result and content_tokens_seen > 0:
114116
logger.warning(
115117
"Generated prompt is empty but %d content tokens were processed. "
116118
"This suggests all content was filtered out.",
117119
content_tokens_seen
118120
)
119-
121+
120122
return result
121123
except Exception as exc:
122124
logger.error("Failed to generate prompt from LLM: %s", str(exc))

backend/utils/str_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ def remove_think_blocks(text: str) -> str:
55
"""Remove <think>...</think> blocks including inner content."""
66
if not text:
77
return text
8-
return re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL | re.IGNORECASE)
8+
return re.sub(r"(?:<think>)?.*?</think>", "", text, flags=re.DOTALL | re.IGNORECASE)

0 commit comments

Comments
 (0)