Skip to content

Commit 3fb447c

Browse files
authored
✨ Support select model in multi cases
2 parents 574c1ca + 75acbb7 commit 3fb447c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+2664
-3735
lines changed

backend/agents/create_agent_info.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from services.memory_config_service import build_memory_context
1616
from database.agent_db import search_agent_info_by_agent_id, query_sub_agents_id_list
1717
from database.tool_db import search_tools_for_sub_agent
18+
from database.model_management_db import get_model_records
19+
from utils.model_name_utils import add_repo_to_name
1820
from utils.prompt_template_utils import get_agent_prompt_template
1921
from utils.config_utils import tenant_config_manager, get_model_name_from_config
2022
from consts.const import LOCAL_MCP_SERVER, MODEL_CONFIG_MAPPING, LANGUAGE
@@ -24,21 +26,34 @@
2426

2527

2628
async def create_model_config_list(tenant_id):
29+
records = get_model_records({"model_type": "llm"}, tenant_id)
30+
model_list = []
31+
for record in records:
32+
model_list.append(
33+
ModelConfig(cite_name=record["display_name"],
34+
api_key=record.get("api_key", ""),
35+
model_name=add_repo_to_name(
36+
model_repo=record["model_repo"],
37+
model_name=record["model_name"],
38+
),
39+
url=record["base_url"]))
40+
# fit for old version, main_model and sub_model use default model
2741
main_model_config = tenant_config_manager.get_model_config(
2842
key=MODEL_CONFIG_MAPPING["llm"], tenant_id=tenant_id)
29-
sub_model_config = tenant_config_manager.get_model_config(
30-
key=MODEL_CONFIG_MAPPING["llmSecondary"], tenant_id=tenant_id)
31-
32-
return [ModelConfig(cite_name="main_model",
33-
api_key=main_model_config.get("api_key", ""),
34-
model_name=get_model_name_from_config(main_model_config) if main_model_config.get(
35-
"model_name") else "",
36-
url=main_model_config.get("base_url", "")),
37-
ModelConfig(cite_name="sub_model",
38-
api_key=sub_model_config.get("api_key", ""),
39-
model_name=get_model_name_from_config(sub_model_config) if sub_model_config.get(
40-
"model_name") else "",
41-
url=sub_model_config.get("base_url", ""))]
43+
model_list.append(
44+
ModelConfig(cite_name="main_model",
45+
api_key=main_model_config.get("api_key", ""),
46+
model_name=get_model_name_from_config(main_model_config) if main_model_config.get(
47+
"model_name") else "",
48+
url=main_model_config.get("base_url", "")))
49+
model_list.append(
50+
ModelConfig(cite_name="sub_model",
51+
api_key=main_model_config.get("api_key", ""),
52+
model_name=get_model_name_from_config(main_model_config) if main_model_config.get(
53+
"model_name") else "",
54+
url=main_model_config.get("base_url", "")))
55+
56+
return model_list
4257

4358

4459
async def create_agent_config(
@@ -336,8 +351,7 @@ async def create_agent_run_info(
336351
"remote_mcp_server": default_mcp_url,
337352
"status": True
338353
})
339-
remote_mcp_dict = {record["remote_mcp_server_name"]
340-
: record for record in remote_mcp_list if record["status"]}
354+
remote_mcp_dict = {record["remote_mcp_server_name"]: record for record in remote_mcp_list if record["status"]}
341355

342356
# Filter MCP servers and tools
343357
mcp_host = filter_mcp_servers_and_tools(agent_config, remote_mcp_dict)

backend/apps/knowledge_summary_app.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ async def auto_summary(
2020
description="Name of the index to get documents from"),
2121
batch_size: int = Query(
2222
1000, description="Number of documents to retrieve per batch"),
23+
model_id: Optional[int] = Query(
24+
None, description="Model ID to use for summary generation"),
2325
es_core: ElasticSearchCore = Depends(get_es_core),
2426
authorization: Optional[str] = Header(None)
2527
):
@@ -34,7 +36,8 @@ async def auto_summary(
3436
batch_size=batch_size,
3537
es_core=es_core,
3638
tenant_id=tenant_id,
37-
language=language
39+
language=language,
40+
model_id=model_id
3841
)
3942
except Exception as e:
4043
logger.error("Knowledge base summary generation failed", exc_info=True)

backend/apps/model_managment_app.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
batch_update_models_for_tenant,
3939
delete_model_for_tenant,
4040
list_models_for_tenant,
41+
list_llm_models_for_tenant,
4142
)
4243
from utils.auth_utils import get_current_user_id
4344

