Skip to content

Commit 3aad19d

Browse files
authored
♻️ All models should use nexent.core.model instead of smolagent #1971
2 parents 3f39432 + 96251c4 commit 3aad19d

13 files changed

+236
-70
lines changed

backend/services/conversation_management_service.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import Any, Dict, List, Optional
66

77
from jinja2 import StrictUndefined, Template
8-
from smolagents import OpenAIServerModel
98

109
from consts.const import LANGUAGE, MODEL_CONFIG_MAPPING, MESSAGE_ROLE, DEFAULT_EN_TITLE, DEFAULT_ZH_TITLE
1110
from consts.model import AgentRequest, ConversationResponse, MessageRequest, MessageUnit
@@ -27,7 +26,8 @@
2726
rename_conversation,
2827
update_message_opinion
2928
)
30-
from nexent.core.utils.observer import ProcessType
29+
from nexent.core.utils.observer import MessageObserver, ProcessType
30+
from nexent.core.models import OpenAIModel
3131
from utils.config_utils import get_model_name_from_config, tenant_config_manager
3232
from utils.prompt_template_utils import get_generate_title_prompt_template
3333
from utils.str_utils import remove_think_blocks
@@ -262,8 +262,8 @@ def call_llm_for_title(content: str, tenant_id: str, language: str = LANGUAGE["Z
262262
model_config = tenant_config_manager.get_model_config(
263263
key=MODEL_CONFIG_MAPPING["llm"], tenant_id=tenant_id)
264264

265-
# Create OpenAIServerModel instance
266-
llm = OpenAIServerModel(
265+
# Create OpenAIModel instance
266+
llm = OpenAIModel(
267267
model_id=get_model_name_from_config(model_config) if model_config.get("model_name") else "",
268268
api_base=model_config.get("base_url", ""),
269269
api_key=model_config.get("api_key", ""),

backend/services/vectordatabase_service.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -760,16 +760,16 @@ async def summary_index_name(self,
760760
StreamingResponse containing the generated summary
761761
"""
762762
try:
763+
if not tenant_id:
764+
raise Exception("Tenant ID is required for summary generation.")
765+
763766
from utils.document_vector_utils import (
764767
process_documents_for_clustering,
765768
kmeans_cluster_documents,
766769
summarize_clusters_map_reduce,
767770
merge_cluster_summaries
768771
)
769772

770-
if not tenant_id:
771-
raise Exception("Tenant ID is required for summary generation.")
772-
773773
# Use new Map-Reduce approach
774774
sample_count = min(batch_size // 5, 200) # Sample reasonable number of documents
775775

backend/utils/document_vector_utils.py

Lines changed: 23 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,16 @@
1414

1515
import numpy as np
1616
from jinja2 import Template, StrictUndefined
17-
from nexent.vector_database.base import VectorDatabaseCore
1817
from sklearn.cluster import KMeans
1918
from sklearn.metrics import silhouette_score
2019
from sklearn.metrics.pairwise import cosine_similarity
2120

2221
from consts.const import LANGUAGE
22+
from database.model_management_db import get_model_by_model_id
23+
from nexent.core.utils.observer import MessageObserver
24+
from nexent.core.models import OpenAIModel
25+
from nexent.vector_database.base import VectorDatabaseCore
26+
from utils.llm_utils import call_llm_for_system_prompt
2327
from utils.prompt_template_utils import (
2428
get_document_summary_prompt_template,
2529
get_cluster_summary_reduce_prompt_template,
@@ -568,37 +572,22 @@ def summarize_document(document_content: str, filename: str, language: str = LAN
568572

569573
# Call LLM if model_id and tenant_id are provided
570574
if model_id and tenant_id:
571-
from smolagents import OpenAIServerModel
572-
from database.model_management_db import get_model_by_model_id
573-
from utils.config_utils import get_model_name_from_config
574-
from consts.const import MESSAGE_ROLE
575-
575+
576576
# Get model configuration
577577
llm_model_config = get_model_by_model_id(model_id=model_id, tenant_id=tenant_id)
578578
if not llm_model_config:
579579
logger.warning(f"No model configuration found for model_id: {model_id}, tenant_id: {tenant_id}")
580580
return f"[Document Summary: {filename}] (max {max_words} words) - Content: {document_content[:200]}..."
581-
582-
# Create LLM instance
583-
llm = OpenAIServerModel(
584-
model_id=get_model_name_from_config(llm_model_config) if llm_model_config else "",
585-
api_base=llm_model_config.get("base_url", ""),
586-
api_key=llm_model_config.get("api_key", ""),
587-
temperature=0.3,
588-
top_p=0.95
581+
582+
document_summary = call_llm_for_system_prompt(
583+
model_id=model_id,
584+
user_prompt=user_prompt,
585+
system_prompt=system_prompt,
586+
callback=None,
587+
tenant_id=tenant_id
589588
)
590-
591-
# Build messages
592-
messages = [
593-
{"role": MESSAGE_ROLE["SYSTEM"], "content": system_prompt},
594-
{"role": MESSAGE_ROLE["USER"], "content": user_prompt}
595-
]
596-
597-
# Call LLM, allow more tokens for generation
598-
response = llm(messages, max_tokens=max_words * 2)
599-
if not response or not response.content:
600-
return ""
601-
return response.content.strip()
589+
590+
return (document_summary or "").strip()
602591
else:
603592
# Fallback to placeholder if no model configuration
604593
logger.warning("No model_id or tenant_id provided, using placeholder summary")
@@ -642,10 +631,6 @@ def summarize_cluster(document_summaries: List[str], language: str = LANGUAGE["Z
642631

643632
# Call LLM if model_id and tenant_id are provided
644633
if model_id and tenant_id:
645-
from smolagents import OpenAIServerModel
646-
from database.model_management_db import get_model_by_model_id
647-
from utils.config_utils import get_model_name_from_config
648-
from consts.const import MESSAGE_ROLE
649634

650635
# Get model configuration
651636
llm_model_config = get_model_by_model_id(model_id=model_id, tenant_id=tenant_id)
@@ -654,25 +639,15 @@ def summarize_cluster(document_summaries: List[str], language: str = LANGUAGE["Z
654639
return f"[Cluster Summary] (max {max_words} words) - Based on {len(document_summaries)} documents"
655640

656641
# Create LLM instance
657-
llm = OpenAIServerModel(
658-
model_id=get_model_name_from_config(llm_model_config) if llm_model_config else "",
659-
api_base=llm_model_config.get("base_url", ""),
660-
api_key=llm_model_config.get("api_key", ""),
661-
temperature=0.3,
662-
top_p=0.95
642+
cluster_summary = call_llm_for_system_prompt(
643+
model_id=model_id,
644+
user_prompt=user_prompt,
645+
system_prompt=system_prompt,
646+
callback=None,
647+
tenant_id=tenant_id
663648
)
664-
665-
# Build messages
666-
messages = [
667-
{"role": MESSAGE_ROLE["SYSTEM"], "content": system_prompt},
668-
{"role": MESSAGE_ROLE["USER"], "content": user_prompt}
669-
]
670-
671-
# Call LLM
672-
response = llm(messages, max_tokens=max_words * 2) # Allow more tokens for generation
673-
if not response or not response.content:
674-
return ""
675-
return response.content.strip()
649+
650+
return (cluster_summary or "").strip()
676651
else:
677652
# Fallback to placeholder if no model configuration
678653
logger.warning("No model_id or tenant_id provided, using placeholder summary")

backend/utils/llm_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import logging
22
from typing import Callable, List, Optional
33

4-
from smolagents import OpenAIServerModel
5-
64
from consts.const import MESSAGE_ROLE, THINK_END_PATTERN, THINK_START_PATTERN
75
from database.model_management_db import get_model_by_model_id
6+
from nexent.core.utils.observer import MessageObserver
7+
from nexent.core.models import OpenAIModel
88
from utils.config_utils import get_model_name_from_config
99

1010
logger = logging.getLogger("llm_utils")
@@ -44,7 +44,7 @@ def call_llm_for_system_prompt(
4444
"""
4545
llm_model_config = get_model_by_model_id(model_id=model_id, tenant_id=tenant_id)
4646

47-
llm = OpenAIServerModel(
47+
llm = OpenAIModel(
4848
model_id=get_model_name_from_config(llm_model_config) if llm_model_config else "",
4949
api_base=llm_model_config.get("base_url", ""),
5050
api_key=llm_model_config.get("api_key", ""),

sdk/nexent/core/models/openai_llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
logger = logging.getLogger("openai_llm")
1515

1616
class OpenAIModel(OpenAIServerModel):
17-
def __init__(self, observer: MessageObserver, temperature=0.2, top_p=0.95,
17+
def __init__(self, observer: MessageObserver = MessageObserver, temperature=0.2, top_p=0.95,
1818
ssl_verify=True, *args, **kwargs):
1919
"""
2020
Initialize OpenAI Model with observer and SSL verification option.

test/backend/services/test_conversation_management_service.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def test_extract_user_messages(self):
327327
self.assertIn("Give me examples of AI applications", result)
328328
self.assertIn("AI stands for Artificial Intelligence.", result)
329329

330-
@patch('backend.services.conversation_management_service.OpenAIServerModel')
330+
@patch('backend.services.conversation_management_service.OpenAIModel')
331331
@patch('backend.services.conversation_management_service.get_generate_title_prompt_template')
332332
@patch('backend.services.conversation_management_service.tenant_config_manager.get_model_config')
333333
def test_call_llm_for_title(self, mock_get_model_config, mock_get_prompt_template, mock_openai):
@@ -360,7 +360,7 @@ def test_call_llm_for_title(self, mock_get_model_config, mock_get_prompt_templat
360360
mock_llm_instance.generate.assert_called_once()
361361
mock_get_prompt_template.assert_called_once_with(language='zh')
362362

363-
@patch('backend.services.conversation_management_service.OpenAIServerModel')
363+
@patch('backend.services.conversation_management_service.OpenAIModel')
364364
@patch('backend.services.conversation_management_service.get_generate_title_prompt_template')
365365
@patch('backend.services.conversation_management_service.tenant_config_manager.get_model_config')
366366
def test_call_llm_for_title_response_none_zh(self, mock_get_model_config, mock_get_prompt_template, mock_openai):
@@ -392,7 +392,7 @@ def test_call_llm_for_title_response_none_zh(self, mock_get_model_config, mock_g
392392
mock_llm_instance.generate.assert_called_once()
393393
mock_get_prompt_template.assert_called_once_with(language='zh')
394394

395-
@patch('backend.services.conversation_management_service.OpenAIServerModel')
395+
@patch('backend.services.conversation_management_service.OpenAIModel')
396396
@patch('backend.services.conversation_management_service.get_generate_title_prompt_template')
397397
@patch('backend.services.conversation_management_service.tenant_config_manager.get_model_config')
398398
def test_call_llm_for_title_response_none_en(self, mock_get_model_config, mock_get_prompt_template, mock_openai):

test/backend/services/test_vectordatabase_service.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,19 @@ def _create_package_mock(name: str) -> MagicMock:
3535
sys.modules['nexent.core'] = _create_package_mock('nexent.core')
3636
sys.modules['nexent.core.agents'] = _create_package_mock('nexent.core.agents')
3737
sys.modules['nexent.core.agents.agent_model'] = MagicMock()
38-
sys.modules['nexent.core.models'] = _create_package_mock('nexent.core.models')
38+
# Mock nexent.core.models with OpenAIModel
39+
openai_model_module = ModuleType('nexent.core.models')
40+
openai_model_module.OpenAIModel = MagicMock
41+
sys.modules['nexent.core.models'] = openai_model_module
3942
sys.modules['nexent.core.models.embedding_model'] = MagicMock()
4043
sys.modules['nexent.core.models.stt_model'] = MagicMock()
4144
sys.modules['nexent.core.nlp'] = _create_package_mock('nexent.core.nlp')
4245
sys.modules['nexent.core.nlp.tokenizer'] = MagicMock()
46+
# Mock nexent.core.utils and observer module
47+
sys.modules['nexent.core.utils'] = _create_package_mock('nexent.core.utils')
48+
observer_module = ModuleType('nexent.core.utils.observer')
49+
observer_module.MessageObserver = MagicMock
50+
sys.modules['nexent.core.utils.observer'] = observer_module
4351
sys.modules['nexent.vector_database'] = _create_package_mock('nexent.vector_database')
4452
vector_db_base_module = ModuleType('nexent.vector_database.base')
4553

@@ -96,6 +104,8 @@ class _VectorDatabaseCore:
96104
# Apply the patches before importing the module being tested
97105
with patch('botocore.client.BaseClient._make_api_call'), \
98106
patch('elasticsearch.Elasticsearch', return_value=MagicMock()):
107+
# Import utils.document_vector_utils to ensure it's available for patching
108+
import utils.document_vector_utils
99109
from backend.services.vectordatabase_service import ElasticSearchService, check_knowledge_base_exist_impl
100110

101111

test/backend/test_cluster_summarization.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,39 @@
1010
import numpy as np
1111
import pytest
1212

13-
# Add backend to path
13+
# Mock consts module before patching backend.database.client to avoid ImportError
14+
# backend.database.client imports from consts.const, so we need to mock it first
15+
consts_mock = MagicMock()
16+
consts_const_mock = MagicMock()
17+
# Set required constants that backend.database.client might use
18+
consts_const_mock.MINIO_ENDPOINT = "http://localhost:9000"
19+
consts_const_mock.MINIO_ACCESS_KEY = "test_access_key"
20+
consts_const_mock.MINIO_SECRET_KEY = "test_secret_key"
21+
consts_const_mock.MINIO_REGION = "us-east-1"
22+
consts_const_mock.MINIO_DEFAULT_BUCKET = "test-bucket"
23+
consts_const_mock.POSTGRES_HOST = "localhost"
24+
consts_const_mock.POSTGRES_USER = "test_user"
25+
consts_const_mock.NEXENT_POSTGRES_PASSWORD = "test_password"
26+
consts_const_mock.POSTGRES_DB = "test_db"
27+
consts_const_mock.POSTGRES_PORT = 5432
28+
consts_const_mock.LANGUAGE = {"ZH": "zh", "EN": "en"}
29+
consts_mock.const = consts_const_mock
30+
sys.modules['consts'] = consts_mock
31+
sys.modules['consts.const'] = consts_const_mock
32+
33+
# Add backend to path before patching backend modules
1434
current_dir = os.path.dirname(os.path.abspath(__file__))
1535
backend_dir = os.path.abspath(os.path.join(current_dir, "../../backend"))
1636
sys.path.insert(0, backend_dir)
1737

38+
# Patch storage factory and MinIO config validation to avoid errors during initialization
39+
# These patches must be started before any imports that use MinioClient
40+
storage_client_mock = MagicMock()
41+
minio_client_mock = MagicMock()
42+
patch('nexent.storage.storage_client_factory.create_storage_client_from_config', return_value=storage_client_mock).start()
43+
patch('nexent.storage.minio_config.MinIOStorageConfig.validate', lambda self: None).start()
44+
patch('backend.database.client.MinioClient', return_value=minio_client_mock).start()
45+
1846
from backend.utils.document_vector_utils import (
1947
extract_cluster_content,
2048
summarize_cluster,

test/backend/test_document_vector_integration.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,38 @@
1111
import numpy as np
1212
import pytest
1313

14-
# Add backend to path
14+
# Mock consts module before patching backend.database.client to avoid ImportError
15+
# backend.database.client imports from consts.const, so we need to mock it first
16+
consts_mock = MagicMock()
17+
consts_const_mock = MagicMock()
18+
# Set required constants that backend.database.client might use
19+
consts_const_mock.MINIO_ENDPOINT = "http://localhost:9000"
20+
consts_const_mock.MINIO_ACCESS_KEY = "test_access_key"
21+
consts_const_mock.MINIO_SECRET_KEY = "test_secret_key"
22+
consts_const_mock.MINIO_REGION = "us-east-1"
23+
consts_const_mock.MINIO_DEFAULT_BUCKET = "test-bucket"
24+
consts_const_mock.POSTGRES_HOST = "localhost"
25+
consts_const_mock.POSTGRES_USER = "test_user"
26+
consts_const_mock.NEXENT_POSTGRES_PASSWORD = "test_password"
27+
consts_const_mock.POSTGRES_DB = "test_db"
28+
consts_const_mock.POSTGRES_PORT = 5432
29+
consts_mock.const = consts_const_mock
30+
sys.modules['consts'] = consts_mock
31+
sys.modules['consts.const'] = consts_const_mock
32+
33+
# Add backend to path before patching backend modules
1534
current_dir = os.path.dirname(os.path.abspath(__file__))
1635
backend_dir = os.path.abspath(os.path.join(current_dir, "../../backend"))
1736
sys.path.insert(0, backend_dir)
1837

38+
# Patch storage factory and MinIO config validation to avoid errors during initialization
39+
# These patches must be started before any imports that use MinioClient
40+
storage_client_mock = MagicMock()
41+
minio_client_mock = MagicMock()
42+
patch('nexent.storage.storage_client_factory.create_storage_client_from_config', return_value=storage_client_mock).start()
43+
patch('nexent.storage.minio_config.MinIOStorageConfig.validate', lambda self: None).start()
44+
patch('backend.database.client.MinioClient', return_value=minio_client_mock).start()
45+
1946
from backend.utils.document_vector_utils import (
2047
calculate_document_embedding,
2148
auto_determine_k,

0 commit comments

Comments
 (0)