@@ -258,6 +259,22 @@ async def get_model_list(authorization: Optional[str] = Header(None)):
258259
detail="Failed to retrieve model list")
259260

260261

262+
@router.get("/llm_list")
263+
async def get_llm_model_list(authorization: Optional[str] = Header(None)):
264+
"""Get list of LLM models for the current tenant."""
265+
try:
266+
_, tenant_id = get_current_user_id(authorization)
267+
llm_list = await list_llm_models_for_tenant(tenant_id)
268+
return JSONResponse(status_code=HTTPStatus.OK, content={
269+
"message": "Successfully retrieved LLM list",
270+
"data": jsonable_encoder(llm_list)
271+
})
272+
except Exception as e:
273+
logging.error(f"Failed to retrieve LLM list: {str(e)}")
274+
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
275+
detail="Failed to retrieve LLM list")
276+
277+
261278
@router.post("/healthcheck")
262279
async def check_model_health(
263280
display_name: str = Query(..., description="Display name to check"),

backend/apps/prompt_app.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ async def generate_and_save_system_prompt_api(
2323
authorization, http_request)
2424
return StreamingResponse(gen_system_prompt_streamable(
2525
agent_id=prompt_request.agent_id,
26+
model_id=prompt_request.model_id,
2627
task_description=prompt_request.task_description,
2728
user_id=user_id,
2829
tenant_id=tenant_id,

backend/consts/const.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,6 @@
214214

215215
MODEL_CONFIG_MAPPING = {
216216
"llm": "LLM_ID",
217-
"llmSecondary": "LLM_SECONDARY_ID",
218217
"embedding": "EMBEDDING_ID",
219218
"multiEmbedding": "MULTI_EMBEDDING_ID",
220219
"rerank": "RERANK_ID",

backend/consts/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ class SingleModelConfig(BaseModel):
8787

8888
class ModelConfig(BaseModel):
8989
llm: SingleModelConfig
90-
llmSecondary: SingleModelConfig
9190
embedding: SingleModelConfig
9291
multiEmbedding: SingleModelConfig
9392
rerank: SingleModelConfig
@@ -189,6 +188,7 @@ class OpinionRequest(BaseModel):
189188
class GeneratePromptRequest(BaseModel):
190189
task_description: str
191190
agent_id: int
191+
model_id: int
192192

193193

194194
class GenerateTitleRequest(BaseModel):

backend/database/agent_db.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ def create_agent(agent_info, tenant_id: str, user_id: str):
7979
"tenant_id": tenant_id,
8080
"created_by": user_id,
8181
"updated_by": user_id,
82-
"model_name": "main_model",
8382
"max_steps": 5
8483
})
8584
with get_db_session() as session:

backend/services/agent_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1049,4 +1049,4 @@ def get_sub_agents_recursive(parent_agent_id: int, depth: int = 0, max_depth: in
10491049
except Exception as e:
10501050
logger.exception(
10511051
f"Failed to get agent call relationship for agent {agent_id}: {str(e)}")
1052-
raise ValueError(f"Failed to get agent call relationship: {str(e)}")
1052+
raise ValueError(f"Failed to get agent call relationship: {str(e)}")

backend/services/config_sync_service.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,8 @@ async def load_config_impl(language: str, tenant_id: str):
130130

131131
def build_app_config(language: str, tenant_id: str) -> dict:
132132
default_app_name = DEFAULT_APP_NAME_ZH if language == LANGUAGE["ZH"] else DEFAULT_APP_NAME_EN
133-
default_app_description = DEFAULT_APP_DESCRIPTION_ZH if language == LANGUAGE["ZH"] else DEFAULT_APP_DESCRIPTION_EN
133+
default_app_description = DEFAULT_APP_DESCRIPTION_ZH if language == LANGUAGE[
134+
"ZH"] else DEFAULT_APP_DESCRIPTION_EN
134135

135136
return {
136137
"name": tenant_config_manager.get_app_config(APP_NAME, tenant_id=tenant_id) or default_app_name,

backend/services/elasticsearch_service.py

Lines changed: 66 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from datetime import datetime, timezone
1717
from typing import Any, Dict, Generator, List, Optional
1818

19-
from dotenv import load_dotenv
2019
from fastapi import Body, Depends, Path, Query
2120
from fastapi.responses import StreamingResponse
2221
from jinja2 import Template, StrictUndefined
@@ -43,7 +42,9 @@
4342
logger = logging.getLogger("elasticsearch_service")
4443

4544

46-
def generate_knowledge_summary_stream(keywords: str, language: str, tenant_id: str) -> Generator:
45+
46+
47+
def generate_knowledge_summary_stream(keywords: str, language: str, tenant_id: str, model_id: Optional[int] = None) -> Generator:
4748
"""
4849
Generate a knowledge base summary based on keywords
4950
@@ -55,9 +56,6 @@ def generate_knowledge_summary_stream(keywords: str, language: str, tenant_id: s
5556
Returns:
5657
str: Generate a knowledge base summary
5758
"""
58-
# Load environment variables
59-
load_dotenv()
60-
6159
# Load prompt words based on language
6260
prompts = get_knowledge_summary_prompt_template(language)
6361

@@ -73,20 +71,47 @@ def generate_knowledge_summary_stream(keywords: str, language: str, tenant_id: s
7371
{"role": MESSAGE_ROLE["USER"], "content": user_prompt}
7472
]
7573

76-
# Get model configuration from tenant config manager
77-
model_config = tenant_config_manager.get_model_config(
78-
key=MODEL_CONFIG_MAPPING["llmSecondary"], tenant_id=tenant_id)
74+
# Get model configuration
75+
if model_id:
76+
try:
77+
from database.model_management_db import get_model_by_model_id
78+
model_info = get_model_by_model_id(model_id, tenant_id)
79+
if model_info:
80+
model_config = {
81+
'api_key': model_info.get('api_key', ''),
82+
'base_url': model_info.get('base_url', ''),
83+
'model_name': model_info.get('model_name', ''),
84+
'model_repo': model_info.get('model_repo', '')
85+
}
86+
else:
87+
# Fallback to default model if specified model not found
88+
logger.warning(f"Specified model {model_id} not found, falling back to default LLM.")
89+
model_config = tenant_config_manager.get_model_config(
90+
key=MODEL_CONFIG_MAPPING["llm"], tenant_id=tenant_id)
91+
except Exception as e:
92+
logger.warning(f"Failed to get model {model_id}, using default model: {e}")
93+
model_config = tenant_config_manager.get_model_config(
94+
key=MODEL_CONFIG_MAPPING["llm"], tenant_id=tenant_id)
95+
else:
96+
# Use default model configuration
97+
model_config = tenant_config_manager.get_model_config(
98+
key=MODEL_CONFIG_MAPPING["llm"], tenant_id=tenant_id)
7999

80100
# initialize OpenAI client
81101
client = OpenAI(api_key=model_config.get('api_key', ""),
82102
base_url=model_config.get('base_url', ""))
83103

84104
try:
85105
# Create stream chat completion request
86-
max_tokens = KNOWLEDGE_SUMMARY_MAX_TOKENS_ZH if language == LANGUAGE["ZH"] else KNOWLEDGE_SUMMARY_MAX_TOKENS_EN
106+
max_tokens = KNOWLEDGE_SUMMARY_MAX_TOKENS_ZH if language == LANGUAGE[
107+
"ZH"] else KNOWLEDGE_SUMMARY_MAX_TOKENS_EN
108+
# Get model name for the request
109+
model_name_for_request = model_config.get("model_name", "")
110+
if model_config.get("model_repo"):
111+
model_name_for_request = f"{model_config['model_repo']}/{model_name_for_request}"
112+
87113
stream = client.chat.completions.create(
88-
model=get_model_name_from_config(model_config) if model_config.get(
89-
"model_name") else "", # use model name from config
114+
model=model_name_for_request,
90115
messages=messages,
91116
max_tokens=max_tokens, # add max_tokens limit
92117
stream=True # enable stream output
@@ -385,7 +410,8 @@ async def delete_index(
385410
}
386411
success = delete_knowledge_record(update_data)
387412
if not success:
388-
raise Exception(f"Error deleting knowledge record for index {index_name}")
413+
raise Exception(
414+
f"Error deleting knowledge record for index {index_name}")
389415

390416
return {"status": "success", "message": f"Index {index_name} and associated files deleted successfully"}
391417
except Exception as e:
@@ -397,8 +423,10 @@ def list_indices(
397423
"*", description="Pattern to match index names"),
398424
include_stats: bool = Query(
399425
False, description="Whether to include index stats"),
400-
tenant_id: str = Body(description="ID of the tenant listing the knowledge base"),
401-
user_id: str = Body(description="ID of the user listing the knowledge base"),
426+
tenant_id: str = Body(
427+
description="ID of the tenant listing the knowledge base"),
428+
user_id: str = Body(
429+
description="ID of the user listing the knowledge base"),
402430
es_core: ElasticSearchCore = Depends(get_es_core)
403431
):
404432
"""
@@ -424,7 +452,8 @@ def list_indices(
424452
for record in db_record:
425453
# async PG database to sync ES, remove the data that is not in ES
426454
if record["index_name"] not in all_indices_list:
427-
delete_knowledge_record({"index_name": record["index_name"], "user_id": user_id})
455+
delete_knowledge_record(
456+
{"index_name": record["index_name"], "user_id": user_id})
428457
continue
429458
if record["embedding_model_name"] is None:
430459
model_name_is_none_list.append(record["index_name"])
@@ -449,8 +478,9 @@ def list_indices(
449478
"stats": index_stats
450479
})
451480
if index_name in model_name_is_none_list:
452-
update_model_name_by_index_name(index_name,
453-
index_stats.get("base_info", {}).get("embedding_model", ""),
481+
update_model_name_by_index_name(index_name,
482+
index_stats.get("base_info", {}).get(
483+
"embedding_model", ""),
454484
tenant_id, user_id)
455485
response["indices_info"] = stats_info
456486

@@ -514,11 +544,14 @@ def get_index_name(
514544
error_msg = str(e)
515545
# Check if it's an ElasticSearch connection issue
516546
if "503" in error_msg or "search_phase_execution_exception" in error_msg:
517-
raise Exception(f"ElasticSearch service unavailable for index {index_name}: {error_msg}")
547+
raise Exception(
548+
f"ElasticSearch service unavailable for index {index_name}: {error_msg}")
518549
elif "ApiError" in error_msg:
519-
raise Exception(f"ElasticSearch API error for index {index_name}: {error_msg}")
550+
raise Exception(
551+
f"ElasticSearch API error for index {index_name}: {error_msg}")
520552
else:
521-
raise Exception(f"Error getting info for index {index_name}: {error_msg}")
553+
raise Exception(
554+
f"Error getting info for index {index_name}: {error_msg}")
522555

523556
@staticmethod
524557
def index_documents(
@@ -551,7 +584,8 @@ def index_documents(
551584
index_name, es_core=es_core)
552585
logger.info(f"Created new index {index_name}")
553586
except Exception as create_error:
554-
raise Exception(f"Failed to create index {index_name}: {str(create_error)}")
587+
raise Exception(
588+
f"Failed to create index {index_name}: {str(create_error)}")
555589

556590
# Transform indexing request results to documents
557591
documents = []
@@ -783,7 +817,8 @@ async def list_files(
783817
return {"files": files}
784818

785819
except Exception as e:
786-
raise Exception(f"Error getting file list for index {index_name}: {str(e)}")
820+
raise Exception(
821+
f"Error getting file list for index {index_name}: {str(e)}")
787822

788823
@staticmethod
789824
def delete_documents(
@@ -828,9 +863,12 @@ async def summary_index_name(self,
828863
1000, description="Number of documents to retrieve per batch"),
829864
es_core: ElasticSearchCore = Depends(
830865
get_es_core),
866+
user_id: Optional[str] = Body(
867+
None, description="ID of the user delete the knowledge base"),
831868
tenant_id: Optional[str] = Body(
832869
None, description="ID of the tenant"),
833-
language: str = LANGUAGE["ZH"]
870+
language: str = LANGUAGE["ZH"],
871+
model_id: Optional[int] = None
834872
):
835873
"""
836874
Generate a summary for the specified index based on its content
@@ -848,7 +886,8 @@ async def summary_index_name(self,
848886
try:
849887
# Get all documents
850888
if not tenant_id:
851-
raise Exception("Tenant ID is required for summary generation.")
889+
raise Exception(
890+
"Tenant ID is required for summary generation.")
852891
all_documents = ElasticSearchService.get_random_documents(
853892
index_name, batch_size, es_core)
854893
all_chunks = self._clean_chunks_for_summary(all_documents)
@@ -860,7 +899,7 @@ async def summary_index_name(self,
860899
async def generate_summary():
861900
token_join = []
862901
try:
863-
for new_token in generate_knowledge_summary_stream(keywords_for_summary, language, tenant_id):
902+
for new_token in generate_knowledge_summary_stream(keywords_for_summary, language, tenant_id, model_id):
864903
if new_token == "END":
865904
break
866905
else:
@@ -947,7 +986,8 @@ def get_random_documents(
947986
}
948987

949988
except Exception as e:
950-
raise Exception(f"Error retrieving random documents from index {index_name}: {str(e)}")
989+
raise Exception(
990+
f"Error retrieving random documents from index {index_name}: {str(e)}")
951991

952992
@staticmethod
953993
def change_summary(

0 commit comments

Comments
 (0)