diff --git a/backend/agents/create_agent_info.py b/backend/agents/create_agent_info.py index 558e2010f..d09029a97 100644 --- a/backend/agents/create_agent_info.py +++ b/backend/agents/create_agent_info.py @@ -150,6 +150,8 @@ async def create_agent_config( tenant_id=tenant_id, user_id=user_id) if knowledge_info_list: for knowledge_info in knowledge_info_list: + if knowledge_info.get('knowledge_sources') != 'elasticsearch': + continue knowledge_name = knowledge_info.get("index_name") try: message = ElasticSearchService().get_summary(index_name=knowledge_name) @@ -239,13 +241,22 @@ async def create_tool_config_list(agent_id, tenant_id, user_id): knowledge_info_list = get_selected_knowledge_list( tenant_id=tenant_id, user_id=user_id) index_names = [knowledge_info.get( - "index_name") for knowledge_info in knowledge_info_list] + "index_name") for knowledge_info in knowledge_info_list if knowledge_info.get('knowledge_sources') == 'elasticsearch'] tool_config.metadata = { "index_names": index_names, "vdb_core": get_vector_db_core(), "embedding_model": get_embedding_model(tenant_id=tenant_id), "name_resolver": build_knowledge_name_mapping(tenant_id=tenant_id, user_id=user_id), } + elif tool_config.class_name == "DataMateSearchTool": + knowledge_info_list = get_selected_knowledge_list( + tenant_id=tenant_id, user_id=user_id) + index_names = [knowledge_info.get( + "index_name") for knowledge_info in knowledge_info_list if + knowledge_info.get('knowledge_sources') == 'datamate'] + tool_config.metadata = { + "index_names": index_names, + } elif tool_config.class_name == "AnalyzeTextFileTool": tool_config.metadata = { "llm_model": get_llm_model(tenant_id=tenant_id), diff --git a/backend/apps/config_app.py b/backend/apps/config_app.py index eb2c824c1..e6977683a 100644 --- a/backend/apps/config_app.py +++ b/backend/apps/config_app.py @@ -6,6 +6,7 @@ from apps.agent_app import agent_config_router as agent_router from apps.config_sync_app import router as config_sync_router +from apps.datamate_app import router as datamate_router from apps.vectordatabase_app import router as vectordatabase_router from apps.file_management_app import file_management_config_router as file_manager_router from apps.image_app import router as proxy_router @@ -42,6 +43,7 @@ app.include_router(config_sync_router) app.include_router(agent_router) app.include_router(vectordatabase_router) +app.include_router(datamate_router) app.include_router(voice_router) app.include_router(file_manager_router) app.include_router(proxy_router) diff --git a/backend/apps/datamate_app.py b/backend/apps/datamate_app.py new file mode 100644 index 000000000..129ba91c5 --- /dev/null +++ b/backend/apps/datamate_app.py @@ -0,0 +1,48 @@ +import logging +from typing import Optional + +from fastapi import APIRouter, Header, HTTPException, Path +from fastapi.responses import JSONResponse +from http import HTTPStatus + +from services.datamate_service import ( + sync_datamate_knowledge_bases_and_create_records, + fetch_datamate_knowledge_base_file_list +) +from utils.auth_utils import get_current_user_id + +router = APIRouter(prefix="/datamate") +logger = logging.getLogger("datamate_app") + + + + +@router.post("/sync_and_create_records") +async def sync_datamate_and_create_records_endpoint( + authorization: Optional[str] = Header(None) +): + """Sync DataMate knowledge bases and create knowledge records in local database.""" + try: + user_id, tenant_id = get_current_user_id(authorization) + + return await sync_datamate_knowledge_bases_and_create_records( + tenant_id=tenant_id, + user_id=user_id + ) + except Exception as e: + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=f"Error syncing DataMate knowledge bases and creating records: {str(e)}") + + +@router.get("/{knowledge_base_id}/files") +async def get_datamate_knowledge_base_files_endpoint( + knowledge_base_id: str = Path(..., description="ID of the DataMate knowledge base"), + authorization: Optional[str] = Header(None) +): + """Get all files from a DataMate knowledge base.""" + try: + result = await fetch_datamate_knowledge_base_file_list(knowledge_base_id) + return JSONResponse(status_code=HTTPStatus.OK, content=result) + except Exception as e: + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=f"Error fetching DataMate knowledge base files: {str(e)}") diff --git a/backend/apps/tenant_config_app.py b/backend/apps/tenant_config_app.py index e5dfd0481..e2a490e1c 100644 --- a/backend/apps/tenant_config_app.py +++ b/backend/apps/tenant_config_app.py @@ -6,6 +6,7 @@ from fastapi.responses import JSONResponse from consts.const import DEPLOYMENT_VERSION, APP_VERSION +from consts.model import UpdateKnowledgeListRequest from services.tenant_config_service import get_selected_knowledge_list, update_selected_knowledge from utils.auth_utils import get_current_user_id @@ -61,16 +62,37 @@ def load_knowledge_list( @router.post("/update_knowledge_list") def update_knowledge_list( authorization: Optional[str] = Header(None), - knowledge_list: List[str] = Body(None) + request: UpdateKnowledgeListRequest = Body(...) ): try: user_id, tenant_id = get_current_user_id(authorization) + + # Convert grouped request to flat lists + knowledge_list = [] + knowledge_sources = [] + + if request.nexent: + knowledge_list.extend(request.nexent) + knowledge_sources.extend(["nexent"] * len(request.nexent)) + + if request.datamate: + knowledge_list.extend(request.datamate) + knowledge_sources.extend(["datamate"] * len(request.datamate)) + result = update_selected_knowledge( - tenant_id=tenant_id, user_id=user_id, index_name_list=knowledge_list) + tenant_id=tenant_id, user_id=user_id, index_name_list=knowledge_list, knowledge_sources=knowledge_sources) if result: + # 获取更新后的知识库信息 + selected_knowledge_info = get_selected_knowledge_list( + tenant_id=tenant_id, user_id=user_id) + + content = {"selectedKbNames": [item["index_name"] for item in selected_knowledge_info], + "selectedKbModels": [item["embedding_model_name"] for item in selected_knowledge_info], + "selectedKbSources": [item["knowledge_sources"] for item in selected_knowledge_info]} + return JSONResponse( status_code=HTTPStatus.OK, - content={"message": "update success", "status": "success"} + content={"content": content, "message": "update success", "status": "success"} ) else: raise HTTPException( diff --git a/backend/consts/const.py b/backend/consts/const.py index 003e01057..6d73b95d6 100644 --- a/backend/consts/const.py +++ b/backend/consts/const.py @@ -14,6 +14,7 @@ # Vector database providers class VectorDatabaseType(str, Enum): ELASTICSEARCH = "elasticsearch" + DATAMATE = "datamate" # ModelEngine Configuration @@ -28,6 +29,10 @@ class VectorDatabaseType(str, Enum): ES_USERNAME = "elastic" ELASTICSEARCH_SERVICE = os.getenv("ELASTICSEARCH_SERVICE") +# DataMate Configuration +#todo +DATAMATE_BASE_URL = os.getenv("DATAMATE_BASE_URL", "http://1.94.5.242:30000/") + # Data Processing Service Configuration DATA_PROCESS_SERVICE = os.getenv("DATA_PROCESS_SERVICE") diff --git a/backend/consts/model.py b/backend/consts/model.py index 1e0f5b5e0..df9191fe3 100644 --- a/backend/consts/model.py +++ b/backend/consts/model.py @@ -448,4 +448,12 @@ class MCPServerConfig(BaseModel): class MCPConfigRequest(BaseModel): """Request model for adding MCP servers from configuration""" mcpServers: Dict[str, MCPServerConfig] = Field( - ..., description="Dictionary of MCP server configurations") \ No newline at end of file + ..., description="Dictionary of MCP server configurations") + + +class UpdateKnowledgeListRequest(BaseModel): + """Request model for updating user's selected knowledge base list grouped by source""" + nexent: Optional[List[str]] = Field( + None, description="List of knowledge base index names from nexent source") + datamate: Optional[List[str]] = Field( + None, description="List of knowledge base index names from datamate source") \ No newline at end of file diff --git a/backend/database/db_models.py b/backend/database/db_models.py index 301dd64aa..51bb0f2e5 100644 --- a/backend/database/db_models.py +++ b/backend/database/db_models.py @@ -1,4 +1,4 @@ -from sqlalchemy import Boolean, Column, Integer, JSON, Numeric, Sequence, String, Text, TIMESTAMP +from sqlalchemy import BigInteger, Boolean, Column, Integer, JSON, Numeric, Sequence, String, Text, TIMESTAMP from sqlalchemy.orm import DeclarativeBase from sqlalchemy.sql import func @@ -243,7 +243,7 @@ class KnowledgeRecord(TableBase): __tablename__ = "knowledge_record_t" __table_args__ = {"schema": "nexent"} - knowledge_id = Column(Integer, Sequence("knowledge_record_t_knowledge_id_seq", schema="nexent"), + knowledge_id = Column(BigInteger, Sequence("knowledge_record_t_knowledge_id_seq", schema="nexent"), primary_key=True, nullable=False, doc="Knowledge base ID, unique primary key") index_name = Column(String(100), doc="Internal Elasticsearch index name") knowledge_name = Column(String(100), doc="User-facing knowledge base name") diff --git a/backend/database/knowledge_db.py b/backend/database/knowledge_db.py index 6faccdafa..3538add59 100644 --- a/backend/database/knowledge_db.py +++ b/backend/database/knowledge_db.py @@ -79,6 +79,59 @@ def create_knowledge_record(query: Dict[str, Any]) -> Dict[str, Any]: raise e +def upsert_knowledge_record(query: Dict[str, Any]) -> Dict[str, Any]: + """ + Create or update a knowledge base record (upsert operation). + If a record with the same index_name and tenant_id exists, update it. + Otherwise, create a new record. + + Args: + query: Dictionary containing knowledge base data, must include: + - index_name: Knowledge base name (used as unique identifier) + - tenant_id: Tenant ID + - knowledge_name: User-facing knowledge base name + - knowledge_describe: Knowledge base description + - knowledge_sources: Knowledge base sources (optional, default 'elasticsearch') + - embedding_model_name: Embedding model name + - user_id: User ID for created_by and updated_by fields + + Returns: + Dict[str, Any]: Dictionary with 'knowledge_id' and 'index_name' + """ + try: + with get_db_session() as session: + # Check if record exists + existing_record = session.query(KnowledgeRecord).filter( + KnowledgeRecord.index_name == query['index_name'], + KnowledgeRecord.tenant_id == query['tenant_id'], + KnowledgeRecord.delete_flag != 'Y' + ).first() + + if existing_record: + # Update existing record + existing_record.knowledge_name = query.get('knowledge_name') or query.get('index_name') + existing_record.knowledge_describe = query.get('knowledge_describe', '') + existing_record.knowledge_sources = query.get('knowledge_sources', 'elasticsearch') + existing_record.embedding_model_name = query.get('embedding_model_name') + existing_record.updated_by = query.get('user_id') + existing_record.update_time = func.current_timestamp() + + session.flush() + session.commit() + return { + "knowledge_id": existing_record.knowledge_id, + "index_name": existing_record.index_name, + "knowledge_name": existing_record.knowledge_name, + } + else: + # Create new record + return create_knowledge_record(query) + + except SQLAlchemyError as e: + session.rollback() + raise e + + def update_knowledge_record(query: Dict[str, Any]) -> bool: """ Update a knowledge base record @@ -230,6 +283,29 @@ def get_knowledge_info_by_tenant_id(tenant_id: str) -> List[Dict[str, Any]]: raise e +def get_knowledge_info_by_tenant_and_source(tenant_id: str, knowledge_sources: str) -> List[Dict[str, Any]]: + """ + Get knowledge base records by tenant ID and knowledge sources. + + Args: + tenant_id: Tenant ID to filter by + knowledge_sources: Knowledge sources to filter by (e.g., 'datamate') + + Returns: + List[Dict[str, Any]]: List of knowledge base record dictionaries + """ + try: + with get_db_session() as session: + result = session.query(KnowledgeRecord).filter( + KnowledgeRecord.tenant_id == tenant_id, + KnowledgeRecord.knowledge_sources == knowledge_sources, + KnowledgeRecord.delete_flag != 'Y' + ).all() + return [as_dict(item) for item in result] + except SQLAlchemyError as e: + raise e + + def update_model_name_by_index_name(index_name: str, embedding_model_name: str, tenant_id: str, user_id: str) -> bool: try: with get_db_session() as session: diff --git a/backend/services/agent_service.py b/backend/services/agent_service.py index aafa38ba6..6849f83a5 100644 --- a/backend/services/agent_service.py +++ b/backend/services/agent_service.py @@ -97,45 +97,45 @@ def _resolve_model_with_fallback( ) -> str | None: """ Resolve model_id from model_display_name with fallback to quick config LLM model. - + Args: model_display_name: Display name of the model to lookup exported_model_id: Original model_id from export (for logging only) model_label: Label for logging (e.g., "Model", "Business logic model") tenant_id: Tenant ID for model lookup - + Returns: Resolved model_id or None if not found and no fallback available """ if not model_display_name: return None - + # Try to find model by display name in current tenant resolved_id = get_model_id_by_display_name(model_display_name, tenant_id) - + if resolved_id: logger.info( f"{model_label} '{model_display_name}' found in tenant {tenant_id}, " f"mapped to model_id: {resolved_id} (exported model_id was: {exported_model_id})") return resolved_id - + # Model not found, try fallback to quick config LLM model logger.warning( f"{model_label} '{model_display_name}' (exported model_id: {exported_model_id}) " f"not found in tenant {tenant_id}, falling back to quick config LLM model.") - + quick_config_model = tenant_config_manager.get_model_config( key=MODEL_CONFIG_MAPPING["llm"], tenant_id=tenant_id ) - + if quick_config_model: fallback_id = quick_config_model.get("model_id") logger.info( f"Using quick config LLM model for {model_label.lower()}: " f"{quick_config_model.get('display_name')} (model_id: {fallback_id})") return fallback_id - + logger.warning(f"No quick config LLM model found for tenant {tenant_id}") return None @@ -998,7 +998,7 @@ async def export_agent_by_agent_id(agent_id: int, tenant_id: str, user_id: str) # Check if any tool is KnowledgeBaseSearchTool and set its metadata to empty dict for tool in tool_list: - if tool.class_name in ["KnowledgeBaseSearchTool", "AnalyzeTextFileTool", "AnalyzeImageTool"]: + if tool.class_name in ["KnowledgeBaseSearchTool", "AnalyzeTextFileTool", "AnalyzeImageTool", "DataMateSearchTool"]: tool.metadata = {} # Get model_id and model display name from agent_info @@ -1132,7 +1132,7 @@ async def import_agent_by_agent_id( if not import_agent_info.name.isidentifier(): raise ValueError( f"Invalid agent name: {import_agent_info.name}. agent name must be a valid python variable name.") - + # Resolve model IDs with fallback # Note: We use model_display_name for cross-tenant compatibility # The exported model_id is kept for reference/debugging only @@ -1142,7 +1142,7 @@ async def import_agent_by_agent_id( model_label="Model", tenant_id=tenant_id ) - + business_logic_model_id = _resolve_model_with_fallback( model_display_name=import_agent_info.business_logic_model_name, exported_model_id=import_agent_info.business_logic_model_id, @@ -1344,28 +1344,28 @@ def check_agent_availability( ) -> tuple[bool, list[str]]: """ Check if an agent is available based on its tools and model configuration. - + Args: agent_id: The agent ID to check tenant_id: The tenant ID agent_info: Optional pre-fetched agent info (to avoid duplicate DB queries) model_cache: Optional model cache for performance optimization - + Returns: tuple: (is_available: bool, unavailable_reasons: list[str]) """ unavailable_reasons: list[str] = [] - + if model_cache is None: model_cache = {} - + # Fetch agent info if not provided if agent_info is None: agent_info = search_agent_info_by_agent_id(agent_id, tenant_id) - + if not agent_info: return False, ["agent_not_found"] - + # Check tool availability tool_info = search_tools_for_sub_agent(agent_id=agent_id, tenant_id=tenant_id) tool_id_list = [tool["tool_id"] for tool in tool_info if tool.get("tool_id") is not None] @@ -1373,7 +1373,7 @@ def check_agent_availability( tool_statuses = check_tool_is_available(tool_id_list) if not all(tool_statuses): unavailable_reasons.append("tool_unavailable") - + # Check model availability model_reasons = _collect_model_availability_reasons( agent=agent_info, @@ -1381,7 +1381,7 @@ def check_agent_availability( model_cache=model_cache ) unavailable_reasons.extend(model_reasons) - + is_available = len(unavailable_reasons) == 0 return is_available, unavailable_reasons @@ -1935,4 +1935,4 @@ def get_sub_agents_recursive(parent_agent_id: int, depth: int = 0, max_depth: in except Exception as e: logger.exception( f"Failed to get agent call relationship for agent {agent_id}: {str(e)}") - raise ValueError(f"Failed to get agent call relationship: {str(e)}") \ No newline at end of file + raise ValueError(f"Failed to get agent call relationship: {str(e)}") diff --git a/backend/services/datamate_service.py b/backend/services/datamate_service.py new file mode 100644 index 000000000..269d9e2cb --- /dev/null +++ b/backend/services/datamate_service.py @@ -0,0 +1,226 @@ +""" +Service layer for DataMate knowledge base integration. +Handles API calls to DataMate to fetch knowledge bases and their files. + +This service layer uses the DataMate SDK client to interact with DataMate APIs. +""" +import json +import logging +from typing import Dict, List, Optional, Any +import asyncio + +from consts.const import DATAMATE_BASE_URL +from database.knowledge_db import upsert_knowledge_record, get_knowledge_info_by_tenant_and_source, delete_knowledge_record +from nexent.vector_database.datamate_core import DataMateCore + +logger = logging.getLogger("datamate_service") + + +async def _create_datamate_knowledge_records(knowledge_base_ids: List[str], + knowledge_base_names: List[str], + embedding_model_names: List[str], + tenant_id: str, + user_id: str) -> List[Dict[str, Any]]: + """ + Create knowledge records in local database for DataMate knowledge bases. + + Args: + knowledge_base_ids: List of DataMate knowledge base IDs + knowledge_base_names: List of DataMate knowledge base names + embedding_model_names: List of DataMate embedding model names + tenant_id: Tenant ID for the knowledge records + user_id: User ID for the knowledge records + + Returns: + List of created knowledge record dictionaries + """ + created_records = [] + + for i, kb_id in enumerate(knowledge_base_ids): + try: + # Get knowledge base name, fallback to ID if not available + knowledge_name = knowledge_base_names[i] if i < len(knowledge_base_names) else kb_id + + # Create or update knowledge record in local database + record_data = { + "index_name": kb_id, + "knowledge_name": knowledge_name, + "knowledge_describe": f"DataMate knowledge base: {knowledge_name}", + "knowledge_sources": "datamate", # Mark source as datamate + "tenant_id": tenant_id, + "user_id": user_id, + "embedding_model_name": embedding_model_names[i] # Use datamate as embedding model name + } + + # Run synchronous database operation in executor to avoid blocking + loop = asyncio.get_event_loop() + created_record = await loop.run_in_executor( + None, + upsert_knowledge_record, + record_data + ) + + created_records.append(created_record) + logger.info(f"Created knowledge record for DataMate KB '{knowledge_name}': {created_record}") + + except Exception as e: + logger.error(f"Failed to create knowledge record for DataMate KB '{kb_id}': {str(e)}") + # Continue with other knowledge bases even if one fails + continue + + return created_records + + +def _get_datamate_core() -> DataMateCore: + """Get DataMate core instance.""" + return DataMateCore(base_url=DATAMATE_BASE_URL) + + + + +async def fetch_datamate_knowledge_base_files(knowledge_base_id: str) -> List[Dict[str, Any]]: + """ + Fetch file list for a specific DataMate knowledge base. + + Args: + knowledge_base_id: The ID of the knowledge base. + + Returns: + List of file dictionaries with name, status, size, upload_date, etc. + """ + try: + core = _get_datamate_core() + # Run synchronous SDK call in executor to avoid blocking + loop = asyncio.get_event_loop() + result = await loop.run_in_executor( + None, + core.get_index_chunks, + knowledge_base_id + ) + return result["chunks"] + except Exception as e: + logger.error(f"Error fetching files for knowledge base {knowledge_base_id}: {str(e)}") + raise RuntimeError(f"Failed to fetch files for knowledge base {knowledge_base_id}: {str(e)}") + + +async def fetch_datamate_knowledge_base_file_list(knowledge_base_id: str) -> Dict[str, Any]: + """ + Fetch file list for a specific DataMate knowledge base. + + Args: + knowledge_base_id: The ID of the knowledge base. + + Returns: + Dictionary containing file list with status, files array, etc. + """ + try: + core = _get_datamate_core() + # Run synchronous SDK call in executor to avoid blocking + loop = asyncio.get_event_loop() + files = await loop.run_in_executor( + None, + core.get_documents_detail, + knowledge_base_id + ) + + # Transform to match vectordatabase files endpoint format + return { + "status": "success", + "files": files + } + except Exception as e: + logger.error(f"Error fetching file list for knowledge base {knowledge_base_id}: {str(e)}") + raise RuntimeError(f"Failed to fetch file list for knowledge base {knowledge_base_id}: {str(e)}") + + +async def sync_datamate_knowledge_bases_and_create_records(tenant_id: str, user_id: str) -> Dict[str, Any]: + """ + Sync all DataMate knowledge bases and create knowledge records in local database. + + Args: + tenant_id: Tenant ID for creating knowledge records + user_id: User ID for creating knowledge records + + Returns: + Dictionary containing knowledge bases list and created records. + """ + try: + core = _get_datamate_core() + + # Step 1: Get knowledge base id + knowledge_base_ids = core.get_user_indices() + if not knowledge_base_ids: + return { + "indices": [], + "count": 0, + } + + # Step 2: Get detailed information for all knowledge bases + details, knowledge_base_names = core.get_indices_detail(knowledge_base_ids) + + response = { + "indices": knowledge_base_names, + "count": len(knowledge_base_names), + } + + embedding_model_names = [detail['base_info']['embedding_model'] for detail in details.values()] + + # Add indices_info for consistency with list_indices method + indices_info = [] + for i, kb_id in enumerate(knowledge_base_ids): + if kb_id in details: + kb_detail = details[kb_id] + knowledge_base_name = knowledge_base_names[i] if i < len(knowledge_base_names) else kb_id + indices_info.append({ + "name": kb_id, # Internal index name (used as ID) + "display_name": knowledge_base_name, # User-facing knowledge base name + "stats": kb_detail, + }) + response["indices_info"] = indices_info + + # Create knowledge records in local database + created_records = await _create_datamate_knowledge_records( + knowledge_base_ids, knowledge_base_names, embedding_model_names, tenant_id, user_id + ) + + # Step 3: Handle deleted knowledge bases (soft delete) + # Get all existing DataMate records for this tenant + loop = asyncio.get_event_loop() + existing_records = await loop.run_in_executor( + None, + get_knowledge_info_by_tenant_and_source, + tenant_id, + "datamate" + ) + + # Find records that exist in DB but not in API response + existing_index_names = {record['index_name'] for record in existing_records} + api_index_names = set(knowledge_base_ids) + + # Records to delete (exist in DB but not in API) + records_to_delete = existing_index_names - api_index_names + + # Soft delete records that are no longer in DataMate + for index_name in records_to_delete: + try: + delete_result = await loop.run_in_executor( + None, + delete_knowledge_record, + {"index_name": index_name, "user_id": user_id} + ) + if delete_result: + logger.info(f"Soft deleted DataMate knowledge base record: {index_name}") + else: + logger.warning(f"Failed to soft delete DataMate knowledge base record: {index_name}") + except Exception as e: + logger.error(f"Error soft deleting DataMate knowledge base record {index_name}: {str(e)}") + # Continue with other records even if one fails + + return response + except Exception as e: + logger.error(f"Error syncing DataMate knowledge bases and creating records: {str(e)}") + return { + "indices": [], + "count": 0, + } + diff --git a/backend/services/tenant_config_service.py b/backend/services/tenant_config_service.py index 30524677c..c0e4d4afb 100644 --- a/backend/services/tenant_config_service.py +++ b/backend/services/tenant_config_service.py @@ -1,5 +1,5 @@ import logging -from typing import List +from typing import List, Optional from database.knowledge_db import get_knowledge_info_by_knowledge_ids, get_knowledge_ids_by_index_names from database.tenant_config_db import get_tenant_config_info, insert_config, delete_config_by_tenant_config_id @@ -17,7 +17,17 @@ def get_selected_knowledge_list(tenant_id: str, user_id: str): return knowledge_info -def update_selected_knowledge(tenant_id: str, user_id: str, index_name_list: List[str]): +def update_selected_knowledge(tenant_id: str, user_id: str, index_name_list: List[str], knowledge_sources: Optional[List[str]] = None): + # Validate that knowledge_sources length matches index_name_list if provided + if knowledge_sources and len(knowledge_sources) != len(index_name_list): + logger.error( + f"Knowledge sources length mismatch: sources={len(knowledge_sources)}, names={len(index_name_list)}") + return False + + logger.info( + f"Updating knowledge list for tenant {tenant_id}, user {user_id}: " + f"names={index_name_list}, sources={knowledge_sources}") + knowledge_ids = get_knowledge_ids_by_index_names(index_name_list) record_list = get_tenant_config_info( tenant_id=tenant_id, user_id=user_id, select_key="selected_knowledge_id") diff --git a/backend/services/tool_configuration_service.py b/backend/services/tool_configuration_service.py index 24ca69ce5..e171f6f9b 100644 --- a/backend/services/tool_configuration_service.py +++ b/backend/services/tool_configuration_service.py @@ -502,7 +502,7 @@ def _validate_local_tool( user_id: User ID for knowledge base tools (optional) Returns: - Dict[str, Any]: The actual result returned by the tool's forward method, + Dict[str, Any]: The actual result returned by the tool's forward method, serving as proof that the tool works correctly Raises: @@ -541,8 +541,7 @@ def _validate_local_tool( raise ToolExecutionException(f"Tenant ID and User ID are required for {tool_name} validation") knowledge_info_list = get_selected_knowledge_list( tenant_id=tenant_id, user_id=user_id) - index_names = [knowledge_info.get("index_name") - for knowledge_info in knowledge_info_list] + index_names = [knowledge_info.get("index_name") for knowledge_info in knowledge_info_list if knowledge_info.get('knowledge_sources') == 'elasticsearch'] name_resolver = build_knowledge_name_mapping( tenant_id=tenant_id, user_id=user_id) @@ -573,6 +572,19 @@ def _validate_local_tool( 'embedding_model': embedding_model, } tool_instance = tool_class(**params) + elif tool_name == "datamate_search_tool": + if not tenant_id or not user_id: + raise ToolExecutionException(f"Tenant ID and User ID are required for {tool_name} validation") + knowledge_info_list = get_selected_knowledge_list( + tenant_id=tenant_id, user_id=user_id) + index_names = [knowledge_info.get("index_name") for knowledge_info in knowledge_info_list if + knowledge_info.get('knowledge_sources') == 'datamate'] + + params = { + **instantiation_params, + 'index_names': index_names, + } + tool_instance = tool_class(**params) elif tool_name == "analyze_image": if not tenant_id or not user_id: raise ToolExecutionException(f"Tenant ID and User ID are required for {tool_name} validation") diff --git a/backend/services/vectordatabase_service.py b/backend/services/vectordatabase_service.py index 18203847e..46196436e 100644 --- a/backend/services/vectordatabase_service.py +++ b/backend/services/vectordatabase_service.py @@ -23,8 +23,9 @@ from nexent.core.models.embedding_model import OpenAICompatibleEmbedding, JinaEmbedding, BaseEmbedding from nexent.vector_database.base import VectorDatabaseCore from nexent.vector_database.elasticsearch_core import ElasticSearchCore +from nexent.vector_database.datamate_core import DataMateCore -from consts.const import ES_API_KEY, ES_HOST, LANGUAGE, VectorDatabaseType +from consts.const import DATAMATE_BASE_URL, ES_API_KEY, ES_HOST, LANGUAGE, VectorDatabaseType from consts.model import ChunkCreateRequest, ChunkUpdateRequest from database.attachment_db import delete_file from database.knowledge_db import ( @@ -107,6 +108,9 @@ def get_vector_db_core( ssl_show_warn=False, ) + if db_type == VectorDatabaseType.DATAMATE: + return DataMateCore(base_url=DATAMATE_BASE_URL) + raise ValueError(f"Unsupported vector database type: {db_type}") @@ -507,6 +511,8 @@ def list_indices( filtered_indices_list = [] model_name_is_none_list = [] for record in db_record: + if record['knowledge_sources'] == 'datamate': + continue # async PG database to sync ES, remove the data that is not in ES if record["index_name"] not in all_indices_list: delete_knowledge_record( diff --git a/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx b/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx index 283730e0f..6f86e5724 100644 --- a/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx +++ b/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx @@ -134,6 +134,7 @@ function DataConfig({ isActive }: DataConfigProps) { setActiveKnowledgeBase, isKnowledgeBaseSelectable, refreshKnowledgeBaseData, + refreshKnowledgeBaseDataWithDataMate, loadUserSelectedKnowledgeBases, saveUserSelectedKnowledgeBases, dispatch: kbDispatch, @@ -173,7 +174,7 @@ function DataConfig({ isActive }: DataConfigProps) { setIsCreatingMode(false); setHasClickedUpload(false); setActiveKnowledgeBase(knowledgeBase); - fetchDocuments(knowledgeBase.id); + fetchDocuments(knowledgeBase.id, false, knowledgeBase.source); } }; @@ -271,9 +272,18 @@ function DataConfig({ isActive }: DataConfigProps) { // When component unmounts, if previously active and user has interacted, execute save if (prevIsActiveRef.current === true && hasUserInteractedRef.current) { // Use saved state instead of current potentially cleared state - const selectedKbNames = savedKnowledgeBasesRef.current - .filter((kb) => savedSelectedIdsRef.current.includes(kb.id)) - .map((kb) => kb.id); + const selectedKnowledgeBases = savedKnowledgeBasesRef.current + .filter((kb) => savedSelectedIdsRef.current.includes(kb.id)); + + // Group knowledge bases by source + const knowledgeBySource: { nexent?: string[]; datamate?: string[] } = {}; + selectedKnowledgeBases.forEach((kb) => { + const source = kb.source as keyof typeof knowledgeBySource; + if (!knowledgeBySource[source]) { + knowledgeBySource[source] = []; + } + knowledgeBySource[source]!.push(kb.id); + }); try { // Use fetch with keepalive to ensure request can be sent during page unload @@ -283,7 +293,7 @@ function DataConfig({ isActive }: DataConfigProps) { "Content-Type": "application/json", ...getAuthHeaders(), }, - body: JSON.stringify(selectedKbNames), + body: JSON.stringify(knowledgeBySource), keepalive: true, }).catch((error) => { log.error("卸载时保存失败:", error); @@ -371,7 +381,6 @@ function DataConfig({ isActive }: DataConfigProps) { }, [ isActive, kbState.isLoading, - kbState.selectedIds, kbState.knowledgeBases, modelConfig?.embedding?.modelName, modelConfig?.multiEmbedding?.modelName, @@ -439,7 +448,7 @@ function DataConfig({ isActive }: DataConfigProps) { }); // Get latest document data - const documents = await knowledgeBaseService.getAllFiles(kb.id); + const documents = await knowledgeBaseService.getAllFiles(kb.id, kb.source); // Trigger document update event knowledgeBasePollingService.triggerDocumentsUpdate(kb.id, documents); @@ -447,8 +456,8 @@ function DataConfig({ isActive }: DataConfigProps) { // Background update knowledge base statistics, but don't duplicate document fetching setTimeout(async () => { try { - // Directly call fetchKnowledgeBases to update knowledge base list data - await fetchKnowledgeBases(false, true); + // Directly call fetchKnowledgeBases to update knowledge base list data, but don't reload user selections + await fetchKnowledgeBases(false, false); } catch (error) { log.error("获取知识库最新数据失败:", error); } @@ -517,20 +526,43 @@ function DataConfig({ isActive }: DataConfigProps) { }); }; - // Handle knowledge base sync - const handleSync = () => { - // When manually syncing, force fetch latest data from server - refreshKnowledgeBaseData(true) - .then(() => { - message.success(t("knowledgeBase.message.syncSuccess")); - }) - .catch((error) => { - message.error( - t("knowledgeBase.message.syncError", { - error: error.message || t("common.unknownError"), - }) - ); - }); + // Handle knowledge base sync (includes both indices and DataMate sync and create records) + const handleSync = async () => { + // Set sync loading state + kbDispatch({ type: KNOWLEDGE_BASE_ACTION_TYPES.SET_SYNC_LOADING, payload: true }); + + try { + // refreshKnowledgeBaseDataWithDataMate calls syncDataMateAndCreateRecords which syncs DataMate and creates local records + await refreshKnowledgeBaseDataWithDataMate(); + message.success(t("knowledgeBase.message.syncDataMateSuccess")); + } catch (error) { + message.error( + t("knowledgeBase.message.syncDataMateError", { + error: (error as Error)?.message || t("common.unknownError"), + }) + ); + } finally { + // Clear sync loading state + kbDispatch({ type: KNOWLEDGE_BASE_ACTION_TYPES.SET_SYNC_LOADING, payload: false }); + } + }; + + // Handle DataMate configuration + const [showDataMateConfigModal, setShowDataMateConfigModal] = useState(false); + const [dataMateUrl, setDataMateUrl] = useState(""); + + const handleDataMateConfig = () => { + setShowDataMateConfigModal(true); + }; + + const handleDataMateConfigSave = async () => { + try { + // TODO: Implement DataMate URL configuration save logic + message.success(t("knowledgeBase.message.dataMateConfigSaved")); + setShowDataMateConfigModal(false); + } catch (error) { + message.error(t("knowledgeBase.message.dataMateConfigError")); + } }; // Handle new knowledge base creation @@ -856,6 +888,37 @@ function DataConfig({ isActive }: DataConfigProps) { + + {/* DataMate Configuration Modal */} + setShowDataMateConfigModal(false)} + okText={t("common.save")} + cancelText={t("common.cancel")} + centered + getContainer={() => contentRef.current || document.body} + > +
+
+ + setDataMateUrl(e.target.value)} + placeholder={t("knowledgeBase.modal.dataMateConfig.urlPlaceholder")} + className="w-full px-3 py-2 border border-gray-300 rounded-md focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-transparent" + /> +
+
+ {t("knowledgeBase.modal.dataMateConfig.description")} +
+
+
+ modelId} containerHeight={SETUP_PAGE_CONTAINER.MAIN_CONTENT_HEIGHT} diff --git a/frontend/app/[locale]/knowledges/components/knowledge/KnowledgeBaseList.tsx b/frontend/app/[locale]/knowledges/components/knowledge/KnowledgeBaseList.tsx index d49883d1c..29d4a29f1 100644 --- a/frontend/app/[locale]/knowledges/components/knowledge/KnowledgeBaseList.tsx +++ b/frontend/app/[locale]/knowledges/components/knowledge/KnowledgeBaseList.tsx @@ -2,7 +2,7 @@ import React from 'react' import { useTranslation } from 'react-i18next' import { Button, Checkbox, ConfigProvider } from 'antd' -import { SyncOutlined, PlusOutlined } from '@ant-design/icons' +import { SyncOutlined, PlusOutlined, SettingOutlined } from '@ant-design/icons' import { KnowledgeBase } from '@/types/knowledgeBase' @@ -43,11 +43,13 @@ interface KnowledgeBaseListProps { activeKnowledgeBase: KnowledgeBase | null currentEmbeddingModel: string | null isLoading?: boolean + syncLoading?: boolean onSelect: (id: string) => void onClick: (kb: KnowledgeBase) => void onDelete: (id: string) => void onSync: () => void onCreateNew: () => void + onDataMateConfig?: () => void isSelectable: (kb: KnowledgeBase) => boolean getModelDisplayName: (modelId: string) => string containerHeight?: string // Container total height, consistent with DocumentList @@ -60,11 +62,13 @@ const KnowledgeBaseList: React.FC = ({ activeKnowledgeBase, currentEmbeddingModel, isLoading = false, + syncLoading = false, onSelect, onClick, onDelete, onSync, onCreateNew, + onDataMateConfig, isSelectable, getModelDisplayName, containerHeight = '70vh', // Default container height consistent with DocumentList @@ -162,10 +166,28 @@ const KnowledgeBaseList: React.FC = ({ height: "14px", }} > - + {t("knowledgeBase.button.sync")} + diff --git a/frontend/app/[locale]/knowledges/contexts/DocumentContext.tsx b/frontend/app/[locale]/knowledges/contexts/DocumentContext.tsx index 4e0b33967..b956dd919 100644 --- a/frontend/app/[locale]/knowledges/contexts/DocumentContext.tsx +++ b/frontend/app/[locale]/knowledges/contexts/DocumentContext.tsx @@ -111,7 +111,7 @@ const documentReducer = (state: DocumentState, action: DocumentAction): Document export const DocumentContext = createContext<{ state: DocumentState; dispatch: React.Dispatch; - fetchDocuments: (kbId: string, forceRefresh?: boolean) => Promise; + fetchDocuments: (kbId: string, forceRefresh?: boolean, kbSource?: string) => Promise; uploadDocuments: (kbId: string, files: File[]) => Promise; deleteDocument: (kbId: string, docId: string) => Promise; }>({ @@ -175,23 +175,23 @@ export const DocumentProvider: React.FC = ({ children }) }, []); // Fetch documents for a knowledge base - const fetchDocuments = useCallback(async (kbId: string, forceRefresh?: boolean) => { + const fetchDocuments = useCallback(async (kbId: string, forceRefresh?: boolean, kbSource?: string) => { // Skip if already loading this kb if (state.loadingKbIds.has(kbId)) return; - + // If forceRefresh is false and we have cached data, return directly if (!forceRefresh && state.documentsMap[kbId] && state.documentsMap[kbId].length > 0) { return; // If we have cached data and don't need force refresh, return directly without server request } - + dispatch({ type: DOCUMENT_ACTION_TYPES.SET_LOADING_KB_ID, payload: { kbId, isLoading: true } }); - + try { // Use getAllFiles() to get documents including those not yet in ES - const documents = await knowledgeBaseService.getAllFiles(kbId); - dispatch({ - type: DOCUMENT_ACTION_TYPES.FETCH_SUCCESS, - payload: { kbId, documents } + const documents = await knowledgeBaseService.getAllFiles(kbId, kbSource); + dispatch({ + type: DOCUMENT_ACTION_TYPES.FETCH_SUCCESS, + payload: { kbId, documents } }); } catch (error) { log.error(t('document.error.fetch'), error); diff --git a/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx b/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx index c866600fd..9445ed907 100644 --- a/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx +++ b/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx @@ -58,6 +58,11 @@ const knowledgeBaseReducer = (state: KnowledgeBaseState, action: KnowledgeBaseAc ...state, isLoading: action.payload }; + case KNOWLEDGE_BASE_ACTION_TYPES.SET_SYNC_LOADING: + return { + ...state, + syncLoading: action.payload + }; case KNOWLEDGE_BASE_ACTION_TYPES.ERROR: return { ...state, @@ -79,6 +84,7 @@ export const KnowledgeBaseContext = createContext<{ setActiveKnowledgeBase: (kb: KnowledgeBase) => void; isKnowledgeBaseSelectable: (kb: KnowledgeBase) => boolean; refreshKnowledgeBaseData: (forceRefresh?: boolean) => Promise; + refreshKnowledgeBaseDataWithDataMate: () => Promise; loadUserSelectedKnowledgeBases: () => Promise; saveUserSelectedKnowledgeBases: () => Promise; }>({ @@ -88,6 +94,7 @@ export const KnowledgeBaseContext = createContext<{ activeKnowledgeBase: null, currentEmbeddingModel: null, isLoading: false, + syncLoading: false, error: null }, dispatch: () => {}, @@ -98,6 +105,7 @@ export const KnowledgeBaseContext = createContext<{ setActiveKnowledgeBase: () => {}, isKnowledgeBaseSelectable: () => false, refreshKnowledgeBaseData: async () => {}, + refreshKnowledgeBaseDataWithDataMate: async () => {}, loadUserSelectedKnowledgeBases: async () => {}, saveUserSelectedKnowledgeBases: async () => false, }); @@ -118,9 +126,10 @@ export const KnowledgeBaseProvider: React.FC = ({ ch activeKnowledgeBase: null, currentEmbeddingModel: null, isLoading: false, + syncLoading: false, error: null }); - + // Check if knowledge base is selectable - memoized with useCallback const isKnowledgeBaseSelectable = useCallback((kb: KnowledgeBase): boolean => { // If no current embedding model is set, not selectable @@ -131,8 +140,48 @@ export const KnowledgeBaseProvider: React.FC = ({ ch return kb.embeddingModel === "unknown" || kb.embeddingModel === state.currentEmbeddingModel; }, [state.currentEmbeddingModel]); + // Load user selected knowledge bases from backend + const loadUserSelectedKnowledgeBases = useCallback(async () => { + try { + const userConfig = await userConfigService.loadKnowledgeList(); + if (userConfig) { + let allSelectedNames: string[] = []; + + // Handle new format (selectedKbNames array) + if (userConfig.selectedKbNames && userConfig.selectedKbNames.length > 0) { + allSelectedNames = userConfig.selectedKbNames; + } + // Fallback to legacy grouped format for backward compatibility + else if (userConfig.nexent || userConfig.datamate) { + allSelectedNames = [ + ...(userConfig.nexent || []), + ...(userConfig.datamate || []) + ]; + } + + if (allSelectedNames.length > 0) { + // Find matching knowledge base IDs based on index names + const selectedIds = state.knowledgeBases + .filter((kb) => allSelectedNames.includes(kb.id)) + .map((kb) => kb.id); + + dispatch({ + type: KNOWLEDGE_BASE_ACTION_TYPES.SELECT_KNOWLEDGE_BASE, + payload: selectedIds, + }); + } + } + } catch (error) { + log.error(t("knowledgeBase.error.loadSelected"), error); + dispatch({ + type: KNOWLEDGE_BASE_ACTION_TYPES.ERROR, + payload: t("knowledgeBase.error.loadSelectedRetry"), + }); + } + }, [state.knowledgeBases]); + // Load knowledge base data (supports force fetch from server and load selected status) - optimized with useCallback - const fetchKnowledgeBases = useCallback(async (skipHealthCheck = true) => { + const fetchKnowledgeBases = useCallback(async (skipHealthCheck = true, shouldLoadSelected = true) => { // If already loading, return directly if (state.isLoading) { return; @@ -146,16 +195,21 @@ export const KnowledgeBaseProvider: React.FC = ({ ch // Get knowledge base list data directly from server const kbs = await knowledgeBaseService.getKnowledgeBasesInfo(skipHealthCheck); - + dispatch({ type: KNOWLEDGE_BASE_ACTION_TYPES.FETCH_SUCCESS, payload: kbs }); - + + // After loading knowledge bases, automatically load user's selected knowledge bases if requested + if (shouldLoadSelected && kbs.length > 0) { + await loadUserSelectedKnowledgeBases(); + } + } catch (error) { log.error(t('knowledgeBase.error.fetchList'), error); dispatch({ type: KNOWLEDGE_BASE_ACTION_TYPES.ERROR, payload: t('knowledgeBase.error.fetchListRetry') }); } finally { dispatch({ type: KNOWLEDGE_BASE_ACTION_TYPES.LOADING, payload: false }); } - }, [state.isLoading, t]); + }, [state.isLoading, t, loadUserSelectedKnowledgeBases]); // Select knowledge base - memoized with useCallback const selectKnowledgeBase = useCallback((id: string) => { @@ -233,48 +287,32 @@ export const KnowledgeBaseProvider: React.FC = ({ ch } }, [state.knowledgeBases, state.selectedIds, state.activeKnowledgeBase]); - // Load user selected knowledge bases from backend - const loadUserSelectedKnowledgeBases = useCallback(async () => { + // Save user selected knowledge bases to backend + const saveUserSelectedKnowledgeBases = useCallback(async () => { try { - const userConfig = await userConfigService.loadKnowledgeList(); - if (userConfig && userConfig.selectedKbNames.length > 0) { - // Find matching knowledge base IDs based on index names - const selectedIds = state.knowledgeBases - .filter((kb) => userConfig.selectedKbNames.includes(kb.id)) - .map((kb) => kb.id); + // Get selected knowledge bases grouped by source + const selectedKnowledgeBases = state.knowledgeBases + .filter((kb) => state.selectedIds.includes(kb.id)); - dispatch({ - type: KNOWLEDGE_BASE_ACTION_TYPES.SELECT_KNOWLEDGE_BASE, - payload: selectedIds, - }); - } - } catch (error) { - log.error(t("knowledgeBase.error.loadSelected"), error); - dispatch({ - type: KNOWLEDGE_BASE_ACTION_TYPES.ERROR, - payload: t("knowledgeBase.error.loadSelectedRetry"), + // Group knowledge bases by source + const knowledgeBySource: { nexent?: string[]; datamate?: string[] } = {}; + selectedKnowledgeBases.forEach((kb) => { + const source = kb.source as keyof typeof knowledgeBySource; + if (!knowledgeBySource[source]) { + knowledgeBySource[source] = []; + } + knowledgeBySource[source]!.push(kb.id); }); - } - }, [state.knowledgeBases]); - // Save user selected knowledge bases to backend - const saveUserSelectedKnowledgeBases = useCallback(async () => { - try { - // Get selected knowledge base index names (globally unique identifiers) - const selectedKbNames = state.knowledgeBases - .filter((kb) => state.selectedIds.includes(kb.id)) - .map((kb) => kb.id); - - const success = await userConfigService.updateKnowledgeList( - selectedKbNames - ); - if (!success) { + const result = await userConfigService.updateKnowledgeList(knowledgeBySource); + if (!result) { dispatch({ type: KNOWLEDGE_BASE_ACTION_TYPES.ERROR, payload: t("knowledgeBase.error.saveSelected"), }); + return false; } - return success; + return true; } catch (error) { log.error(t("knowledgeBase.error.saveSelected"), error); dispatch({ @@ -286,16 +324,16 @@ export const KnowledgeBaseProvider: React.FC = ({ ch }, [state.knowledgeBases, state.selectedIds, t]); // Add a function to refresh the knowledge base data - const refreshKnowledgeBaseData = useCallback(async () => { + const refreshKnowledgeBaseData = useCallback(async (forceRefresh = false) => { try { - // Get latest knowledge base data directly from server - await fetchKnowledgeBases(false); + // Get latest knowledge base data directly from server, but don't reload user selections + await fetchKnowledgeBases(false, false); // If there is an active knowledge base, also refresh its document information if (state.activeKnowledgeBase) { // Publish document update event to notify document list component to refresh document data try { - const documents = await knowledgeBaseService.getAllFiles(state.activeKnowledgeBase.id); + const documents = await knowledgeBaseService.getAllFiles(state.activeKnowledgeBase.id, state.activeKnowledgeBase.source); log.log("documents", documents); window.dispatchEvent(new CustomEvent('documentsUpdated', { detail: { @@ -313,6 +351,37 @@ export const KnowledgeBaseProvider: React.FC = ({ ch } }, [fetchKnowledgeBases, state.activeKnowledgeBase]); + // Add a function to refresh the knowledge base data with DataMate sync and create records + const refreshKnowledgeBaseDataWithDataMate = useCallback(async () => { + try { + // First sync DataMate and create records + await knowledgeBaseService.syncDataMateAndCreateRecords(); + + // Then get latest knowledge base data directly from server, but don't reload user selections during sync + await fetchKnowledgeBases(false, false); + + // If there is an active knowledge base, also refresh its document information + if (state.activeKnowledgeBase) { + // Publish document update event to notify document list component to refresh document data + try { + const documents = await knowledgeBaseService.getAllFiles(state.activeKnowledgeBase.id, state.activeKnowledgeBase.source); + log.log("documents", documents); + window.dispatchEvent(new CustomEvent('documentsUpdated', { + detail: { + kbId: state.activeKnowledgeBase.id, + documents + } + })); + } catch (error) { + log.error("Failed to refresh document information:", error); + } + } + } catch (error) { + log.error("Failed to refresh knowledge base data with DataMate:", error); + dispatch({ type: KNOWLEDGE_BASE_ACTION_TYPES.ERROR, payload: 'Failed to refresh knowledge base data with DataMate' }); + } + }, [fetchKnowledgeBases, state.activeKnowledgeBase]); + // Initial data loading - with optimized dependencies useEffect(() => { // Use ref to track if data has been loaded to avoid duplicate loading @@ -360,10 +429,11 @@ export const KnowledgeBaseProvider: React.FC = ({ ch // Check if need to force fetch data from server const customEvent = e as CustomEvent; const forceRefresh = customEvent.detail?.forceRefresh === true; - + // If first time loading data or force refresh, get from server if (!initialDataLoaded || forceRefresh) { - fetchKnowledgeBases(false); + // For force refresh, don't reload user selections to preserve current state + fetchKnowledgeBases(false, !forceRefresh); initialDataLoaded = true; } }; @@ -390,6 +460,7 @@ export const KnowledgeBaseProvider: React.FC = ({ ch setActiveKnowledgeBase, isKnowledgeBaseSelectable, refreshKnowledgeBaseData, + refreshKnowledgeBaseDataWithDataMate, loadUserSelectedKnowledgeBases, saveUserSelectedKnowledgeBases }), [ @@ -401,6 +472,7 @@ export const KnowledgeBaseProvider: React.FC = ({ ch setActiveKnowledgeBase, isKnowledgeBaseSelectable, refreshKnowledgeBaseData, + refreshKnowledgeBaseDataWithDataMate, loadUserSelectedKnowledgeBases, saveUserSelectedKnowledgeBases ]); diff --git a/frontend/const/knowledgeBase.ts b/frontend/const/knowledgeBase.ts index afac4cab1..3ed72bd0f 100644 --- a/frontend/const/knowledgeBase.ts +++ b/frontend/const/knowledgeBase.ts @@ -43,6 +43,7 @@ export const KNOWLEDGE_BASE_ACTION_TYPES = { DELETE_KNOWLEDGE_BASE: "DELETE_KNOWLEDGE_BASE", ADD_KNOWLEDGE_BASE: "ADD_KNOWLEDGE_BASE", LOADING: "LOADING", + SET_SYNC_LOADING: "SET_SYNC_LOADING", ERROR: "ERROR" } as const; diff --git a/frontend/public/locales/en/common.json b/frontend/public/locales/en/common.json index 8702bfc97..565d42a2e 100644 --- a/frontend/public/locales/en/common.json +++ b/frontend/public/locales/en/common.json @@ -454,6 +454,7 @@ "knowledgeBase.list.title": "Knowledge Base List", "knowledgeBase.button.create": "Create", "knowledgeBase.button.sync": "Sync", + "knowledgeBase.button.syncDataMate": "Sync DataMate Knowledge Bases", "knowledgeBase.selected.prefix": "Selected", "knowledgeBase.selected.suffix": "knowledge bases for retrieval", "knowledgeBase.button.removeKb": "Remove knowledge base {{name}}", @@ -471,6 +472,15 @@ "knowledgeBase.message.deleteError": "Failed to delete knowledge base", "knowledgeBase.message.syncSuccess": "Knowledge base synchronized successfully", "knowledgeBase.message.syncError": "Failed to synchronize knowledge base: {{error}}", + "knowledgeBase.message.syncDataMateSuccess": "DataMate knowledge bases synchronized successfully", + "knowledgeBase.message.syncDataMateError": "Failed to synchronize DataMate knowledge bases: {{error}}", + "knowledgeBase.button.dataMateConfig": "DataMate Config", + "knowledgeBase.message.dataMateConfigSaved": "DataMate configuration saved successfully", + "knowledgeBase.message.dataMateConfigError": "Failed to save DataMate configuration", + "knowledgeBase.modal.dataMateConfig.title": "DataMate Configuration", + "knowledgeBase.modal.dataMateConfig.urlLabel": "DataMate URL", + "knowledgeBase.modal.dataMateConfig.urlPlaceholder": "Enter DataMate server address", + "knowledgeBase.modal.dataMateConfig.description": "Configure the DataMate server address for synchronizing external knowledge base data.", "knowledgeBase.message.nameRequired": "Please enter knowledge base name", "knowledgeBase.message.nameExists": "Knowledge base {{name}} already exists, please use a different name", "knowledgeBase.error.nameExistsInOtherTenant": "Knowledge base {{name}} is used by another tenant, please use a different name", diff --git a/frontend/public/locales/zh/common.json b/frontend/public/locales/zh/common.json index 4cf620d1d..e5852a60d 100644 --- a/frontend/public/locales/zh/common.json +++ b/frontend/public/locales/zh/common.json @@ -455,6 +455,7 @@ "knowledgeBase.list.title": "知识库列表", "knowledgeBase.button.create": "创建知识库", "knowledgeBase.button.sync": "同步知识库", + "knowledgeBase.button.syncDataMate": "同步DataMate知识库", "knowledgeBase.selected.prefix": "已选择", "knowledgeBase.selected.suffix": "个知识库用于知识检索", "knowledgeBase.button.removeKb": "移除知识库 {{name}}", @@ -472,6 +473,15 @@ "knowledgeBase.message.deleteError": "删除知识库失败", "knowledgeBase.message.syncSuccess": "同步知识库成功", "knowledgeBase.message.syncError": "同步知识库失败:{{error}}", + "knowledgeBase.message.syncDataMateSuccess": "同步DataMate知识库成功", + "knowledgeBase.message.syncDataMateError": "同步DataMate知识库失败:{{error}}", + "knowledgeBase.button.dataMateConfig": "DataMate配置", + "knowledgeBase.message.dataMateConfigSaved": "DataMate配置已保存", + "knowledgeBase.message.dataMateConfigError": "DataMate配置保存失败", + "knowledgeBase.modal.dataMateConfig.title": "DataMate配置", + "knowledgeBase.modal.dataMateConfig.urlLabel": "DataMate URL", + "knowledgeBase.modal.dataMateConfig.urlPlaceholder": "请输入DataMate服务器地址", + "knowledgeBase.modal.dataMateConfig.description": "配置DataMate服务器地址,用于同步外部知识库数据。", "knowledgeBase.message.nameRequired": "请输入知识库名称", "knowledgeBase.message.nameExists": "知识库 {{name}} 已存在,请更换名称", "knowledgeBase.error.nameExistsInOtherTenant": "知识库 {{name}} 已被其他租户使用,请更换名称", diff --git a/frontend/services/api.ts b/frontend/services/api.ts index 81856de87..261dbb82f 100644 --- a/frontend/services/api.ts +++ b/frontend/services/api.ts @@ -150,6 +150,10 @@ export const API_ENDPOINTS = { pathOrUrl )}/error-info`, }, + datamate: { + syncAndCreateRecords: `${API_BASE_URL}/datamate/sync_and_create_records`, + files: (knowledgeBaseId: string) => `${API_BASE_URL}/datamate/${knowledgeBaseId}/files`, + }, config: { save: `${API_BASE_URL}/config/save_config`, load: `${API_BASE_URL}/config/load_config`, diff --git a/frontend/services/knowledgeBaseService.ts b/frontend/services/knowledgeBaseService.ts index 0ea443081..3a0150294 100644 --- a/frontend/services/knowledgeBaseService.ts +++ b/frontend/services/knowledgeBaseService.ts @@ -41,46 +41,107 @@ class KnowledgeBaseService { } } - // Get knowledge bases with stats (very slow, don't use it) + // Sync DataMate knowledge bases and create local records + async syncDataMateAndCreateRecords(): Promise<{ + indices: string[]; + count: number; + indices_info: any[]; + created_records: any[]; + }> { + try { + const response = await fetch(API_ENDPOINTS.datamate.syncAndCreateRecords, { + method: "POST", + headers: getAuthHeaders(), + }); + + const data = await response.json(); + + if (!response.ok) { + throw new Error(data.detail || "Failed to sync DataMate knowledge bases and create records"); + } + + return data; + } catch (error) { + log.error("Failed to sync DataMate knowledge bases and create records:", error); + throw error; + } + } + + + // Get knowledge bases with stats from all sources (very slow, don't use it) async getKnowledgeBasesInfo( skipHealthCheck = false ): Promise { try { - // First check Elasticsearch health (unless skipped) - if (!skipHealthCheck) { - const isElasticsearchHealthy = await this.checkHealth(); - if (!isElasticsearchHealthy) { - log.warn("Elasticsearch service unavailable"); - return []; - } - } - - let knowledgeBases: KnowledgeBase[] = []; + const knowledgeBases: KnowledgeBase[] = []; // Get knowledge bases from Elasticsearch try { - const response = await fetch( - `${API_ENDPOINTS.knowledgeBase.indices}?include_stats=true`, - { - headers: getAuthHeaders(), + // First check Elasticsearch health (unless skipped) + if (!skipHealthCheck) { + const isElasticsearchHealthy = await this.checkHealth(); + if (!isElasticsearchHealthy) { + log.warn("Elasticsearch service unavailable"); + } else { + const response = await fetch( + `${API_ENDPOINTS.knowledgeBase.indices}?include_stats=true`, + { + headers: getAuthHeaders(), + } + ); + const data = await response.json(); + + if (data.indices && data.indices_info) { + // Convert Elasticsearch indices to knowledge base format + const esKnowledgeBases = data.indices_info.map((indexInfo: any) => { + const stats = indexInfo.stats?.base_info || {}; + // Backend now returns: + // - name: internal index_name + // - display_name: user-facing knowledge_name (fallback to index_name) + const kbId = indexInfo.name; + const kbName = indexInfo.display_name || indexInfo.name; + + return { + id: kbId, + name: kbName, + description: "Elasticsearch index", + documentCount: stats.doc_count || 0, + chunkCount: stats.chunk_count || 0, + createdAt: stats.creation_date || null, + updatedAt: stats.update_date || stats.creation_date || null, + embeddingModel: stats.embedding_model || "unknown", + avatar: "", + chunkNum: 0, + language: "", + nickname: "", + parserId: "", + permission: "", + tokenNum: 0, + source: "nexent", + }; + }); + knowledgeBases.push(...esKnowledgeBases); + } } - ); - const data = await response.json(); + } + } catch (error) { + log.error("Failed to get Elasticsearch indices:", error); + } - if (data.indices && data.indices_info) { - // Convert Elasticsearch indices to knowledge base format - knowledgeBases = data.indices_info.map((indexInfo: any) => { + // Sync DataMate knowledge bases and get the synced data + try { + const syncResult = await this.syncDataMateAndCreateRecords(); + if (syncResult.indices_info) { + // Convert synced DataMate indices to knowledge base format + const datamateKnowledgeBases: KnowledgeBase[] = syncResult.indices_info.map((indexInfo: any) => { const stats = indexInfo.stats?.base_info || {}; - // Backend now returns: - // - name: internal index_name - // - display_name: user-facing knowledge_name (fallback to index_name) const kbId = indexInfo.name; const kbName = indexInfo.display_name || indexInfo.name; return { id: kbId, name: kbName, - description: "Elasticsearch index", + description: "DataMate knowledge base", documentCount: stats.doc_count || 0, chunkCount: stats.chunk_count || 0, createdAt: stats.creation_date || null, @@ -93,12 +154,13 @@ class KnowledgeBaseService { parserId: "", permission: "", tokenNum: 0, - source: "elasticsearch", + source: "datamate", }; }); + knowledgeBases.push(...datamateKnowledgeBases); } } catch (error) { - log.error("Failed to get Elasticsearch indices:", error); + log.error("Failed to sync DataMate knowledge bases:", error); } return knowledgeBases; @@ -252,15 +314,31 @@ class KnowledgeBaseService { } // Get all files from a knowledge base, regardless of the existence of index - async getAllFiles(kbId: string): Promise { + async getAllFiles(kbId: string, kbSource?: string): Promise { try { - const response = await fetch( - API_ENDPOINTS.knowledgeBase.listFiles(kbId), - { - headers: getAuthHeaders(), - } - ); - const result = await response.json(); + let response: Response; + let result: any; + + // Determine which API to call based on knowledge base source + if (kbSource === "datamate") { + // Call DataMate files API + response = await fetch( + API_ENDPOINTS.datamate.files(kbId), + { + headers: getAuthHeaders(), + } + ); + result = await response.json(); + } else { + // Call Elasticsearch files API (default behavior) + response = await fetch( + API_ENDPOINTS.knowledgeBase.listFiles(kbId), + { + headers: getAuthHeaders(), + } + ); + result = await response.json(); + } if (result.status !== "success") { throw new Error("Failed to get file list"); diff --git a/frontend/services/userConfigService.ts b/frontend/services/userConfigService.ts index 76a3deeaa..99f4d70c0 100644 --- a/frontend/services/userConfigService.ts +++ b/frontend/services/userConfigService.ts @@ -1,5 +1,5 @@ import { API_ENDPOINTS } from './api'; -import { UserKnowledgeConfig } from '../types/knowledgeBase'; +import { UserKnowledgeConfig, UpdateKnowledgeListRequest } from '../types/knowledgeBase'; import { fetchWithAuth, getAuthHeaders } from '@/lib/auth'; // @ts-ignore @@ -29,25 +29,28 @@ export class UserConfigService { } // Update user selected knowledge base list - async updateKnowledgeList(knowledgeList: string[]): Promise { + async updateKnowledgeList(request: UpdateKnowledgeListRequest): Promise { try { const response = await fetch( API_ENDPOINTS.tenantConfig.updateKnowledgeList, { method: "POST", headers: getAuthHeaders(), - body: JSON.stringify(knowledgeList), + body: JSON.stringify(request), } ); if (!response.ok) { - return false; + return null; } const result = await response.json(); - return result.status === 'success'; + if (result.status === 'success') { + return result.content; + } + return null; } catch (error) { - return false; + return null; } } } diff --git a/frontend/types/knowledgeBase.ts b/frontend/types/knowledgeBase.ts index e04f145c7..b170660bc 100644 --- a/frontend/types/knowledgeBase.ts +++ b/frontend/types/knowledgeBase.ts @@ -82,11 +82,12 @@ export interface KnowledgeBaseState { activeKnowledgeBase: KnowledgeBase | null; currentEmbeddingModel: string | null; isLoading: boolean; + syncLoading: boolean; error: string | null; } // Knowledge base action type -export type KnowledgeBaseAction = +export type KnowledgeBaseAction = | { type: typeof KNOWLEDGE_BASE_ACTION_TYPES.FETCH_SUCCESS, payload: KnowledgeBase[] } | { type: typeof KNOWLEDGE_BASE_ACTION_TYPES.SELECT_KNOWLEDGE_BASE, payload: string[] } | { type: typeof KNOWLEDGE_BASE_ACTION_TYPES.SET_ACTIVE, payload: KnowledgeBase | null } @@ -94,6 +95,7 @@ export type KnowledgeBaseAction = | { type: typeof KNOWLEDGE_BASE_ACTION_TYPES.DELETE_KNOWLEDGE_BASE, payload: string } | { type: typeof KNOWLEDGE_BASE_ACTION_TYPES.ADD_KNOWLEDGE_BASE, payload: KnowledgeBase } | { type: typeof KNOWLEDGE_BASE_ACTION_TYPES.LOADING, payload: boolean } + | { type: typeof KNOWLEDGE_BASE_ACTION_TYPES.SET_SYNC_LOADING, payload: boolean } | { type: typeof KNOWLEDGE_BASE_ACTION_TYPES.ERROR, payload: string }; // UI state interface @@ -123,7 +125,16 @@ export interface AbortableError extends Error { // User selected knowledge base configuration type export interface UserKnowledgeConfig { - selectedKbNames: string[]; - selectedKbModels: string[]; - selectedKbSources: string[]; + selectedKbNames?: string[]; + selectedKbModels?: string[]; + selectedKbSources?: string[]; + // Legacy support for grouped format + nexent?: string[]; + datamate?: string[]; +} + +// Update knowledge list request type +export interface UpdateKnowledgeListRequest { + nexent?: string[]; + datamate?: string[]; } diff --git a/sdk/nexent/__init__.py b/sdk/nexent/__init__.py index a7242e554..63423081e 100644 --- a/sdk/nexent/__init__.py +++ b/sdk/nexent/__init__.py @@ -1,9 +1,10 @@ from .core import * from .data_process import * +from .datamate import * from .memory import * from .storage import * from .vector_database import * from .container import * -__all__ = ["core", "data_process", "memory", "storage", "vector_database", "container"] \ No newline at end of file +__all__ = ["core", "data_process", "datamate","memory", "storage", "vector_database", "container"] diff --git a/sdk/nexent/core/agents/nexent_agent.py b/sdk/nexent/core/agents/nexent_agent.py index 290dfb45e..12d7737df 100644 --- a/sdk/nexent/core/agents/nexent_agent.py +++ b/sdk/nexent/core/agents/nexent_agent.py @@ -89,6 +89,12 @@ def create_local_tool(self, tool_config: ToolConfig): name_resolver = tool_config.metadata.get( "name_resolver", None) if tool_config.metadata else None tools_obj.name_resolver = {} if name_resolver is None else name_resolver + elif class_name == "DataMateSearchTool": + tools_obj = tool_class(**params) + tools_obj.observer = self.observer + index_names = tool_config.metadata.get( + "index_names", None) if tool_config.metadata else None + tools_obj.index_names = [] if index_names is None else index_names elif class_name == "AnalyzeTextFileTool": tools_obj = tool_class(observer=self.observer, llm_model=tool_config.metadata.get("llm_model", []), diff --git a/sdk/nexent/core/tools/__init__.py b/sdk/nexent/core/tools/__init__.py index aaa0a0049..e88be78b7 100644 --- a/sdk/nexent/core/tools/__init__.py +++ b/sdk/nexent/core/tools/__init__.py @@ -20,12 +20,12 @@ "ExaSearchTool", "KnowledgeBaseSearchTool", "DataMateSearchTool", - "SendEmailTool", - "GetEmailTool", - "TavilySearchTool", + "SendEmailTool", + "GetEmailTool", + "TavilySearchTool", "LinkupSearchTool", "CreateFileTool", - "ReadFileTool", + "ReadFileTool", "DeleteFileTool", "CreateDirectoryTool", "DeleteDirectoryTool", diff --git a/sdk/nexent/core/tools/analyze_text_file_tool.py b/sdk/nexent/core/tools/analyze_text_file_tool.py index 43cecb742..78b78543d 100644 --- a/sdk/nexent/core/tools/analyze_text_file_tool.py +++ b/sdk/nexent/core/tools/analyze_text_file_tool.py @@ -26,14 +26,14 @@ class AnalyzeTextFileTool(Tool): """Tool for analyzing text file content using a large language model""" - + name = "analyze_text_file" description = ( "Extract content from text files and analyze them using a large language model based on your query. " "Supports multiple files from S3 URLs (s3://bucket/key or /bucket/key), HTTP, and HTTPS URLs. " "The tool will extract the text content from each file and return an analysis based on your question." ) - + inputs = { "file_url_list": { "type": "array", @@ -75,6 +75,7 @@ def __init__( self.llm_model = llm_model self.data_process_service_url = data_process_service_url self.mm = LoadSaveObjectManager(storage_client=self.storage_client) + self.time_out = 60 * 5 self.running_prompt_zh = "正在分析文件..." self.running_prompt_en = "Analyzing file..." @@ -137,7 +138,7 @@ def _forward_impl( analysis_results.append(str(analysis_error)) return analysis_results - + except Exception as e: logger.error(f"Error analyzing text file: {str(e)}", exc_info=True) error_msg = f"Error analyzing text file: {str(e)}" @@ -160,9 +161,9 @@ def process_text_file(self, filename: str, file_content: bytes,) -> str: } data = { 'chunking_strategy': 'basic', - 'timeout': 60 + 'timeout': self.time_out, } - with httpx.Client(timeout=60) as client: + with httpx.Client(timeout=self.time_out) as client: response = client.post(api_url, files=files, data=data) if response.status_code == 200: diff --git a/sdk/nexent/core/tools/datamate_search_tool.py b/sdk/nexent/core/tools/datamate_search_tool.py index bf1009269..296de9845 100644 --- a/sdk/nexent/core/tools/datamate_search_tool.py +++ b/sdk/nexent/core/tools/datamate_search_tool.py @@ -1,15 +1,14 @@ import json import logging -from typing import List, Optional +from typing import Optional, List, Union -import httpx from pydantic import Field from smolagents.tools import Tool +from ...vector_database import DataMateCore from ..utils.observer import MessageObserver, ProcessType from ..utils.tools_common_message import SearchResultTextMessage, ToolCategory, ToolSign - # Get logger instance logger = logging.getLogger("datamate_search_tool") @@ -41,6 +40,11 @@ class DataMateSearchTool(Tool): "default": 0.2, "nullable": True, }, + "index_names": { + "type": "array", + "description": "The list of knowledge base names to search (supports user-facing knowledge_name or internal index_name). If not provided, will search all available knowledge bases.", + "nullable": True, + }, "kb_page": { "type": "integer", "description": "Page index when listing knowledge bases from DataMate.", @@ -64,6 +68,7 @@ def __init__( self, server_ip: str = Field(description="DataMate server IP or hostname"), server_port: int = Field(description="DataMate server port"), + index_names: List[str] = Field(description="The list of index names to search", default=None, exclude=True), observer: MessageObserver = Field(description="Message observer", default=None, exclude=True), ): """Initialize the DataMateSearchTool. @@ -84,10 +89,14 @@ def __init__( # Store raw host and port self.server_ip = server_ip.strip() self.server_port = server_port + self.index_names = [] if index_names is None else index_names # Build base URL: http://host:port self.server_base_url = f"http://{self.server_ip}:{self.server_port}".rstrip("/") + # Initialize DataMate vector database core + self.datamate_core = DataMateCore(base_url=self.server_base_url) + self.kb_page = 0 self.kb_page_size = 20 self.observer = observer @@ -96,11 +105,20 @@ def __init__( self.running_prompt_zh = "DataMate知识库检索中..." self.running_prompt_en = "Searching the DataMate knowledge base..." + def _normalize_index_names(self, index_names: Optional[Union[str, List[str]]]) -> List[str]: + """Normalize index_names to list; accept single string and keep None as empty list.""" + if index_names is None: + return [] + if isinstance(index_names, str): + return [index_names] + return list(index_names) + def forward( self, query: str, top_k: int = 10, threshold: float = 0.2, + index_names: Union[str, List[str], None] = None, kb_page: int = 0, kb_page_size: int = 20, ) -> str: @@ -110,6 +128,7 @@ def forward( query: Search query text. top_k: Optional override for maximum number of search results. threshold: Optional override for similarity threshold. + index_names: Optional list of index names to search in. If not provided, all available indexes will be queried. kb_page: Optional override for knowledge base list page index. kb_page_size: Optional override for knowledge base list page size. """ @@ -126,21 +145,55 @@ def forward( logger.info( f"DataMateSearchTool called with query: '{query}', base_url: '{self.server_base_url}', " - f"top_k: {top_k}, threshold: {threshold}" + f"top_k: {top_k}, threshold: {threshold}, index_names: {index_names}" ) + knowledge_base_ids = [] + + # Use provided index_names if available, otherwise use default + knowledge_base_ids = self._normalize_index_names( + index_names if index_names is not None else self.index_names) + # todo 名字匹配 + # search_index_names = self._resolve_names(search_index_names) + + + + try: - # Step 1: Get knowledge base list - knowledge_base_ids = self._get_knowledge_base_list() + # Step 1: Determine knowledge base IDs to search + # if index_names: + # # Use provided index names + # knowledge_base_ids = [str(index) for index in index_names] + # else: + # # Get knowledge base list using SDK + # knowledge_bases = self.datamate_core.client.list_knowledge_bases( + # page=self.kb_page, + # size=self.kb_page_size + # ) + # + # # Extract knowledge base IDs + # knowledge_base_ids = [] + # for kb in knowledge_bases: + # kb_id = kb.get("id") + # chunk_count = kb.get("chunkCount") + # if kb_id and chunk_count: + # knowledge_base_ids.append(str(kb_id)) + if not knowledge_base_ids: return json.dumps("No knowledge base found. No relevant information found.", ensure_ascii=False) - - # Step 2: Retrieve knowledge base content - kb_search_results = self._retrieve_knowledge_base_content(query, knowledge_base_ids, top_k, threshold - ) - - if not kb_search_results: - raise Exception("No results found! Try a less restrictive/shorter query.") + # knowledge_base_ids = [str(index) for index in index_names] + # Step 2: Retrieve knowledge base content using DataMateCore hybrid search + kb_search_results = [] + for knowledge_base_id in knowledge_base_ids: + kb_search = self.datamate_core.hybrid_search( + query_text=query, + index_names=[knowledge_base_id], + top_k=top_k, + weight_accurate=threshold, + ) + if not kb_search: + raise Exception("No results found! Try a less restrictive/shorter query.") + kb_search_results.extend(kb_search) # Format search results search_results_json = [] # Organize search results into a unified format @@ -151,7 +204,7 @@ def forward( metadata = self._parse_metadata(entity_data.get("metadata")) dataset_id = self._extract_dataset_id(metadata.get("absolute_directory_path", "")) file_id = metadata.get("original_file_id") - download_url = self._build_file_download_url(dataset_id, file_id) + download_url = self.datamate_core.client.build_file_download_url(dataset_id, file_id) score_details = entity_data.get("scoreDetails", {}) or {} score_details.update({ @@ -168,7 +221,7 @@ def forward( url=download_url, filename=metadata.get("file_name", ""), published_date=entity_data.get("createTime", ""), - score=entity_data.get("score", "0"), + score=single_search_result.get("score", "0"), score_details=score_details, cite_index=self.record_ops + index, search_type=self.name, @@ -191,100 +244,6 @@ def forward( logger.error(error_msg) raise Exception(error_msg) - def _get_knowledge_base_list(self) -> List[str]: - """Get knowledge base list from DataMate API. - - Returns: - List[str]: List of knowledge base IDs. - """ - try: - url = f"{self.server_base_url}/api/knowledge-base/list" - payload = {"page": self.kb_page, "size": self.kb_page_size} - - with httpx.Client(timeout=30) as client: - response = client.post(url, json=payload) - - if response.status_code != 200: - error_detail = ( - response.json().get("detail", "unknown error") - if response.headers.get("content-type", "").startswith("application/json") - else response.text - ) - raise Exception(f"Failed to get knowledge base list (status {response.status_code}): {error_detail}") - - result = response.json() - # Extract knowledge base IDs from response - # Assuming the response structure contains a list of knowledge bases with 'id' field - data = result.get("data", {}) - knowledge_bases = data.get("content", []) if data else [] - - knowledge_base_ids = [] - for kb in knowledge_bases: - kb_id = kb.get("id") - chunk_count = kb.get("chunkCount") - if kb_id and chunk_count: - knowledge_base_ids.append(str(kb_id)) - - logger.info(f"Retrieved {len(knowledge_base_ids)} knowledge base(s): {knowledge_base_ids}") - return knowledge_base_ids - - except httpx.TimeoutException: - raise Exception("Timeout while getting knowledge base list from DataMate API") - except httpx.RequestError as e: - raise Exception(f"Request error while getting knowledge base list: {str(e)}") - except Exception as e: - raise Exception(f"Error getting knowledge base list: {str(e)}") - - def _retrieve_knowledge_base_content( - self, query: str, knowledge_base_ids: List[str], top_k: int, threshold: float - ) -> List[dict]: - """Retrieve knowledge base content from DataMate API. - - Args: - query (str): Search query. - knowledge_base_ids (List[str]): List of knowledge base IDs to search. - top_k (int): Maximum number of results to return. - threshold (float): Similarity threshold. - - Returns: - List[dict]: List of search results. - """ - search_results = [] - for knowledge_base_id in knowledge_base_ids: - try: - url = f"{self.server_base_url}/api/knowledge-base/retrieve" - payload = { - "query": query, - "topK": top_k, - "threshold": threshold, - "knowledgeBaseIds": [knowledge_base_id], - } - - with httpx.Client(timeout=60) as client: - response = client.post(url, json=payload) - - if response.status_code != 200: - error_detail = ( - response.json().get("detail", "unknown error") - if response.headers.get("content-type", "").startswith("application/json") - else response.text - ) - raise Exception( - f"Failed to retrieve knowledge base content (status {response.status_code}): {error_detail}") - - result = response.json() - # Extract search results from response - for data in result.get("data", {}): - search_results.append(data) - except httpx.TimeoutException: - raise Exception("Timeout while retrieving knowledge base content from DataMate API") - except httpx.RequestError as e: - raise Exception(f"Request error while retrieving knowledge base content: {str(e)}") - except Exception as e: - raise Exception(f"Error retrieving knowledge base content: {str(e)}") - logger.info(f"Retrieved {len(search_results)} search result(s)") - return search_results - @staticmethod def _parse_metadata(metadata_raw: Optional[str]) -> dict: """Parse metadata payload safely.""" @@ -305,9 +264,3 @@ def _extract_dataset_id(absolute_path: str) -> str: return "" segments = [segment for segment in absolute_path.strip("/").split("/") if segment] return segments[-1] if segments else "" - - def _build_file_download_url(self, dataset_id: str, file_id: str) -> str: - """Build the download URL for a dataset file.""" - if not (self.server_base_url and dataset_id and file_id): - return "" - return f"{self.server_base_url}/api/data-management/datasets/{dataset_id}/files/{file_id}/download" \ No newline at end of file diff --git a/sdk/nexent/datamate/__init__.py b/sdk/nexent/datamate/__init__.py new file mode 100644 index 000000000..c5a345632 --- /dev/null +++ b/sdk/nexent/datamate/__init__.py @@ -0,0 +1,7 @@ +""" +DataMate SDK client for interacting with DataMate knowledge base APIs. +""" +from .datamate_client import DataMateClient + +__all__ = ["DataMateClient"] + diff --git a/sdk/nexent/datamate/datamate_client.py b/sdk/nexent/datamate/datamate_client.py new file mode 100644 index 000000000..aaeba7712 --- /dev/null +++ b/sdk/nexent/datamate/datamate_client.py @@ -0,0 +1,378 @@ +""" +DataMate API client for datamate knowledge base operations. + +This SDK provides a unified interface for interacting with DataMate knowledge base APIs, +including listing knowledge bases, retrieving files, and retrieving content. +""" +import logging +from typing import Dict, List, Optional, Any +import httpx + +logger = logging.getLogger("datamate_client") + + +class DataMateClient: + """ + Client for interacting with DataMate knowledge base APIs. + + This client encapsulates all DataMate API calls and provides a clean interface + for datamate knowledge base operations. + """ + + def __init__(self, base_url: str, timeout: float = 30.0): + """ + Initialize DataMate client. + + Args: + base_url: Base URL of DataMate server (e.g., "http://jasonwang.site:30000") + timeout: Request timeout in seconds (default: 30.0) + """ + self.base_url = base_url.rstrip("/") + self.timeout = timeout + logger.info(f"Initialized DataMateClient with base_url: {self.base_url}") + + def _build_url(self, path: str) -> str: + """Build full URL from path.""" + if path.startswith("/"): + return f"{self.base_url}{path}" + return f"{self.base_url}/{path}" + + def _build_headers(self, authorization: Optional[str] = None) -> Dict[str, str]: + """ + Build request headers with optional authorization. + + Args: + authorization: Optional authorization header value + + Returns: + Dictionary of headers + """ + headers = {} + if authorization: + headers["Authorization"] = authorization + return headers + + def _handle_error_response(self, response: httpx.Response, error_message: str) -> None: + """ + Handle error response and raise appropriate exception. + + Args: + response: HTTP response object + error_message: Base error message to include in exception (e.g., "Failed to get knowledge base list") + + Raises: + Exception: With detailed error message + """ + error_detail = ( + response.json().get("detail", "unknown error") + if response.headers.get("content-type", "").startswith("application/json") + else response.text + ) + raise Exception(f"{error_message} (status {response.status_code}): {error_detail}") + + def _make_request( + self, + method: str, + url: str, + headers: Dict[str, str], + json: Optional[Dict[str, Any]] = None, + timeout: Optional[float] = None, + error_message: str = "Request failed" + ) -> httpx.Response: + """ + Make HTTP request with error handling. + + Args: + method: HTTP method ("GET" or "POST") + url: Request URL + headers: Request headers + json: Optional JSON payload for POST requests + timeout: Optional timeout override + error_message: Error message to use if request fails + + Returns: + HTTP response object + + Raises: + Exception: If the request fails (with detailed error message) + """ + request_timeout = timeout if timeout is not None else self.timeout + + with httpx.Client(timeout=request_timeout) as client: + if method.upper() == "GET": + response = client.get(url, headers=headers) + elif method.upper() == "POST": + response = client.post(url, json=json, headers=headers) + else: + raise ValueError(f"Unsupported HTTP method: {method}") + + if response.status_code != 200: + self._handle_error_response(response, error_message) + + return response + + def list_knowledge_bases( + self, + page: int = 0, + size: int = 20, + authorization: Optional[str] = None + ) -> List[Dict[str, Any]]: + """ + Get list of knowledge bases from DataMate. + + Args: + page: Page index (default: 0) + size: Page size (default: 20) + authorization: Optional authorization header + + Returns: + List of knowledge base dictionaries with their IDs and metadata. + + Raises: + RuntimeError: If the API request fails + """ + try: + url = self._build_url("/api/knowledge-base/list") + payload = {"page": page, "size": size} + headers = self._build_headers(authorization) + + logger.info(f"Fetching DataMate knowledge bases from: {url}, page={page}, size={size}") + + response = self._make_request("POST", url, headers, json=payload, error_message="Failed to get knowledge base list") + data = response.json() + + # Extract knowledge base list from response + knowledge_bases = [] + if data.get("data"): + knowledge_bases = data.get("data").get("content", []) + + logger.info(f"Successfully fetched {len(knowledge_bases)} knowledge bases from DataMate") + return knowledge_bases + + except httpx.HTTPError as e: + logger.error(f"HTTP error while fetching DataMate knowledge bases: {str(e)}") + raise RuntimeError(f"Failed to fetch DataMate knowledge bases: {str(e)}") + except Exception as e: + logger.error(f"Unexpected error while fetching DataMate knowledge bases: {str(e)}") + raise RuntimeError(f"Failed to fetch DataMate knowledge bases: {str(e)}") + + def get_knowledge_base_files( + self, + knowledge_base_id: str, + authorization: Optional[str] = None + ) -> List[Dict[str, Any]]: + """ + Get file list for a specific DataMate knowledge base. + + Args: + knowledge_base_id: The ID of the knowledge base + authorization: Optional authorization header + + Returns: + List of file dictionaries with name, status, size, upload_date, etc. + + Raises: + RuntimeError: If the API request fails + """ + try: + url = self._build_url(f"/api/knowledge-base/{knowledge_base_id}/files") + logger.info(f"Fetching files for DataMate knowledge base {knowledge_base_id} from: {url}") + + headers = self._build_headers(authorization) + response = self._make_request("GET", url, headers, error_message="Failed to get knowledge base files") + data = response.json() + + # Extract file list from response + files = [] + if data.get("data"): + files = data.get("data").get("content", []) + + logger.info(f"Successfully fetched {len(files)} files for datamate knowledge base {knowledge_base_id}") + return files + + except httpx.HTTPError as e: + logger.error(f"HTTP error while fetching files for datamate knowledge base {knowledge_base_id}: {str(e)}") + raise RuntimeError(f"Failed to fetch files for datamate knowledge base {knowledge_base_id}: {str(e)}") + except Exception as e: + logger.error(f"Unexpected error while fetching files for datamate knowledge base {knowledge_base_id}: {str(e)}") + raise RuntimeError(f"Failed to fetch files for datamate knowledge base {knowledge_base_id}: {str(e)}") + + def get_knowledge_base_info( + self, + knowledge_base_id: str, + authorization: Optional[str] = None + ) -> Dict[str, Any]: + """ + Get details for a specific DataMate knowledge base. + + Args: + knowledge_base_id: The ID of the knowledge base + authorization: Optional authorization header + + Returns: + Dictionary containing knowledge base details. + + Raises: + RuntimeError: If the API request fails + """ + try: + url = self._build_url(f"/api/knowledge-base/{knowledge_base_id}") + logger.info(f"Fetching details for DataMate knowledge base {knowledge_base_id} from: {url}") + + headers = self._build_headers(authorization) + response = self._make_request("GET", url, headers, error_message="Failed to get knowledge base details") + data = response.json() + + # Extract knowledge base details from response + knowledge_base = data.get("data", {}) + + logger.info(f"Successfully fetched details for datamate knowledge base {knowledge_base_id}") + return knowledge_base + + except httpx.HTTPError as e: + logger.error(f"HTTP error while fetching details for datamate knowledge base {knowledge_base_id}: {str(e)}") + raise RuntimeError(f"Failed to fetch details for datamate knowledge base {knowledge_base_id}: {str(e)}") + except Exception as e: + logger.error(f"Unexpected error while fetching details for datamate knowledge base {knowledge_base_id}: {str(e)}") + raise RuntimeError(f"Failed to fetch details for datamate knowledge base {knowledge_base_id}: {str(e)}") + + def retrieve_knowledge_base( + self, + query: str, + knowledge_base_ids: List[str], + top_k: int = 10, + threshold: float = 0.2, + authorization: Optional[str] = None + ) -> List[Dict[str, Any]]: + """ + Retrieve content in DataMate knowledge bases. + + Args: + query: Retrieve query text + knowledge_base_ids: List of knowledge base IDs to retrieve + top_k: Maximum number of results to return (default: 10) + threshold: Similarity threshold (default: 0.2) + authorization: Optional authorization header + + Returns: + List of retrieve result dictionaries + + Raises: + RuntimeError: If the API request fails + """ + try: + url = self._build_url("/api/knowledge-base/retrieve") + payload = { + "query": query, + "topK": top_k, + "threshold": threshold, + "knowledgeBaseIds": knowledge_base_ids, + } + + headers = self._build_headers(authorization) + + logger.info( + f"Retrieving DataMate knowledge bases: query='{query}', " + f"knowledge_base_ids={knowledge_base_ids}, top_k={top_k}, threshold={threshold}" + ) + + # Longer timeout for retrieve operation + response = self._make_request( + "POST", url, headers, json=payload, timeout=self.timeout * 2, + error_message="Failed to retrieve knowledge base content" + ) + + search_results = [] + data = response.json() + # Extract search results from response + for result in data.get("data", {}): + search_results.append(result) + + logger.info(f"Successfully retrieved {len(search_results)} retrieve result(s)") + return search_results + + except httpx.HTTPError as e: + logger.error(f"HTTP error while retrieving DataMate knowledge bases: {str(e)}") + raise RuntimeError(f"Failed to retrieve DataMate knowledge bases: {str(e)}") + except Exception as e: + logger.error(f"Unexpected error while retrieving DataMate knowledge bases: {str(e)}") + raise RuntimeError(f"Failed to retrieve DataMate knowledge bases: {str(e)}") + + def build_file_download_url(self, dataset_id: str, file_id: str) -> str: + """ + Build download URL for a DataMate file. + + Args: + dataset_id: Dataset ID + file_id: File ID + + Returns: + Full download URL for the file + """ + if not (dataset_id and file_id): + return "" + return f"{self.base_url}/api/data-management/datasets/{dataset_id}/files/{file_id}/download" + + def sync_all_knowledge_bases( + self, + authorization: Optional[str] = None + ) -> Dict[str, Any]: + """ + Sync all DataMate knowledge bases and their files. + + Args: + authorization: Optional authorization header + + Returns: + Dictionary containing knowledge bases with their file lists. + Format: { + "success": bool, + "knowledge_bases": [ + { + "knowledge_base": {...}, + "files": [...], + "error": str (optional) + } + ], + "total_count": int + } + """ + try: + # Fetch all knowledge bases + knowledge_bases = self.list_knowledge_bases(authorization=authorization) + + # Fetch files for each knowledge base + result = [] + for kb in knowledge_bases: + kb_id = kb.get("id") + + try: + files = self.get_knowledge_base_files(str(kb_id), authorization=authorization) + result.append({ + "knowledge_base": kb, + "files": files, + }) + except Exception as e: + logger.error(f"Failed to fetch files for datamate knowledge base {kb_id}: {str(e)}") + # Continue with other knowledge bases even if one fails + result.append({ + "knowledge_base": kb, + "files": [], + "error": str(e), + }) + + return { + "success": True, + "knowledge_bases": result, + "total_count": len(result), + } + + except Exception as e: + logger.error(f"Error syncing DataMate knowledge bases: {str(e)}") + return { + "success": False, + "error": str(e), + "knowledge_bases": [], + "total_count": 0, + } + diff --git a/sdk/nexent/vector_database/__init__.py b/sdk/nexent/vector_database/__init__.py index e69de29bb..9c811f9c6 100644 --- a/sdk/nexent/vector_database/__init__.py +++ b/sdk/nexent/vector_database/__init__.py @@ -0,0 +1,5 @@ +"""Vector database SDK public exports.""" + +from .datamate_core import DataMateCore + +__all__ = ["DataMateCore"] diff --git a/sdk/nexent/vector_database/datamate_core.py b/sdk/nexent/vector_database/datamate_core.py new file mode 100644 index 000000000..20da8ffb3 --- /dev/null +++ b/sdk/nexent/vector_database/datamate_core.py @@ -0,0 +1,251 @@ +""" +DataMate adapter implementing the VectorDatabaseCore interface. + +Not all operations are supported by the DataMate HTTP API. Unsupported methods +raise NotImplementedError to make limitations explicit. +""" +import logging +from datetime import datetime +from typing import Any, Dict, List, Optional, Callable, Tuple + +from .base import VectorDatabaseCore +from ..datamate.datamate_client import DataMateClient +from ..core.models.embedding_model import BaseEmbedding + +logger = logging.getLogger("datamate_core") + + +def _parse_timestamp(timestamp: Any, default: int = 0) -> int: + """ + Parse timestamp from various formats to milliseconds since epoch. + + Args: + timestamp: Timestamp value (int, str, or None) + default: Default value if parsing fails + + Returns: + Timestamp in milliseconds since epoch + """ + if timestamp is None: + return default + + if isinstance(timestamp, int): + # If already an int, assume it's in milliseconds (or seconds if < 1e10) + if timestamp < 1e10: + return timestamp * 1000 + return timestamp + + if isinstance(timestamp, str): + try: + # Try ISO format + dt = datetime.fromisoformat(timestamp.replace("Z", "+00:00")) + return int(dt.timestamp() * 1000) + except Exception: + try: + # Try as integer string + ts_int = int(timestamp) + if ts_int < 1e10: + return ts_int * 1000 + return ts_int + except Exception: + return default + + return default + + +class DataMateCore(VectorDatabaseCore): + """VectorDatabaseCore implementation backed by the DataMate REST API.""" + + def __init__(self, base_url: str, timeout: float = 30.0): + self.client = DataMateClient(base_url=base_url, timeout=timeout) + + # ---- INDEX MANAGEMENT ---- + def create_index(self, index_name: str, embedding_dim: Optional[int] = None) -> bool: + """DataMate API does not support index creation via SDK.""" + _ = embedding_dim + raise NotImplementedError("DataMate SDK does not support creating indices.") + + def delete_index(self, index_name: str) -> bool: + """DataMate API does not support deleting indices via SDK.""" + raise NotImplementedError("DataMate SDK does not support deleting indices.") + + def get_user_indices(self, index_pattern: str = "*") -> List[str]: + """Return DataMate knowledge base IDs as index identifiers.""" + _ = index_pattern + knowledge_bases = self.client.list_knowledge_bases() + return [str(kb.get("id")) for kb in knowledge_bases if kb.get("id") is not None] + + def check_index_exists(self, index_name: str) -> bool: + """Check existence by knowledge base id.""" + return index_name in self.get_user_indices() + + # ---- DOCUMENT OPERATIONS ---- + def vectorize_documents( + self, + index_name: str, + embedding_model: BaseEmbedding, + documents: List[Dict[str, Any]], + batch_size: int = 64, + content_field: str = "content", + embedding_batch_size: int = 10, + progress_callback: Optional[Callable[[int, int], None]] = None, + ) -> int: + _ = ( + index_name, + embedding_model, + documents, + batch_size, + content_field, + embedding_batch_size, + progress_callback, + ) + raise NotImplementedError("DataMate SDK does not support direct document ingestion.") + + def delete_documents(self, index_name: str, path_or_url: str) -> int: + _ = (index_name, path_or_url) + raise NotImplementedError("DataMate SDK does not support deleting documents.") + + def get_index_chunks( + self, + index_name: str, + page: Optional[int] = None, + page_size: Optional[int] = None, + path_or_url: Optional[str] = None, + ) -> Dict[str, Any]: + _ = (page, page_size, path_or_url) + files = self.client.get_knowledge_base_files(index_name) + return { + "chunks": files, + "total": len(files), + "page": page, + "page_size": page_size, + } + + def create_chunk(self, index_name: str, chunk: Dict[str, Any]) -> Dict[str, Any]: + _ = (index_name, chunk) + raise NotImplementedError("DataMate SDK does not support creating individual chunks.") + + def update_chunk(self, index_name: str, chunk_id: str, chunk_updates: Dict[str, Any]) -> Dict[str, Any]: + _ = (index_name, chunk_id, chunk_updates) + raise NotImplementedError("DataMate SDK does not support updating chunks.") + + def delete_chunk(self, index_name: str, chunk_id: str) -> bool: + _ = (index_name, chunk_id) + raise NotImplementedError("DataMate SDK does not support deleting chunks.") + + def count_documents(self, index_name: str) -> int: + files = self.client.get_knowledge_base_files(index_name) + return len(files) + + # ---- SEARCH OPERATIONS ---- + def search(self, index_name: str, query: Dict[str, Any]) -> Dict[str, Any]: + _ = (index_name, query) + raise NotImplementedError("DataMate SDK does not support raw search API.") + + def multi_search(self, body: List[Dict[str, Any]], index_name: str) -> Dict[str, Any]: + _ = (body, index_name) + raise NotImplementedError("DataMate SDK does not support multi search API.") + + def accurate_search(self, index_names: List[str], query_text: str, top_k: int = 5) -> List[Dict[str, Any]]: + _ = (index_names, query_text, top_k) + raise NotImplementedError("DataMate SDK does not support accurate search API.") + + def semantic_search( + self, index_names: List[str], query_text: str, embedding_model: BaseEmbedding, top_k: int = 5 + ) -> List[Dict[str, Any]]: + _ = (index_names, query_text, embedding_model, top_k) + raise NotImplementedError("DataMate SDK does not support semantic search API.") + + # ---- SEARCH OPERATIONS ---- + def hybrid_search( + self, + index_names: List[str], + query_text: str, + embedding_model: Optional[BaseEmbedding] = None, + top_k: int = 10, + weight_accurate: float = 0.2, + ) -> List[Dict[str, Any]]: + """ + Retrieve content in DataMate knowledge bases. + + Args: + index_names: List of knowledge base IDs to retrieve + query_text: Retrieve query text + embedding_model: Optional embedding model + top_k: Maximum number of results to return (default: 10) + weight_accurate: Similarity threshold (default: 0.2) + + Returns: + List of retrieve result dictionaries + + Raises: + RuntimeError: If the API request fails + """ + _ = embedding_model # Explicitly ignored + retrieve_knowledge = self.client.retrieve_knowledge_base(query_text, index_names, top_k, weight_accurate) + return retrieve_knowledge + + # ---- STATISTICS AND MONITORING ---- + def get_documents_detail(self, index_name: str) -> List[Dict[str, Any]]: + files_list = self.client.get_knowledge_base_files(index_name) + results = [] + for info in files_list: + file_info = { + "path_or_url": info.get("path_or_url", ""), + "file": info.get("fileName", ""), + "file_size": info.get("fileSize", ""), + "create_time": _parse_timestamp(info.get("createdAt", "")), + "chunk_count": info.get("chunkCount", ""), + "status": "COMPLETED", + "latest_task_id": "", + "error_reason": info.get("errMsg", ""), + "has_error_info": False, + "processed_chunk_num": None, + "total_chunk_num": None, + "chunks": [] + } + results.append(file_info) + return results + + def get_indices_detail(self, index_names: List[str], embedding_dim: Optional[int] = None) -> Tuple[Dict[ + str, Dict[str, Any]], List[str]]: + details: Dict[str, Dict[str, Any]] = {} + knowledge_base_names = [] + for kb_id in index_names: + try: + # Get knowledge base info and files + kb_info = self.client.get_knowledge_base_info(kb_id) + + # Extract data from knowledge base info + doc_count = kb_info.get("fileCount") # Number of unique documents (files) + knowledge_base_name = kb_info.get("name") + knowledge_base_names.append(knowledge_base_name) + chunk_count = kb_info.get("chunkCount") + store_size = kb_info.get("storeSize", "") + process_source = kb_info.get("processSource", "Unstructured") + embedding_model = kb_info.get("embedding").get("modelName") + + # Parse timestamps + creation_date = _parse_timestamp(kb_info.get("createdAt")) + update_date = _parse_timestamp(kb_info.get("updatedAt")) + + # Build base_info dict + base_info = { + "doc_count": doc_count, + "chunk_count": chunk_count, + "store_size": str(store_size), + "process_source": str(process_source), + "embedding_model": str(embedding_model), + "embedding_dim": embedding_dim or 1024, + "creation_date": creation_date, + "update_date": update_date, + } + + # Build performance dict (DataMate API may not provide search stats) + performance = {"total_search_count": 0, "hit_count": 0} + + details[kb_id] = {"base_info": base_info, "search_performance": performance} + except Exception as exc: + logger.error(f"Error getting stats for knowledge base {kb_id}: {str(exc)}") + details[kb_id] = {"error": str(exc)} + return details, knowledge_base_names diff --git a/test/backend/app/test_knowledge_summary_app.py b/test/backend/app/test_knowledge_summary_app.py index 7fa1ace12..8b49e079b 100644 --- a/test/backend/app/test_knowledge_summary_app.py +++ b/test/backend/app/test_knowledge_summary_app.py @@ -49,6 +49,11 @@ def __init__(self, *args, **kwargs): sys.modules['nexent.vector_database'] = vector_db_module sys.modules['nexent.vector_database.base'] = vector_db_base_module sys.modules['nexent.vector_database.elasticsearch_core'] = MagicMock() +# Provide datamate_core module with DataMateCore to satisfy imports like +# `from nexent.vector_database.datamate_core import DataMateCore` +datamate_core_module = types.ModuleType("nexent.vector_database.datamate_core") +datamate_core_module.DataMateCore = MagicMock() +sys.modules['nexent.vector_database.datamate_core'] = datamate_core_module # Mock specific classes that are imported class MockToolConfig: diff --git a/test/backend/app/test_tenant_config_app.py b/test/backend/app/test_tenant_config_app.py index 0ba7dd314..d79e71295 100644 --- a/test/backend/app/test_tenant_config_app.py +++ b/test/backend/app/test_tenant_config_app.py @@ -202,35 +202,46 @@ def test_load_knowledge_list_missing_model_name(self): def test_update_knowledge_list_success(self): """Test successful knowledge list update""" - knowledge_list = ["kb1", "kb3"] + request_data = { + "nexent": ["kb1"], + "datamate": ["kb2"] + } response = self.client.post( "/tenant_config/update_knowledge_list", headers={"authorization": "Bearer test-token"}, - json=knowledge_list + json=request_data ) self.assertEqual(response.status_code, HTTPStatus.OK) data = response.json() self.assertEqual(data["status"], "success") self.assertEqual(data["message"], "update success") + self.assertIn("content", data) + self.assertIn("selectedKbNames", data["content"]) + self.assertIn("selectedKbModels", data["content"]) + self.assertIn("selectedKbSources", data["content"]) - # Verify the mock was called with correct parameters + # Verify the mock was called with correct parameters (flattened) self.mock_update_knowledge.assert_called_once_with( tenant_id="test_tenant", user_id="test_user", - index_name_list=knowledge_list + index_name_list=["kb1", "kb2"], + knowledge_sources=["nexent", "datamate"] ) def test_update_knowledge_list_failure(self): """Test knowledge list update failure""" self.mock_update_knowledge.return_value = False - knowledge_list = ["kb1", "kb3"] + request_data = { + "nexent": ["kb1"], + "datamate": ["kb2"] + } response = self.client.post( "/tenant_config/update_knowledge_list", headers={"authorization": "Bearer test-token"}, - json=knowledge_list + json=request_data ) self.assertEqual(response.status_code, @@ -241,12 +252,15 @@ def test_update_knowledge_list_failure(self): def test_update_knowledge_list_auth_error(self): """Test knowledge list update with authentication error""" self.mock_get_user_id.side_effect = Exception("Authentication failed") - knowledge_list = ["kb1", "kb3"] + request_data = { + "nexent": ["kb1"], + "datamate": ["kb2"] + } response = self.client.post( "/tenant_config/update_knowledge_list", headers={"authorization": "Bearer invalid-token"}, - json=knowledge_list + json=request_data ) self.assertEqual(response.status_code, @@ -257,12 +271,15 @@ def test_update_knowledge_list_auth_error(self): def test_update_knowledge_list_service_error(self): """Test knowledge list update with service error""" self.mock_update_knowledge.side_effect = Exception("Database error") - knowledge_list = ["kb1", "kb3"] + request_data = { + "nexent": ["kb1"], + "datamate": ["kb2"] + } response = self.client.post( "/tenant_config/update_knowledge_list", headers={"authorization": "Bearer test-token"}, - json=knowledge_list + json=request_data ) self.assertEqual(response.status_code, @@ -272,12 +289,15 @@ def test_update_knowledge_list_service_error(self): def test_update_knowledge_list_empty_list(self): """Test updating with empty knowledge list""" - knowledge_list = [] + request_data = { + "nexent": [], + "datamate": [] + } response = self.client.post( "/tenant_config/update_knowledge_list", headers={"authorization": "Bearer test-token"}, - json=knowledge_list + json=request_data ) self.assertEqual(response.status_code, HTTPStatus.OK) @@ -292,17 +312,10 @@ def test_update_knowledge_list_no_body(self): headers={"authorization": "Bearer test-token"} ) - # When no body is provided, FastAPI will pass None to the knowledge_list parameter - self.assertEqual(response.status_code, HTTPStatus.OK) + # When no body is provided, Pydantic will raise validation error + self.assertEqual(response.status_code, 422) # Unprocessable Entity data = response.json() - self.assertEqual(data["status"], "success") - - # Verify the mock was called with None - self.mock_update_knowledge.assert_called_once_with( - tenant_id="test_tenant", - user_id="test_user", - index_name_list=None - ) + self.assertIn("detail", data) def test_get_deployment_version_success(self): """Test successful retrieval of deployment version""" @@ -326,11 +339,14 @@ def test_load_knowledge_list_no_auth_header(self): def test_update_knowledge_list_no_auth_header(self): """Test updating knowledge list without authorization header""" - knowledge_list = ["kb1", "kb2"] + request_data = { + "nexent": ["kb1"], + "datamate": ["kb2"] + } response = self.client.post( "/tenant_config/update_knowledge_list", - json=knowledge_list + json=request_data ) # This should still work as the authorization parameter is Optional diff --git a/test/backend/database/test_client.py b/test/backend/database/test_client.py index 91ee388ed..09136a8c4 100644 --- a/test/backend/database/test_client.py +++ b/test/backend/database/test_client.py @@ -100,7 +100,7 @@ def test_postgres_client_init(self, mock_sessionmaker, mock_create_engine): """Test PostgresClient initialization""" # Reset singleton instance PostgresClient._instance = None - + mock_engine = MagicMock() mock_create_engine.return_value = mock_engine mock_session = MagicMock() @@ -120,7 +120,7 @@ def test_postgres_client_singleton(self): """Test PostgresClient is a singleton""" # Reset singleton instance PostgresClient._instance = None - + client1 = PostgresClient() client2 = PostgresClient() @@ -166,7 +166,7 @@ def test_minio_client_init(self, mock_config_class, mock_create_client): """Test MinioClient initialization""" # Reset singleton instance MinioClient._instance = None - + mock_config = MagicMock() mock_config.default_bucket = 'test-bucket' mock_config_class.return_value = mock_config @@ -184,7 +184,7 @@ def test_minio_client_singleton(self): """Test MinioClient is a singleton""" # Reset singleton instance MinioClient._instance = None - + with patch('backend.database.client.create_storage_client_from_config'), \ patch('backend.database.client.MinIOStorageConfig'): client1 = MinioClient() @@ -197,7 +197,7 @@ def test_minio_client_singleton(self): def test_minio_client_upload_file(self, mock_config_class, mock_create_client): """Test MinioClient.upload_file delegates to storage client""" MinioClient._instance = None - + mock_storage_client = MagicMock() mock_storage_client.upload_file.return_value = (True, '/bucket/file.txt') mock_create_client.return_value = mock_storage_client @@ -215,7 +215,7 @@ def test_minio_client_upload_file(self, mock_config_class, mock_create_client): def test_minio_client_upload_fileobj(self, mock_config_class, mock_create_client): """Test MinioClient.upload_fileobj delegates to storage client""" MinioClient._instance = None - + from io import BytesIO mock_storage_client = MagicMock() mock_storage_client.upload_fileobj.return_value = (True, '/bucket/file.txt') @@ -235,7 +235,7 @@ def test_minio_client_upload_fileobj(self, mock_config_class, mock_create_client def test_minio_client_download_file(self, mock_config_class, mock_create_client): """Test MinioClient.download_file delegates to storage client""" MinioClient._instance = None - + mock_storage_client = MagicMock() mock_storage_client.download_file.return_value = (True, 'Downloaded successfully') mock_create_client.return_value = mock_storage_client @@ -253,7 +253,7 @@ def test_minio_client_download_file(self, mock_config_class, mock_create_client) def test_minio_client_get_file_url(self, mock_config_class, mock_create_client): """Test MinioClient.get_file_url delegates to storage client""" MinioClient._instance = None - + mock_storage_client = MagicMock() mock_storage_client.get_file_url.return_value = (True, 'http://example.com/file.txt') mock_create_client.return_value = mock_storage_client @@ -271,7 +271,7 @@ def test_minio_client_get_file_url(self, mock_config_class, mock_create_client): def test_minio_client_get_file_size(self, mock_config_class, mock_create_client): """Test MinioClient.get_file_size delegates to storage client""" MinioClient._instance = None - + mock_storage_client = MagicMock() mock_storage_client.get_file_size.return_value = 1024 mock_create_client.return_value = mock_storage_client @@ -288,7 +288,7 @@ def test_minio_client_get_file_size(self, mock_config_class, mock_create_client) def test_minio_client_list_files(self, mock_config_class, mock_create_client): """Test MinioClient.list_files delegates to storage client""" MinioClient._instance = None - + mock_storage_client = MagicMock() mock_storage_client.list_files.return_value = [ {'key': 'file1.txt', 'size': 100}, @@ -309,7 +309,7 @@ def test_minio_client_list_files(self, mock_config_class, mock_create_client): def test_minio_client_delete_file(self, mock_config_class, mock_create_client): """Test MinioClient.delete_file delegates to storage client""" MinioClient._instance = None - + mock_storage_client = MagicMock() mock_storage_client.delete_file.return_value = (True, 'Deleted successfully') mock_create_client.return_value = mock_storage_client @@ -327,7 +327,7 @@ def test_minio_client_delete_file(self, mock_config_class, mock_create_client): def test_minio_client_get_file_stream(self, mock_config_class, mock_create_client): """Test MinioClient.get_file_stream delegates to storage client""" MinioClient._instance = None - + from io import BytesIO mock_storage_client = MagicMock() mock_stream = BytesIO(b'test data') @@ -350,7 +350,7 @@ def test_get_db_session_with_new_session(self): """Test get_db_session creates and manages a new session""" mock_session = MagicMock() mock_session_maker = MagicMock(return_value=mock_session) - + # Mock db_client with patch('backend.database.client.db_client') as mock_db_client: mock_db_client.session_maker = mock_session_maker @@ -377,7 +377,7 @@ def test_get_db_session_rollback_on_exception(self): """Test get_db_session rolls back on exception""" mock_session = MagicMock() mock_session_maker = MagicMock(return_value=mock_session) - + with patch('backend.database.client.db_client') as mock_db_client: mock_db_client.session_maker = mock_session_maker diff --git a/test/backend/services/test_conversation_management_service.py b/test/backend/services/test_conversation_management_service.py index 2d690938a..bcd306e7b 100644 --- a/test/backend/services/test_conversation_management_service.py +++ b/test/backend/services/test_conversation_management_service.py @@ -1,5 +1,15 @@ -import sys import types +import unittest +import json +import sys +import asyncio +import os +from datetime import datetime +from unittest.mock import patch, MagicMock +import types as _types +import importlib + +from backend.consts.model import MessageRequest, AgentRequest, MessageUnit def _stub_nexent_openai_model(): # Provide a simple OpenAIModel stub for import-time safety @@ -42,6 +52,83 @@ def render(self, ctx): # # Stub consts.model to avoid pydantic/email-validator heavy imports during tests. consts_model_mod = types.ModuleType("consts.model") + +# Patch environment variables before any imports that might use them +os.environ.setdefault('MINIO_ENDPOINT', 'http://localhost:9000') +os.environ.setdefault('MINIO_ACCESS_KEY', 'minioadmin') +os.environ.setdefault('MINIO_SECRET_KEY', 'minioadmin') +os.environ.setdefault('MINIO_REGION', 'us-east-1') +os.environ.setdefault('MINIO_DEFAULT_BUCKET', 'test-bucket') + +# Mock boto3 and minio client before importing the module under test +boto3_mock = MagicMock() +sys.modules['boto3'] = boto3_mock + +# Patch storage factory and MinIO config validation to avoid errors during initialization +# These patches must be started before any imports that use MinioClient +storage_client_mock = MagicMock() +minio_client_mock = MagicMock() +# Ensure minimal `nexent.storage` stubs exist so `patch('nexent.storage...')` doesn't +# trigger importing the installed `nexent` package which may have heavy imports. +if 'nexent' not in sys.modules: + + _nexent_mod = _types.ModuleType('nexent') + _nexent_storage = _types.ModuleType('nexent.storage') + _storage_factory = _types.ModuleType('nexent.storage.storage_client_factory') + # provide a simple factory function that returns our storage_client_mock + _storage_factory.create_storage_client_from_config = lambda cfg: storage_client_mock + _minio_conf = _types.ModuleType('nexent.storage.minio_config') + class _MinIOStorageConfigStub: + def __init__(self, endpoint=None, access_key=None, secret_key=None, region=None, default_bucket=None, secure=None, **kwargs): + # Store constructor parameters to mimic real config object attributes + self.endpoint = endpoint + self.access_key = access_key + self.secret_key = secret_key + self.region = region + self.default_bucket = default_bucket + self.secure = secure + + def validate(self): + return None + _minio_conf.MinIOStorageConfig = _MinIOStorageConfigStub + # Also expose MinIOStorageConfig on the storage_client_factory module + _storage_factory.MinIOStorageConfig = _MinIOStorageConfigStub + # attach hierarchy and register in sys.modules + _nexent_mod.storage = _nexent_storage + _nexent_storage.storage_client_factory = _storage_factory + _nexent_storage.minio_config = _minio_conf + sys.modules['nexent'] = _nexent_mod + sys.modules['nexent.storage'] = _nexent_storage + sys.modules['nexent.storage.storage_client_factory'] = _storage_factory + sys.modules['nexent.storage.minio_config'] = _minio_conf + +# Now safe to patch (patch will import from sys.modules instead of site-packages) +patch('nexent.storage.storage_client_factory.create_storage_client_from_config', return_value=storage_client_mock).start() +patch('nexent.storage.minio_config.MinIOStorageConfig.validate', lambda self: None).start() + +importlib.import_module("backend.database.client") +patch('backend.database.client.MinioClient', return_value=minio_client_mock).start() + +with patch('backend.database.client.MinioClient', return_value=minio_client_mock): + from backend.services.conversation_management_service import ( + save_message, + save_conversation_user, + save_conversation_assistant, + extract_user_messages, + call_llm_for_title, + update_conversation_title, + create_new_conversation, + get_conversation_list_service, + rename_conversation_service, + delete_conversation_service, + get_conversation_history_service, + get_sources_service, + generate_conversation_title_service, + update_message_opinion_service, + get_message_id_by_index_impl + ) + + class AgentRequest: def __init__(self, **kwargs): for k, v in kwargs.items(): @@ -126,53 +213,6 @@ def test_call_llm_for_title_flattening(monkeypatch): title = call_llm_for_title("some conversation content", tenant_id="t", language="zh") assert title == "The Title" -from backend.consts.model import MessageRequest, AgentRequest, MessageUnit -import unittest -import json -import asyncio -import os -from datetime import datetime -from unittest.mock import patch, MagicMock - -# Patch environment variables before any imports that might use them -os.environ.setdefault('MINIO_ENDPOINT', 'http://localhost:9000') -os.environ.setdefault('MINIO_ACCESS_KEY', 'minioadmin') -os.environ.setdefault('MINIO_SECRET_KEY', 'minioadmin') -os.environ.setdefault('MINIO_REGION', 'us-east-1') -os.environ.setdefault('MINIO_DEFAULT_BUCKET', 'test-bucket') - -# Mock boto3 and minio client before importing the module under test -import sys -boto3_mock = MagicMock() -sys.modules['boto3'] = boto3_mock - -# Patch storage factory and MinIO config validation to avoid errors during initialization -# These patches must be started before any imports that use MinioClient -storage_client_mock = MagicMock() -minio_client_mock = MagicMock() -patch('nexent.storage.storage_client_factory.create_storage_client_from_config', return_value=storage_client_mock).start() -patch('nexent.storage.minio_config.MinIOStorageConfig.validate', lambda self: None).start() -patch('backend.database.client.MinioClient', return_value=minio_client_mock).start() - -with patch('backend.database.client.MinioClient', return_value=minio_client_mock): - from backend.services.conversation_management_service import ( - save_message, - save_conversation_user, - save_conversation_assistant, - extract_user_messages, - call_llm_for_title, - update_conversation_title, - create_new_conversation, - get_conversation_list_service, - rename_conversation_service, - delete_conversation_service, - get_conversation_history_service, - get_sources_service, - generate_conversation_title_service, - update_message_opinion_service, - get_message_id_by_index_impl - ) - class TestConversationManagementService(unittest.TestCase): def setUp(self): diff --git a/test/backend/services/test_datamate_service.py b/test/backend/services/test_datamate_service.py new file mode 100644 index 000000000..a7aa0765d --- /dev/null +++ b/test/backend/services/test_datamate_service.py @@ -0,0 +1,43 @@ +import pytest + +from backend.services import datamate_service + + +class FakeClient: + def __init__(self, base_url=None): + self.base_url = base_url + + def list_knowledge_bases(self): + return [{"id": "kb1", "name": "KB1"}] + + def get_knowledge_base_files(self, knowledge_base_id): + return [{"name": "file1", "size": 123, "knowledge_base_id": knowledge_base_id}] + + def sync_all_knowledge_bases(self): + return {"success": True, "knowledge_bases": [{"id": "kb1"}], "total_count": 1} + + + + +@pytest.mark.asyncio +async def test_fetch_datamate_knowledge_base_files_success(monkeypatch): + monkeypatch.setattr(datamate_service, "_get_datamate_client", lambda: FakeClient()) + files = await datamate_service.fetch_datamate_knowledge_base_files("kb1") + assert isinstance(files, list) + assert files[0]["knowledge_base_id"] == "kb1" + + +@pytest.mark.asyncio +async def test_fetch_datamate_knowledge_base_files_failure(monkeypatch): + class BadClient(FakeClient): + def get_knowledge_base_files(self, knowledge_base_id): + raise Exception("boom") + + monkeypatch.setattr(datamate_service, "_get_datamate_client", lambda: BadClient()) + with pytest.raises(RuntimeError) as excinfo: + await datamate_service.fetch_datamate_knowledge_base_files("kb1") + assert "Failed to fetch files for knowledge base kb1" in str(excinfo.value) + + + + diff --git a/test/backend/services/test_tenant_config_service.py b/test/backend/services/test_tenant_config_service.py index 3e6df7676..e2263ea59 100644 --- a/test/backend/services/test_tenant_config_service.py +++ b/test/backend/services/test_tenant_config_service.py @@ -14,6 +14,7 @@ update_selected_knowledge, delete_selected_knowledge_by_index_name, ) +from consts.model import UpdateKnowledgeListRequest class TestTenantConfigService(unittest.TestCase): @@ -55,48 +56,55 @@ def test_get_selected_knowledge_list_with_records( ) mock_get_knowledge_info.assert_called_once_with([self.knowledge_id]) + @patch("backend.services.tenant_config_service.get_selected_knowledge_list") @patch("backend.services.tenant_config_service.delete_config_by_tenant_config_id") @patch("backend.services.tenant_config_service.insert_config") @patch("backend.services.tenant_config_service.get_tenant_config_info") @patch("backend.services.tenant_config_service.get_knowledge_ids_by_index_names") def test_update_selected_knowledge_add_only( - self, mock_get_ids, mock_get_config, mock_insert, mock_delete + self, mock_get_ids, mock_get_config, mock_insert, mock_delete, mock_get_list ): mock_get_ids.return_value = self.knowledge_ids mock_get_config.return_value = [] mock_insert.return_value = True + mock_get_list.return_value = [] + request = UpdateKnowledgeListRequest(nexent=self.index_name_list) result = update_selected_knowledge( - self.tenant_id, self.user_id, self.index_name_list + self.tenant_id, self.user_id, request ) - self.assertTrue(result) + self.assertIsNotNone(result) self.assertEqual(mock_insert.call_count, 2) mock_delete.assert_not_called() + @patch("backend.services.tenant_config_service.get_selected_knowledge_list") @patch("backend.services.tenant_config_service.delete_config_by_tenant_config_id") @patch("backend.services.tenant_config_service.insert_config") @patch("backend.services.tenant_config_service.get_tenant_config_info") @patch("backend.services.tenant_config_service.get_knowledge_ids_by_index_names") def test_update_selected_knowledge_remove_only( - self, mock_get_ids, mock_get_config, mock_insert, mock_delete + self, mock_get_ids, mock_get_config, mock_insert, mock_delete, mock_get_list ): mock_get_ids.return_value = [] mock_get_config.return_value = [ {"config_value": self.knowledge_id, "tenant_config_id": self.tenant_config_id} ] mock_delete.return_value = True + mock_get_list.return_value = [] - result = update_selected_knowledge(self.tenant_id, self.user_id, []) - self.assertTrue(result) + request = UpdateKnowledgeListRequest() + result = update_selected_knowledge(self.tenant_id, self.user_id, request) + self.assertIsNotNone(result) mock_insert.assert_not_called() mock_delete.assert_called_once_with(self.tenant_config_id) + @patch("backend.services.tenant_config_service.get_selected_knowledge_list") @patch("backend.services.tenant_config_service.delete_config_by_tenant_config_id") @patch("backend.services.tenant_config_service.insert_config") @patch("backend.services.tenant_config_service.get_tenant_config_info") @patch("backend.services.tenant_config_service.get_knowledge_ids_by_index_names") def test_update_selected_knowledge_add_and_remove( - self, mock_get_ids, mock_get_config, mock_insert, mock_delete + self, mock_get_ids, mock_get_config, mock_insert, mock_delete, mock_get_list ): mock_get_ids.return_value = ["knowledge_id_2"] mock_get_config.return_value = [ @@ -104,35 +112,40 @@ def test_update_selected_knowledge_add_and_remove( ] mock_insert.return_value = True mock_delete.return_value = True + mock_get_list.return_value = [] - result = update_selected_knowledge(self.tenant_id, self.user_id, ["new_index"]) - self.assertTrue(result) + request = UpdateKnowledgeListRequest(nexent=["new_index"]) + result = update_selected_knowledge(self.tenant_id, self.user_id, request) + self.assertIsNotNone(result) mock_insert.assert_called_once() mock_delete.assert_called_once_with("tenant_config_id_1") + @patch("backend.services.tenant_config_service.get_selected_knowledge_list") @patch("backend.services.tenant_config_service.delete_config_by_tenant_config_id") @patch("backend.services.tenant_config_service.insert_config") @patch("backend.services.tenant_config_service.get_tenant_config_info") @patch("backend.services.tenant_config_service.get_knowledge_ids_by_index_names") def test_update_selected_knowledge_insert_failure( - self, mock_get_ids, mock_get_config, mock_insert, mock_delete + self, mock_get_ids, mock_get_config, mock_insert, mock_delete, mock_get_list ): mock_get_ids.return_value = self.knowledge_ids mock_get_config.return_value = [] mock_insert.return_value = False + request = UpdateKnowledgeListRequest(nexent=self.index_name_list) result = update_selected_knowledge( - self.tenant_id, self.user_id, self.index_name_list + self.tenant_id, self.user_id, request ) - self.assertFalse(result) + self.assertIsNone(result) mock_insert.assert_called_once() + @patch("backend.services.tenant_config_service.get_selected_knowledge_list") @patch("backend.services.tenant_config_service.delete_config_by_tenant_config_id") @patch("backend.services.tenant_config_service.insert_config") @patch("backend.services.tenant_config_service.get_tenant_config_info") @patch("backend.services.tenant_config_service.get_knowledge_ids_by_index_names") def test_update_selected_knowledge_delete_failure( - self, mock_get_ids, mock_get_config, mock_insert, mock_delete + self, mock_get_ids, mock_get_config, mock_insert, mock_delete, mock_get_list ): mock_get_ids.return_value = [] mock_get_config.return_value = [ @@ -140,8 +153,9 @@ def test_update_selected_knowledge_delete_failure( ] mock_delete.return_value = False - result = update_selected_knowledge(self.tenant_id, self.user_id, []) - self.assertFalse(result) + request = UpdateKnowledgeListRequest() + result = update_selected_knowledge(self.tenant_id, self.user_id, request) + self.assertIsNone(result) mock_delete.assert_called_once_with(self.tenant_config_id) @patch("backend.services.tenant_config_service.delete_config_by_tenant_config_id") diff --git a/test/backend/services/test_tool_configuration_service.py b/test/backend/services/test_tool_configuration_service.py index b63474d21..550bae479 100644 --- a/test/backend/services/test_tool_configuration_service.py +++ b/test/backend/services/test_tool_configuration_service.py @@ -195,6 +195,19 @@ def __init__(self, *args, **kwargs): pass +# Provide a mock DataMateCore to satisfy imports in vectordatabase_service +vector_database_datamate_module = types.ModuleType('nexent.vector_database.datamate_core') + + +class MockDataMateCore(MockVectorDatabaseCore): + def __init__(self, *args, **kwargs): + pass + +vector_database_datamate_module.DataMateCore = MockDataMateCore +sys.modules['nexent.vector_database.datamate_core'] = vector_database_datamate_module +setattr(sys.modules['nexent.vector_database'], 'datamate_core', vector_database_datamate_module) +setattr(sys.modules['nexent.vector_database'], 'DataMateCore', MockDataMateCore) + vector_database_base_module.VectorDatabaseCore = MockVectorDatabaseCore vector_database_elasticsearch_module.ElasticSearchCore = MockElasticSearchCore sys.modules['nexent.vector_database.base'] = vector_database_base_module diff --git a/test/backend/services/test_vectordatabase_service.py b/test/backend/services/test_vectordatabase_service.py index e7b9c0fcb..658d2570d 100644 --- a/test/backend/services/test_vectordatabase_service.py +++ b/test/backend/services/test_vectordatabase_service.py @@ -61,6 +61,31 @@ class _VectorDatabaseCore: vector_db_base_module.VectorDatabaseCore = _VectorDatabaseCore sys.modules['nexent.vector_database.base'] = vector_db_base_module sys.modules['nexent.vector_database.elasticsearch_core'] = MagicMock() +sys.modules['nexent.vector_database.datamate_core'] = MagicMock() +# Provide a lightweight models module with the IndexStatsSummary class used in the service. +vector_db_models_module = ModuleType('nexent.vector_database.models') + + +class _IndexStatsSummary: + def __init__(self, base_info=None, search_performance=None, error=None): + self.base_info = base_info + self.search_performance = search_performance + self.error = error + + def to_dict(self): + payload = {} + if self.base_info is not None: + payload["base_info"] = self.base_info + if self.search_performance is not None: + payload["search_performance"] = self.search_performance + if self.error is not None: + payload["error"] = self.error + return payload + + +vector_db_models_module.IndexStatsSummary = _IndexStatsSummary +sys.modules['nexent.vector_database.models'] = vector_db_models_module +IndexStatsSummary = _IndexStatsSummary # Mock nexent.storage module and its submodules before any imports sys.modules['nexent.storage'] = _create_package_mock('nexent.storage') storage_factory_module = MagicMock() @@ -416,8 +441,9 @@ def test_list_indices_without_stats(self, mock_get_knowledge): self.mock_vdb_core.get_user_indices.assert_called_once_with("*") mock_get_knowledge.assert_called_once_with(tenant_id="test_tenant") + @patch('backend.services.vectordatabase_service.update_model_name_by_index_name') @patch('backend.services.vectordatabase_service.get_knowledge_info_by_tenant_id') - def test_list_indices_with_stats(self, mock_get_knowledge): + def test_list_indices_with_stats(self, mock_get_knowledge, mock_update_model): """ Test listing indices with statistics included. @@ -428,9 +454,16 @@ def test_list_indices_with_stats(self, mock_get_knowledge): """ # Setup self.mock_vdb_core.get_user_indices.return_value = ["index1", "index2"] + # get_indices_detail returns Dict[str, Dict[str, Dict[str, Any]]], not IndexStatsSummary objects self.mock_vdb_core.get_indices_detail.return_value = { - "index1": {"base_info": {"doc_count": 10, "embedding_model": "test-model"}}, - "index2": {"base_info": {"doc_count": 20, "embedding_model": "test-model"}} + "index1": { + "base_info": {"doc_count": 10, "embedding_model": "test-model"}, + "search_performance": {} + }, + "index2": { + "base_info": {"doc_count": 20, "embedding_model": "test-model"}, + "search_performance": {} + }, } mock_get_knowledge.return_value = [ {"index_name": "index1", "embedding_model_name": "test-model"}, @@ -480,8 +513,9 @@ def test_list_indices_removes_stale_pg_records(self, mock_delete_knowledge, mock self.assertEqual(result["indices"], []) self.assertEqual(result["count"], 0) + @patch('backend.services.vectordatabase_service.update_model_name_by_index_name') @patch('backend.services.vectordatabase_service.get_knowledge_info_by_tenant_id') - def test_list_indices_stats_defaults_when_missing(self, mock_get_info): + def test_list_indices_stats_defaults_when_missing(self, mock_get_info, mock_update_model): """ Test list_indices include_stats path when Elasticsearch returns no stats for an index. """ @@ -555,8 +589,9 @@ def test_list_indices_stats_surfaces_elasticsearch_errors(self, mock_get_info): self.assertIn("503 Service Unavailable", str(context.exception)) + @patch('backend.services.vectordatabase_service.update_model_name_by_index_name') @patch('backend.services.vectordatabase_service.get_knowledge_info_by_tenant_id') - def test_list_indices_stats_keeps_non_stat_fields(self, mock_get_info): + def test_list_indices_stats_keeps_non_stat_fields(self, mock_get_info, mock_update_model): """ Test that list_indices preserves all stats fields returned by ElasticSearchCore. """ @@ -585,6 +620,7 @@ def test_list_indices_stats_keeps_non_stat_fields(self, mock_get_info): ) self.assertEqual(len(result["indices_info"]), 1) + # `detailed_stats` is already a dict; compare directly (models now return dicts). self.assertEqual(result["indices_info"][0]["stats"], detailed_stats["index1"]) def test_vectorize_documents_success(self): @@ -1531,9 +1567,11 @@ def test_summary_index_name_runtime_error_fallback(self): # Create a mock loop with run_in_executor that returns a coroutine mock_loop = MagicMock() - async def mock_run_in_executor(executor, func, *args): - # Execute the function synchronously and return its result - return func() + def mock_run_in_executor(executor, func, *args): + # run_in_executor returns a coroutine, so we need to create one + async def _execute(): + return func(*args) + return _execute() mock_loop.run_in_executor = mock_run_in_executor @@ -2442,10 +2480,10 @@ def test_get_vdb_core(self): 1. The get_vdb_core function returns the correct elastic_core instance 2. The function is properly imported and accessible """ - from backend.services.vectordatabase_service import get_vector_db_core + from backend.services.vectordatabase_service import get_vector_db_core, VectorDatabaseType - # Execute - result = get_vector_db_core() + # Execute - pass the enum value explicitly since it's a FastAPI Query parameter + result = get_vector_db_core(VectorDatabaseType.ELASTICSEARCH) # Assert self.assertIsNotNone(result) @@ -2718,6 +2756,19 @@ def test_get_vector_db_core_unsupported_type(self): self.assertIn("Unsupported vector database type", str(exc.exception)) + @patch('backend.services.vectordatabase_service.DataMateCore') + def test_get_vector_db_core_datamate(self, mock_datamate_class): + """get_vector_db_core returns DataMateCore when db_type is DATAMATE.""" + from backend.services.vectordatabase_service import get_vector_db_core, VectorDatabaseType, DATAMATE_BASE_URL + + mock_instance = MagicMock() + mock_datamate_class.return_value = mock_instance + + result = get_vector_db_core(VectorDatabaseType.DATAMATE) + + mock_datamate_class.assert_called_once_with(base_url=DATAMATE_BASE_URL) + self.assertIs(result, mock_instance) + def test_rethrow_or_plain_parses_error_code(self): """_rethrow_or_plain rethrows JSON error_code payloads unchanged.""" from backend.services.vectordatabase_service import _rethrow_or_plain diff --git a/test/sdk/core/models/test_openai_llm.py b/test/sdk/core/models/test_openai_llm.py index 6dbc6bc25..1533f5098 100644 --- a/test/sdk/core/models/test_openai_llm.py +++ b/test/sdk/core/models/test_openai_llm.py @@ -5,6 +5,58 @@ # Ensure SDK package is importable by adding sdk/ to sys.path (do not fallback to stubs) sys.path.insert(0, str(Path(__file__).resolve().parents[4] / "sdk")) +# Ensure minimal `nexent` package structure exists in sys.modules so string-based +# patch targets like "nexent.core.models.openai_llm.asyncio.to_thread" can be +# resolved by unittest.mock during tests that run outside the temporary patch +# contexts used below. +_sdk_root = Path(__file__).resolve().parents[4] / "sdk" / "nexent" +if "nexent" not in sys.modules: + _top_pkg = types.ModuleType("nexent") + _top_pkg.__path__ = [str(_sdk_root)] + sys.modules["nexent"] = _top_pkg +if "nexent.core" not in sys.modules: + _core_pkg = types.ModuleType("nexent.core") + _core_pkg.__path__ = [str(_sdk_root / "core")] + sys.modules["nexent.core"] = _core_pkg +if "nexent.core.models" not in sys.modules: + _models_pkg = types.ModuleType("nexent.core.models") + _models_pkg.__path__ = [str(_sdk_root / "core" / "models")] + sys.modules["nexent.core.models"] = _models_pkg + +# Ensure the package attributes exist on the top-level `nexent` module so that +# string-based patch targets (e.g. "nexent.core.models.openai_llm.asyncio.to_thread") +# resolve via getattr during unittest.mock's import lookup. +try: + top_mod = sys.modules.get("nexent") + core_mod = sys.modules.get("nexent.core") + models_mod = sys.modules.get("nexent.core.models") + if top_mod and core_mod and not hasattr(top_mod, "core"): + setattr(top_mod, "core", core_mod) + if core_mod and models_mod and not hasattr(core_mod, "models"): + setattr(core_mod, "models", models_mod) +except Exception: + # If anything goes wrong, do not fail test import phase; the test will create + # the necessary entries later within its patch context. + pass + +# Ensure the concrete openai_llm submodule is available in sys.modules so that +# string-based patch targets resolve outside of temporary patch contexts. +try: + _openai_name = "nexent.core.models.openai_llm" + _openai_path = Path(__file__).resolve().parents[4] / "sdk" / "nexent" / "core" / "models" / "openai_llm.py" + if _openai_path.exists() and _openai_name not in sys.modules: + _spec = importlib.util.spec_from_file_location(_openai_name, _openai_path) + _mod = importlib.util.module_from_spec(_spec) + sys.modules[_openai_name] = _mod + assert _spec and _spec.loader + _spec.loader.exec_module(_mod) + pkg = sys.modules.get("nexent.core.models") + if pkg is not None and not hasattr(pkg, "openai_llm"): + setattr(pkg, "openai_llm", _mod) +except Exception: + # Best-effort only; if this fails tests will still attempt to load/open the module later. + pass + # Dynamically load the openai_llm module to avoid importing full sdk package MODULE_NAME = "nexent.core.models.openai_llm" MODULE_PATH = ( @@ -275,6 +327,15 @@ class MockProcessType: sys.modules[MODULE_NAME] = openai_llm_module assert spec and spec.loader spec.loader.exec_module(openai_llm_module) + # Expose the loaded submodule as an attribute on the package object so that + # string-based patch targets like "nexent.core.models.openai_llm.asyncio.to_thread" + # resolve via getattr during unittest.mock's import lookup. + try: + models_pkg = sys.modules.get("nexent.core.models") + if models_pkg is not None: + setattr(models_pkg, "openai_llm", openai_llm_module) + except Exception: + pass ImportedOpenAIModel = openai_llm_module.OpenAIModel # ----------------------------------------------------------------------- diff --git a/test/sdk/core/tools/test_analyze_text_file_tool.py b/test/sdk/core/tools/test_analyze_text_file_tool.py index 7eab52d89..c0a91e355 100644 --- a/test/sdk/core/tools/test_analyze_text_file_tool.py +++ b/test/sdk/core/tools/test_analyze_text_file_tool.py @@ -1,4 +1,3 @@ -import json from unittest.mock import MagicMock, patch import pytest diff --git a/test/sdk/core/tools/test_datamate_search_tool.py b/test/sdk/core/tools/test_datamate_search_tool.py index ebfdb3bba..83b3b1681 100644 --- a/test/sdk/core/tools/test_datamate_search_tool.py +++ b/test/sdk/core/tools/test_datamate_search_tool.py @@ -2,12 +2,12 @@ from typing import List from unittest.mock import ANY, MagicMock -import httpx import pytest from pytest_mock import MockFixture from sdk.nexent.core.tools.datamate_search_tool import DataMateSearchTool from sdk.nexent.core.utils.observer import MessageObserver, ProcessType +from sdk.nexent.datamate.datamate_client import DataMateClient @pytest.fixture @@ -18,46 +18,47 @@ def mock_observer() -> MessageObserver: @pytest.fixture -def datamate_tool(mock_observer: MessageObserver) -> DataMateSearchTool: - return DataMateSearchTool( +def mock_datamate_client(mocker: MockFixture) -> DataMateClient: + return mocker.MagicMock(spec=DataMateClient) + + +@pytest.fixture +def datamate_tool(mock_observer: MessageObserver, mock_datamate_client: DataMateClient) -> DataMateSearchTool: + tool = DataMateSearchTool( server_ip="127.0.0.1", server_port=8080, observer=mock_observer, ) - - -def _build_kb_list_response(ids: List[str]): - return { - "data": { - "content": [ - {"id": kb_id, "chunkCount": 1} - for kb_id in ids - ] - } - } - - -def _build_search_response(kb_id: str, count: int = 2): - return { - "data": [ - { - "entity": { - "id": f"file-{i}", - "text": f"content-{i}", - "createTime": "2024-01-01T00:00:00Z", - "score": 0.9 - i * 0.1, - "metadata": json.dumps( - { - "file_name": f"file-{i}.txt", - "absolute_directory_path": f"/data/{kb_id}", - } - ), - "scoreDetails": {"raw": 0.8}, - } + # DataMateSearchTool stores a DataMateCore instance which exposes a `client` attribute. + # Set the client's mock on the tool's datamate_core to reflect current implementation. + tool.datamate_core.client = mock_datamate_client + return tool + + +def _build_kb_list(ids: List[str]): + return [{"id": kb_id, "chunkCount": 1} for kb_id in ids] + + +def _build_search_results(kb_id: str, count: int = 2): + return [ + { + "entity": { + "id": f"file-{i}", + "text": f"content-{i}", + "createTime": "2024-01-01T00:00:00Z", + "score": 0.9 - i * 0.1, + "metadata": json.dumps( + { + "file_name": f"file-{i}.txt", + "absolute_directory_path": f"/data/{kb_id}", + "original_file_id": f"orig-{i}", + } + ), + "scoreDetails": {"raw": 0.8}, } - for i in range(count) - ] - } + } + for i in range(count) + ] class TestDataMateSearchToolInit: @@ -74,6 +75,8 @@ def test_init_success(self, mock_observer: MessageObserver): assert tool.kb_page == 0 assert tool.kb_page_size == 20 assert tool.observer is mock_observer + # The tool exposes the DataMate client via datamate_core.client + assert isinstance(tool.datamate_core.client, DataMateClient) @pytest.mark.parametrize("server_ip", ["", None]) def test_init_invalid_server_ip(self, server_ip): @@ -114,256 +117,78 @@ def test_parse_metadata(self, datamate_tool: DataMateSearchTool, metadata_raw, e def test_extract_dataset_id(self, datamate_tool: DataMateSearchTool, path, expected): assert datamate_tool._extract_dataset_id(path) == expected - @pytest.mark.parametrize( - "dataset_id, file_id, expected", - [ - ("ds1", "f1", "http://127.0.0.1:8080/api/data-management/datasets/ds1/files/f1/download"), - ("", "f1", ""), - ("ds1", "", ""), - ], - ) - def test_build_file_download_url(self, datamate_tool: DataMateSearchTool, dataset_id, file_id, expected): - assert datamate_tool._build_file_download_url(dataset_id, file_id) == expected - - -class TestKnowledgeBaseList: - def test_get_knowledge_base_list_success(self, mocker: MockFixture, datamate_tool: DataMateSearchTool): - client_cls = mocker.patch("sdk.nexent.core.tools.datamate_search_tool.httpx.Client") - client = client_cls.return_value.__enter__.return_value - - response = MagicMock() - response.status_code = 200 - response.json.return_value = _build_kb_list_response(["kb1", "kb2"]) - client.post.return_value = response - - kb_ids = datamate_tool._get_knowledge_base_list() - - assert kb_ids == ["kb1", "kb2"] - client.post.assert_called_once_with( - f"{datamate_tool.server_base_url}/api/knowledge-base/list", - json={"page": datamate_tool.kb_page, "size": datamate_tool.kb_page_size}, - ) - - def test_get_knowledge_base_list_http_error_json_detail(self, mocker: MockFixture, datamate_tool: DataMateSearchTool): - client_cls = mocker.patch("sdk.nexent.core.tools.datamate_search_tool.httpx.Client") - client = client_cls.return_value.__enter__.return_value - - response = MagicMock() - response.status_code = 500 - response.headers = {"content-type": "application/json"} - response.json.return_value = {"detail": "server error"} - client.post.return_value = response - - with pytest.raises(Exception) as excinfo: - datamate_tool._get_knowledge_base_list() - - assert "Failed to get knowledge base list" in str(excinfo.value) - - def test_get_knowledge_base_list_http_error_text_detail(self, mocker: MockFixture, datamate_tool: DataMateSearchTool): - client_cls = mocker.patch("sdk.nexent.core.tools.datamate_search_tool.httpx.Client") - client = client_cls.return_value.__enter__.return_value - - response = MagicMock() - response.status_code = 400 - response.headers = {"content-type": "text/plain"} - response.text = "bad request" - client.post.return_value = response - - with pytest.raises(Exception) as excinfo: - datamate_tool._get_knowledge_base_list() - - assert "bad request" in str(excinfo.value) - - def test_get_knowledge_base_list_timeout(self, mocker: MockFixture, datamate_tool: DataMateSearchTool): - client_cls = mocker.patch("sdk.nexent.core.tools.datamate_search_tool.httpx.Client") - client = client_cls.return_value.__enter__.return_value - client.post.side_effect = httpx.TimeoutException("timeout") - - with pytest.raises(Exception) as excinfo: - datamate_tool._get_knowledge_base_list() - - assert "Timeout while getting knowledge base list" in str(excinfo.value) - - def test_get_knowledge_base_list_request_error(self, mocker: MockFixture, datamate_tool: DataMateSearchTool): - client_cls = mocker.patch("sdk.nexent.core.tools.datamate_search_tool.httpx.Client") - client = client_cls.return_value.__enter__.return_value - client.post.side_effect = httpx.RequestError("network", request=MagicMock()) - - with pytest.raises(Exception) as excinfo: - datamate_tool._get_knowledge_base_list() - - assert "Request error while getting knowledge base list" in str(excinfo.value) - - -class TestRetrieveKnowledgeBaseContent: - def test_retrieve_content_success(self, mocker: MockFixture, datamate_tool: DataMateSearchTool): - client_cls = mocker.patch("sdk.nexent.core.tools.datamate_search_tool.httpx.Client") - client = client_cls.return_value.__enter__.return_value - - response = MagicMock() - response.status_code = 200 - response.json.return_value = _build_search_response("kb1", count=2) - client.post.return_value = response - - results = datamate_tool._retrieve_knowledge_base_content( - "query", - ["kb1"], - top_k=3, - threshold=0.2, - ) - - assert len(results) == 2 - client.post.assert_called_once() - - def test_retrieve_content_http_error(self, mocker: MockFixture, datamate_tool: DataMateSearchTool): - client_cls = mocker.patch("sdk.nexent.core.tools.datamate_search_tool.httpx.Client") - client = client_cls.return_value.__enter__.return_value - - response = MagicMock() - response.status_code = 500 - response.headers = {"content-type": "application/json"} - response.json.return_value = {"detail": "server error"} - client.post.return_value = response - - with pytest.raises(Exception) as excinfo: - datamate_tool._retrieve_knowledge_base_content( - "query", - ["kb1"], - top_k=3, - threshold=0.2, - ) - - assert "Failed to retrieve knowledge base content" in str(excinfo.value) - - def test_retrieve_content_timeout(self, mocker: MockFixture, datamate_tool: DataMateSearchTool): - client_cls = mocker.patch("sdk.nexent.core.tools.datamate_search_tool.httpx.Client") - client = client_cls.return_value.__enter__.return_value - client.post.side_effect = httpx.TimeoutException("timeout") - - with pytest.raises(Exception) as excinfo: - datamate_tool._retrieve_knowledge_base_content( - "query", - ["kb1"], - top_k=3, - threshold=0.2, - ) - - assert "Timeout while retrieving knowledge base content" in str(excinfo.value) - - def test_retrieve_content_request_error(self, mocker: MockFixture, datamate_tool: DataMateSearchTool): - client_cls = mocker.patch("sdk.nexent.core.tools.datamate_search_tool.httpx.Client") - client = client_cls.return_value.__enter__.return_value - client.post.side_effect = httpx.RequestError("network", request=MagicMock()) - - with pytest.raises(Exception) as excinfo: - datamate_tool._retrieve_knowledge_base_content( - "query", - ["kb1"], - top_k=3, - threshold=0.2, - ) - - assert "Request error while retrieving knowledge base content" in str(excinfo.value) - class TestForward: - def _setup_success_flow(self, mocker: MockFixture, tool: DataMateSearchTool): - # Mock knowledge base list - client_cls = mocker.patch("sdk.nexent.core.tools.datamate_search_tool.httpx.Client") - client = client_cls.return_value.__enter__.return_value - - kb_response = MagicMock() - kb_response.status_code = 200 - kb_response.json.return_value = _build_kb_list_response(["kb1"]) - - search_response = MagicMock() - search_response.status_code = 200 - search_response.json.return_value = _build_search_response("kb1", count=2) - - # First call for list, second for retrieve - client.post.side_effect = [kb_response, search_response] - return client - - def test_forward_success_with_observer_en(self, mocker: MockFixture, datamate_tool: DataMateSearchTool): - client = self._setup_success_flow(mocker, datamate_tool) + def test_forward_success_with_observer_en(self, datamate_tool: DataMateSearchTool, mock_datamate_client: DataMateClient): + mock_datamate_client.list_knowledge_bases.return_value = _build_kb_list(["kb1"]) + mock_datamate_client.retrieve_knowledge_base.return_value = _build_search_results("kb1", count=2) + mock_datamate_client.build_file_download_url.side_effect = lambda ds, fid: f"http://dl/{ds}/{fid}" result_json = datamate_tool.forward("test query", top_k=2, threshold=0.5) results = json.loads(result_json) assert len(results) == 2 - # Check that observer received running prompt and card - datamate_tool.observer.add_message.assert_any_call( - "", ProcessType.TOOL, datamate_tool.running_prompt_en - ) + datamate_tool.observer.add_message.assert_any_call("", ProcessType.TOOL, datamate_tool.running_prompt_en) datamate_tool.observer.add_message.assert_any_call( "", ProcessType.CARD, json.dumps([{"icon": "search", "text": "test query"}], ensure_ascii=False) ) - # Check that search content message is added (payload content is not strictly validated here) - datamate_tool.observer.add_message.assert_any_call( - "", ProcessType.SEARCH_CONTENT, ANY - ) + datamate_tool.observer.add_message.assert_any_call("", ProcessType.SEARCH_CONTENT, ANY) assert datamate_tool.record_ops == 1 + len(results) - assert all(isinstance(item["index"], str) for item in results) - - # Ensure both list and retrieve endpoints were called - assert client.post.call_count == 2 - def test_forward_success_with_observer_zh(self, mocker: MockFixture, datamate_tool: DataMateSearchTool): + mock_datamate_client.list_knowledge_bases.assert_called_once_with(page=0, size=20) + # Support both positional and keyword invocation styles from DataMate client wrapper. + mock_datamate_client.retrieve_knowledge_base.assert_called_once() + _args, _kwargs = mock_datamate_client.retrieve_knowledge_base.call_args + if _args: + assert _args[0] == "test query" + assert _args[1] == ["kb1"] + assert _args[2] == 2 + assert _args[3] == 0.5 + else: + assert _kwargs["query"] == "test query" + assert _kwargs["knowledge_base_ids"] == ["kb1"] + assert _kwargs["top_k"] == 2 + assert _kwargs["threshold"] == 0.5 + mock_datamate_client.build_file_download_url.assert_any_call("kb1", "orig-0") + + def test_forward_success_with_observer_zh(self, datamate_tool: DataMateSearchTool, mock_datamate_client: DataMateClient): datamate_tool.observer.lang = "zh" - self._setup_success_flow(mocker, datamate_tool) + mock_datamate_client.list_knowledge_bases.return_value = _build_kb_list(["kb1"]) + mock_datamate_client.retrieve_knowledge_base.return_value = _build_search_results("kb1", count=1) + mock_datamate_client.build_file_download_url.return_value = "http://dl/kb1/file-1" datamate_tool.forward("测试查询") - datamate_tool.observer.add_message.assert_any_call( - "", ProcessType.TOOL, datamate_tool.running_prompt_zh - ) + datamate_tool.observer.add_message.assert_any_call("", ProcessType.TOOL, datamate_tool.running_prompt_zh) - def test_forward_no_observer(self, mocker: MockFixture): + def test_forward_no_observer(self, mock_datamate_client: DataMateClient): tool = DataMateSearchTool(server_ip="127.0.0.1", server_port=8080, observer=None) - self._setup_success_flow(mocker, tool) + tool.datamate_core.client = mock_datamate_client + mock_datamate_client.list_knowledge_bases.return_value = _build_kb_list(["kb1"]) + mock_datamate_client.retrieve_knowledge_base.return_value = _build_search_results("kb1", count=1) + mock_datamate_client.build_file_download_url.return_value = "http://dl/kb1/file-1" - # Should not raise and should not call observer result_json = tool.forward("query") - assert len(json.loads(result_json)) == 2 - - def test_forward_no_knowledge_bases(self, mocker: MockFixture, datamate_tool: DataMateSearchTool): - client_cls = mocker.patch("sdk.nexent.core.tools.datamate_search_tool.httpx.Client") - client = client_cls.return_value.__enter__.return_value + assert len(json.loads(result_json)) == 1 - kb_response = MagicMock() - kb_response.status_code = 200 - kb_response.json.return_value = _build_kb_list_response([]) - client.post.return_value = kb_response + def test_forward_no_knowledge_bases(self, datamate_tool: DataMateSearchTool, mock_datamate_client: DataMateClient): + mock_datamate_client.list_knowledge_bases.return_value = [] result = datamate_tool.forward("query") assert result == json.dumps("No knowledge base found. No relevant information found.", ensure_ascii=False) + mock_datamate_client.retrieve_knowledge_base.assert_not_called() - def test_forward_no_results(self, mocker: MockFixture, datamate_tool: DataMateSearchTool): - client_cls = mocker.patch("sdk.nexent.core.tools.datamate_search_tool.httpx.Client") - client = client_cls.return_value.__enter__.return_value - - kb_response = MagicMock() - kb_response.status_code = 200 - kb_response.json.return_value = _build_kb_list_response(["kb1"]) - - search_response = MagicMock() - search_response.status_code = 200 - search_response.json.return_value = {"data": []} - - client.post.side_effect = [kb_response, search_response] + def test_forward_no_results(self, datamate_tool: DataMateSearchTool, mock_datamate_client: DataMateClient): + mock_datamate_client.list_knowledge_bases.return_value = _build_kb_list(["kb1"]) + mock_datamate_client.retrieve_knowledge_base.return_value = [] with pytest.raises(Exception) as excinfo: datamate_tool.forward("query") assert "No results found!" in str(excinfo.value) - def test_forward_wrapped_error(self, mocker: MockFixture, datamate_tool: DataMateSearchTool): - # Simulate error in underlying method to verify top-level error wrapping - mocker.patch.object( - datamate_tool, - "_get_knowledge_base_list", - side_effect=Exception("low level error"), - ) + def test_forward_wrapped_error(self, datamate_tool: DataMateSearchTool, mock_datamate_client: DataMateClient): + mock_datamate_client.list_knowledge_bases.side_effect = RuntimeError("low level error") with pytest.raises(Exception) as excinfo: datamate_tool.forward("query") @@ -372,4 +197,45 @@ def test_forward_wrapped_error(self, mocker: MockFixture, datamate_tool: DataMat assert "Error during DataMate knowledge base search" in msg assert "low level error" in msg + def test_forward_with_index_name_provided(self, datamate_tool: DataMateSearchTool, mock_datamate_client: DataMateClient): + # Mock the hybrid_search method on datamate_core + mock_hybrid_search = MagicMock(return_value=_build_search_results("custom_kb", count=2)) + datamate_tool.datamate_core.hybrid_search = mock_hybrid_search + mock_datamate_client.build_file_download_url.side_effect = lambda ds, fid: f"http://dl/{ds}/{fid}" + + result_json = datamate_tool.forward("test query", index_name=["custom_kb1", "custom_kb2"]) + results = json.loads(result_json) + + assert len(results) == 4 # 2 results per kb + # Should not call list_knowledge_bases when index_name is provided + mock_datamate_client.list_knowledge_bases.assert_not_called() + # Should call hybrid_search twice, once for each index + assert mock_hybrid_search.call_count == 2 + mock_hybrid_search.assert_any_call( + query_text="test query", + index_names=["custom_kb1"], + top_k=10, + weight_accurate=0.2 + ) + mock_hybrid_search.assert_any_call( + query_text="test query", + index_names=["custom_kb2"], + top_k=10, + weight_accurate=0.2 + ) + + def test_forward_with_empty_index_name_list(self, datamate_tool: DataMateSearchTool, mock_datamate_client: DataMateClient): + mock_datamate_client.list_knowledge_bases.return_value = _build_kb_list(["kb1"]) + # Mock the hybrid_search method on datamate_core + mock_hybrid_search = MagicMock(return_value=_build_search_results("kb1", count=1)) + datamate_tool.datamate_core.hybrid_search = mock_hybrid_search + mock_datamate_client.build_file_download_url.return_value = "http://dl/kb1/file-1" + + result_json = datamate_tool.forward("test query", index_name=[]) + results = json.loads(result_json) + assert len(results) == 1 + # Should not call list_knowledge_bases when empty index_name is provided + mock_datamate_client.list_knowledge_bases.assert_not_called() + # Should not call hybrid_search since index_name list is empty + mock_hybrid_search.assert_not_called() diff --git a/test/sdk/datamate/test_datamate_client.py b/test/sdk/datamate/test_datamate_client.py new file mode 100644 index 000000000..78972bf7e --- /dev/null +++ b/test/sdk/datamate/test_datamate_client.py @@ -0,0 +1,615 @@ +import pytest +from unittest.mock import MagicMock + +import httpx +from pytest_mock import MockFixture + +from sdk.nexent.datamate.datamate_client import DataMateClient + + +@pytest.fixture +def client() -> DataMateClient: + return DataMateClient(base_url="http://datamate.local:30000", timeout=1.0) + + +def _mock_response(mocker: MockFixture, status: int, json_data=None, text: str = ""): + response = MagicMock() + response.status_code = status + response.headers = {"content-type": "application/json"} if json_data is not None else {"content-type": "text/plain"} + response.json.return_value = json_data + response.text = text + return response + + +class TestListKnowledgeBases: + def test_success(self, mocker: MockFixture, client: DataMateClient): + client_cls = mocker.patch("sdk.nexent.datamate.datamate_client.httpx.Client") + http_client = client_cls.return_value.__enter__.return_value + http_client.post.return_value = _mock_response( + mocker, + 200, + {"data": {"content": [{"id": "kb1"}, {"id": "kb2"}]}}, + ) + + kbs = client.list_knowledge_bases(page=1, size=10, authorization="token") + + assert len(kbs) == 2 + http_client.post.assert_called_once_with( + "http://datamate.local:30000/api/knowledge-base/list", + json={"page": 1, "size": 10}, + headers={"Authorization": "token"}, + ) + + def test_non_200_json_error(self, mocker: MockFixture, client: DataMateClient): + client_cls = mocker.patch("sdk.nexent.datamate.datamate_client.httpx.Client") + http_client = client_cls.return_value.__enter__.return_value + http_client.post.return_value = _mock_response( + mocker, + 500, + {"detail": "boom"}, + ) + + with pytest.raises(RuntimeError) as excinfo: + client.list_knowledge_bases() + assert "Failed to fetch DataMate knowledge bases" in str(excinfo.value) + + def test_http_error(self, mocker: MockFixture, client: DataMateClient): + client_cls = mocker.patch("sdk.nexent.datamate.datamate_client.httpx.Client") + http_client = client_cls.return_value.__enter__.return_value + http_client.post.side_effect = httpx.HTTPError("network") + + with pytest.raises(RuntimeError): + client.list_knowledge_bases() + + +class TestGetKnowledgeBaseFiles: + def test_success(self, mocker: MockFixture, client: DataMateClient): + client_cls = mocker.patch("sdk.nexent.datamate.datamate_client.httpx.Client") + http_client = client_cls.return_value.__enter__.return_value + http_client.get.return_value = _mock_response( + mocker, + 200, + {"data": {"content": [{"id": "f1"}, {"id": "f2"}]}}, + ) + + files = client.get_knowledge_base_files("kb1") + + assert len(files) == 2 + http_client.get.assert_called_once_with( + "http://datamate.local:30000/api/knowledge-base/kb1/files", + headers={}, + ) + + def test_non_200(self, mocker: MockFixture, client: DataMateClient): + client_cls = mocker.patch("sdk.nexent.datamate.datamate_client.httpx.Client") + http_client = client_cls.return_value.__enter__.return_value + http_client.get.return_value = _mock_response( + mocker, + 404, + {"detail": "not found"}, + ) + + with pytest.raises(RuntimeError): + client.get_knowledge_base_files("kb1") + + def test_http_error(self, mocker: MockFixture, client: DataMateClient): + client_cls = mocker.patch("sdk.nexent.datamate.datamate_client.httpx.Client") + http_client = client_cls.return_value.__enter__.return_value + http_client.get.side_effect = httpx.HTTPError("network") + + with pytest.raises(RuntimeError): + client.get_knowledge_base_files("kb1") + + +class TestRetrieveKnowledgeBase: + def test_success(self, mocker: MockFixture, client: DataMateClient): + client_cls = mocker.patch("sdk.nexent.datamate.datamate_client.httpx.Client") + http_client = client_cls.return_value.__enter__.return_value + http_client.post.return_value = _mock_response( + mocker, + 200, + {"data": [{"entity": {"id": "1"}}, {"entity": {"id": "2"}}]}, + ) + + results = client.retrieve_knowledge_base("q", ["kb1"], top_k=5, threshold=0.1, authorization="auth") + + assert len(results) == 2 + http_client.post.assert_called_once_with( + "http://datamate.local:30000/api/knowledge-base/retrieve", + json={ + "query": "q", + "topK": 5, + "threshold": 0.1, + "knowledgeBaseIds": ["kb1"], + }, + headers={"Authorization": "auth"}, + ) + + def test_non_200(self, mocker: MockFixture, client: DataMateClient): + client_cls = mocker.patch("sdk.nexent.datamate.datamate_client.httpx.Client") + http_client = client_cls.return_value.__enter__.return_value + http_client.post.return_value = _mock_response( + mocker, + 500, + {"detail": "error"}, + ) + + with pytest.raises(RuntimeError): + client.retrieve_knowledge_base("q", ["kb1"]) + + def test_http_error(self, mocker: MockFixture, client: DataMateClient): + client_cls = mocker.patch("sdk.nexent.datamate.datamate_client.httpx.Client") + http_client = client_cls.return_value.__enter__.return_value + http_client.post.side_effect = httpx.HTTPError("network") + + with pytest.raises(RuntimeError): + client.retrieve_knowledge_base("q", ["kb1"]) + + +class TestBuildFileDownloadUrl: + def test_build_url(self, client: DataMateClient): + assert client.build_file_download_url("ds1", "f1") == \ + "http://datamate.local:30000/api/data-management/datasets/ds1/files/f1/download" + + def test_missing_parts(self, client: DataMateClient): + assert client.build_file_download_url("", "f1") == "" + assert client.build_file_download_url("ds1", "") == "" + + +class TestSyncAllKnowledgeBases: + def test_success_and_partial_error(self, mocker: MockFixture, client: DataMateClient): + mocker.patch.object(client, "list_knowledge_bases", return_value=[{"id": "kb1"}, {"id": "kb2"}]) + mocker.patch.object(client, "get_knowledge_base_files", side_effect=[["f1"], RuntimeError("oops")]) + + result = client.sync_all_knowledge_bases() + + assert result["success"] is True + assert result["total_count"] == 2 + assert result["knowledge_bases"][0]["files"] == ["f1"] + assert result["knowledge_bases"][1]["files"] == [] + assert "oops" in result["knowledge_bases"][1]["error"] + + def test_sync_failure(self, mocker: MockFixture, client: DataMateClient): + mocker.patch.object(client, "list_knowledge_bases", side_effect=RuntimeError("boom")) + + result = client.sync_all_knowledge_bases() + + assert result["success"] is False + assert result["total_count"] == 0 + assert "boom" in result["error"] + + +class TestGetKnowledgeBaseInfo: + def test_success(self, mocker: MockFixture, client: DataMateClient): + client_cls = mocker.patch("sdk.nexent.datamate.datamate_client.httpx.Client") + http_client = client_cls.return_value.__enter__.return_value + http_client.get.return_value = _mock_response( + mocker, + 200, + {"data": {"id": "kb1", "name": "KB1"}}, + ) + + kb = client.get_knowledge_base_info("kb1") + + assert isinstance(kb, dict) + assert kb["id"] == "kb1" + http_client.get.assert_called_once_with( + "http://datamate.local:30000/api/knowledge-base/kb1", + headers={}, + ) + + def test_success_with_authorization(self, mocker: MockFixture, client: DataMateClient): + client_cls = mocker.patch("sdk.nexent.datamate.datamate_client.httpx.Client") + http_client = client_cls.return_value.__enter__.return_value + http_client.get.return_value = _mock_response( + mocker, + 200, + {"data": {"id": "kb1", "name": "KB1"}}, + ) + + kb = client.get_knowledge_base_info("kb1", authorization="Bearer token123") + + assert isinstance(kb, dict) + assert kb["id"] == "kb1" + http_client.get.assert_called_once_with( + "http://datamate.local:30000/api/knowledge-base/kb1", + headers={"Authorization": "Bearer token123"}, + ) + + def test_empty_data(self, mocker: MockFixture, client: DataMateClient): + client_cls = mocker.patch("sdk.nexent.datamate.datamate_client.httpx.Client") + http_client = client_cls.return_value.__enter__.return_value + http_client.get.return_value = _mock_response( + mocker, + 200, + {"data": {}}, + ) + + kb = client.get_knowledge_base_info("kb1") + assert kb == {} + + def test_non_200_json_error(self, mocker: MockFixture, client: DataMateClient): + client_cls = mocker.patch("sdk.nexent.datamate.datamate_client.httpx.Client") + http_client = client_cls.return_value.__enter__.return_value + http_client.get.return_value = _mock_response( + mocker, + 500, + {"detail": "boom"}, + text="", + ) + + with pytest.raises(RuntimeError) as excinfo: + client.get_knowledge_base_info("kb1") + + assert "Failed to fetch details for datamate knowledge base kb1" in str(excinfo.value) + assert "Failed to get knowledge base details" in str(excinfo.value) + + def test_non_200_text_error(self, mocker: MockFixture, client: DataMateClient): + client_cls = mocker.patch("sdk.nexent.datamate.datamate_client.httpx.Client") + http_client = client_cls.return_value.__enter__.return_value + # simulate plain text error response + resp = _mock_response(mocker, 404, None, text="not found") + # override headers to be text/plain + resp.headers = {"content-type": "text/plain"} + http_client.get.return_value = resp + + with pytest.raises(RuntimeError) as excinfo: + client.get_knowledge_base_info("kb1") + + assert "Failed to fetch details for datamate knowledge base kb1" in str(excinfo.value) + assert "not found" in str(excinfo.value) + + def test_http_error_raised(self, mocker: MockFixture, client: DataMateClient): + client_cls = mocker.patch("sdk.nexent.datamate.datamate_client.httpx.Client") + http_client = client_cls.return_value.__enter__.return_value + http_client.get.side_effect = httpx.HTTPError("network") + + with pytest.raises(RuntimeError) as excinfo: + client.get_knowledge_base_info("kb1") + + assert "Failed to fetch details for datamate knowledge base kb1" in str(excinfo.value) + assert "network" in str(excinfo.value) + + +class TestBuildHeaders: + """Test the internal _build_headers method.""" + + def test_with_authorization(self, client: DataMateClient): + headers = client._build_headers("Bearer token123") + assert headers == {"Authorization": "Bearer token123"} + + def test_without_authorization(self, client: DataMateClient): + headers = client._build_headers() + assert headers == {} + + def test_with_none_authorization(self, client: DataMateClient): + headers = client._build_headers(None) + assert headers == {} + + +class TestBuildUrl: + """Test the internal _build_url method.""" + + def test_path_with_leading_slash(self, client: DataMateClient): + url = client._build_url("/api/test") + assert url == "http://datamate.local:30000/api/test" + + def test_path_without_leading_slash(self, client: DataMateClient): + url = client._build_url("api/test") + assert url == "http://datamate.local:30000/api/test" + + def test_base_url_without_trailing_slash(self, client: DataMateClient): + # base_url is already stripped of trailing slash in __init__ + url = client._build_url("/api/test") + assert url == "http://datamate.local:30000/api/test" + + +class TestMakeRequest: + """Test the internal _make_request method.""" + + def test_get_request_success(self, mocker: MockFixture, client: DataMateClient): + client_cls = mocker.patch("sdk.nexent.datamate.datamate_client.httpx.Client") + http_client = client_cls.return_value.__enter__.return_value + http_client.get.return_value = _mock_response(mocker, 200, {"result": "ok"}) + + response = client._make_request("GET", "http://test.com/api", {"X-Header": "value"}) + + assert response.status_code == 200 + http_client.get.assert_called_once_with("http://test.com/api", headers={"X-Header": "value"}) + + def test_post_request_success(self, mocker: MockFixture, client: DataMateClient): + client_cls = mocker.patch("sdk.nexent.datamate.datamate_client.httpx.Client") + http_client = client_cls.return_value.__enter__.return_value + http_client.post.return_value = _mock_response(mocker, 200, {"result": "ok"}) + + response = client._make_request( + "POST", "http://test.com/api", {"X-Header": "value"}, json={"key": "value"} + ) + + assert response.status_code == 200 + http_client.post.assert_called_once_with( + "http://test.com/api", json={"key": "value"}, headers={"X-Header": "value"} + ) + + def test_custom_timeout(self, mocker: MockFixture, client: DataMateClient): + client_cls = mocker.patch("sdk.nexent.datamate.datamate_client.httpx.Client") + http_client = client_cls.return_value.__enter__.return_value + http_client.get.return_value = _mock_response(mocker, 200, {"result": "ok"}) + + client._make_request("GET", "http://test.com/api", {}, timeout=5.0) + + # Verify timeout was passed to Client + client_cls.assert_called_once() + call_kwargs = client_cls.call_args[1] + assert call_kwargs["timeout"] == 5.0 + + def test_default_timeout(self, mocker: MockFixture, client: DataMateClient): + client_cls = mocker.patch("sdk.nexent.datamate.datamate_client.httpx.Client") + http_client = client_cls.return_value.__enter__.return_value + http_client.get.return_value = _mock_response(mocker, 200, {"result": "ok"}) + + client._make_request("GET", "http://test.com/api", {}) + + # Verify default timeout (1.0) was used + client_cls.assert_called_once() + call_kwargs = client_cls.call_args[1] + assert call_kwargs["timeout"] == 1.0 + + def test_non_200_status_code(self, mocker: MockFixture, client: DataMateClient): + client_cls = mocker.patch("sdk.nexent.datamate.datamate_client.httpx.Client") + http_client = client_cls.return_value.__enter__.return_value + http_client.get.return_value = _mock_response(mocker, 404, {"detail": "not found"}) + + with pytest.raises(Exception) as excinfo: + client._make_request("GET", "http://test.com/api", {}, error_message="Custom error") + + assert "Custom error" in str(excinfo.value) + assert "404" in str(excinfo.value) + + def test_unsupported_method(self, client: DataMateClient): + with pytest.raises(ValueError) as excinfo: + client._make_request("PUT", "http://test.com/api", {}) + + assert "Unsupported HTTP method: PUT" in str(excinfo.value) + + +class TestHandleErrorResponse: + """Test the internal _handle_error_response method.""" + + def test_json_error_response(self, client: DataMateClient): + response = MagicMock() + response.status_code = 500 + response.headers = {"content-type": "application/json"} + response.json.return_value = {"detail": "Internal server error"} + + with pytest.raises(Exception) as excinfo: + client._handle_error_response(response, "Test error") + + assert "Test error" in str(excinfo.value) + assert "500" in str(excinfo.value) + assert "Internal server error" in str(excinfo.value) + + def test_text_error_response(self, client: DataMateClient): + response = MagicMock() + response.status_code = 404 + response.headers = {"content-type": "text/plain"} + response.text = "Resource not found" + + with pytest.raises(Exception) as excinfo: + client._handle_error_response(response, "Test error") + + assert "Test error" in str(excinfo.value) + assert "404" in str(excinfo.value) + assert "Resource not found" in str(excinfo.value) + + def test_json_error_without_detail(self, client: DataMateClient): + response = MagicMock() + response.status_code = 500 + response.headers = {"content-type": "application/json"} + response.json.return_value = {} + + with pytest.raises(Exception) as excinfo: + client._handle_error_response(response, "Test error") + + assert "Test error" in str(excinfo.value) + assert "unknown error" in str(excinfo.value) + + +class TestListKnowledgeBasesEdgeCases: + """Test edge cases for list_knowledge_bases.""" + + def test_empty_list(self, mocker: MockFixture, client: DataMateClient): + client_cls = mocker.patch("sdk.nexent.datamate.datamate_client.httpx.Client") + http_client = client_cls.return_value.__enter__.return_value + http_client.post.return_value = _mock_response(mocker, 200, {"data": {"content": []}}) + + kbs = client.list_knowledge_bases() + assert kbs == [] + + def test_no_data_field(self, mocker: MockFixture, client: DataMateClient): + client_cls = mocker.patch("sdk.nexent.datamate.datamate_client.httpx.Client") + http_client = client_cls.return_value.__enter__.return_value + http_client.post.return_value = _mock_response(mocker, 200, {}) + + kbs = client.list_knowledge_bases() + assert kbs == [] + + def test_default_parameters(self, mocker: MockFixture, client: DataMateClient): + client_cls = mocker.patch("sdk.nexent.datamate.datamate_client.httpx.Client") + http_client = client_cls.return_value.__enter__.return_value + http_client.post.return_value = _mock_response( + mocker, 200, {"data": {"content": [{"id": "kb1"}]}} + ) + + client.list_knowledge_bases() + + http_client.post.assert_called_once_with( + "http://datamate.local:30000/api/knowledge-base/list", + json={"page": 0, "size": 20}, + headers={}, + ) + + +class TestGetKnowledgeBaseFilesEdgeCases: + """Test edge cases for get_knowledge_base_files.""" + + def test_empty_file_list(self, mocker: MockFixture, client: DataMateClient): + client_cls = mocker.patch("sdk.nexent.datamate.datamate_client.httpx.Client") + http_client = client_cls.return_value.__enter__.return_value + http_client.get.return_value = _mock_response(mocker, 200, {"data": {"content": []}}) + + files = client.get_knowledge_base_files("kb1") + assert files == [] + + def test_no_data_field(self, mocker: MockFixture, client: DataMateClient): + client_cls = mocker.patch("sdk.nexent.datamate.datamate_client.httpx.Client") + http_client = client_cls.return_value.__enter__.return_value + http_client.get.return_value = _mock_response(mocker, 200, {}) + + files = client.get_knowledge_base_files("kb1") + assert files == [] + + def test_with_authorization(self, mocker: MockFixture, client: DataMateClient): + client_cls = mocker.patch("sdk.nexent.datamate.datamate_client.httpx.Client") + http_client = client_cls.return_value.__enter__.return_value + http_client.get.return_value = _mock_response( + mocker, 200, {"data": {"content": [{"id": "f1"}]}} + ) + + client.get_knowledge_base_files("kb1", authorization="Bearer token") + + http_client.get.assert_called_once_with( + "http://datamate.local:30000/api/knowledge-base/kb1/files", + headers={"Authorization": "Bearer token"}, + ) + + +class TestRetrieveKnowledgeBaseEdgeCases: + """Test edge cases for retrieve_knowledge_base.""" + + def test_empty_results(self, mocker: MockFixture, client: DataMateClient): + client_cls = mocker.patch("sdk.nexent.datamate.datamate_client.httpx.Client") + http_client = client_cls.return_value.__enter__.return_value + http_client.post.return_value = _mock_response(mocker, 200, {"data": []}) + + results = client.retrieve_knowledge_base("query", ["kb1"]) + assert results == [] + + def test_no_data_field(self, mocker: MockFixture, client: DataMateClient): + client_cls = mocker.patch("sdk.nexent.datamate.datamate_client.httpx.Client") + http_client = client_cls.return_value.__enter__.return_value + http_client.post.return_value = _mock_response(mocker, 200, {}) + + results = client.retrieve_knowledge_base("query", ["kb1"]) + assert results == [] + + def test_default_parameters(self, mocker: MockFixture, client: DataMateClient): + client_cls = mocker.patch("sdk.nexent.datamate.datamate_client.httpx.Client") + http_client = client_cls.return_value.__enter__.return_value + http_client.post.return_value = _mock_response(mocker, 200, {"data": []}) + + client.retrieve_knowledge_base("query", ["kb1"]) + + http_client.post.assert_called_once_with( + "http://datamate.local:30000/api/knowledge-base/retrieve", + json={ + "query": "query", + "topK": 10, + "threshold": 0.2, + "knowledgeBaseIds": ["kb1"], + }, + headers={}, + ) + + def test_custom_timeout(self, mocker: MockFixture, client: DataMateClient): + client_cls = mocker.patch("sdk.nexent.datamate.datamate_client.httpx.Client") + http_client = client_cls.return_value.__enter__.return_value + http_client.post.return_value = _mock_response(mocker, 200, {"data": []}) + + client.retrieve_knowledge_base("query", ["kb1"]) + + # Verify timeout is doubled for retrieve (1.0 * 2 = 2.0) + client_cls.assert_called_once() + call_kwargs = client_cls.call_args[1] + assert call_kwargs["timeout"] == 2.0 + + def test_multiple_knowledge_base_ids(self, mocker: MockFixture, client: DataMateClient): + client_cls = mocker.patch("sdk.nexent.datamate.datamate_client.httpx.Client") + http_client = client_cls.return_value.__enter__.return_value + http_client.post.return_value = _mock_response(mocker, 200, {"data": []}) + + client.retrieve_knowledge_base("query", ["kb1", "kb2", "kb3"], top_k=5, threshold=0.3) + + http_client.post.assert_called_once_with( + "http://datamate.local:30000/api/knowledge-base/retrieve", + json={ + "query": "query", + "topK": 5, + "threshold": 0.3, + "knowledgeBaseIds": ["kb1", "kb2", "kb3"], + }, + headers={}, + ) + + +class TestSyncAllKnowledgeBasesEdgeCases: + """Test edge cases for sync_all_knowledge_bases.""" + + def test_empty_knowledge_bases_list(self, mocker: MockFixture, client: DataMateClient): + mocker.patch.object(client, "list_knowledge_bases", return_value=[]) + + result = client.sync_all_knowledge_bases() + + assert result["success"] is True + assert result["total_count"] == 0 + assert result["knowledge_bases"] == [] + + def test_all_success(self, mocker: MockFixture, client: DataMateClient): + mocker.patch.object( + client, "list_knowledge_bases", return_value=[{"id": "kb1"}, {"id": "kb2"}] + ) + mocker.patch.object( + client, "get_knowledge_base_files", side_effect=[[{"id": "f1"}], [{"id": "f2"}]] + ) + + result = client.sync_all_knowledge_bases() + + assert result["success"] is True + assert result["total_count"] == 2 + assert len(result["knowledge_bases"][0]["files"]) == 1 + assert len(result["knowledge_bases"][1]["files"]) == 1 + assert "error" not in result["knowledge_bases"][0] + assert "error" not in result["knowledge_bases"][1] + + def test_with_authorization(self, mocker: MockFixture, client: DataMateClient): + list_mock = mocker.patch.object( + client, "list_knowledge_bases", return_value=[{"id": "kb1"}] + ) + files_mock = mocker.patch.object( + client, "get_knowledge_base_files", return_value=[{"id": "f1"}] + ) + + client.sync_all_knowledge_bases(authorization="Bearer token") + + list_mock.assert_called_once_with(authorization="Bearer token") + files_mock.assert_called_once_with("kb1", authorization="Bearer token") + + +class TestClientInitialization: + """Test DataMateClient initialization.""" + + def test_default_timeout(self): + client = DataMateClient(base_url="http://test.com") + assert client.timeout == 30.0 + + def test_custom_timeout(self): + client = DataMateClient(base_url="http://test.com", timeout=5.0) + assert client.timeout == 5.0 + + def test_base_url_stripping(self): + client = DataMateClient(base_url="http://test.com/", timeout=1.0) + assert client.base_url == "http://test.com" + # Verify _build_url works correctly + assert client._build_url("/api/test") == "http://test.com/api/test" + + diff --git a/test/sdk/vector_database/__init__.py b/test/sdk/vector_database/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/sdk/vector_database/test_datamate_core.py b/test/sdk/vector_database/test_datamate_core.py new file mode 100644 index 000000000..70c79dc73 --- /dev/null +++ b/test/sdk/vector_database/test_datamate_core.py @@ -0,0 +1,157 @@ +import pytest +from unittest.mock import MagicMock, patch +from datetime import datetime + +from sdk.nexent.vector_database import datamate_core + + +def test_parse_timestamp_variants(): + # None -> default + assert datamate_core._parse_timestamp(None, default=7) == 7 + + # Integer already in milliseconds + ms = 1600000000000 + assert datamate_core._parse_timestamp(ms) == ms + + # Integer in seconds (less than 1e10) should be converted to ms + seconds = 1600000000 + assert datamate_core._parse_timestamp(seconds) == seconds * 1000 + + # ISO8601 string with Z + iso = "2020-09-13T12:00:00Z" + expected = int(datetime.fromisoformat(iso.replace("Z", "+00:00")).timestamp() * 1000) + assert datamate_core._parse_timestamp(iso) == expected + + # Numeric string representing seconds + assert datamate_core._parse_timestamp("123456") == 123456 * 1000 + + # Invalid string -> default + assert datamate_core._parse_timestamp("not-a-ts", default=11) == 11 + + +@patch("sdk.nexent.vector_database.datamate_core.DataMateClient") +def test_user_indices_and_count(mock_client_cls): + mock_client = MagicMock() + mock_client.list_knowledge_bases.return_value = [{"id": 1}, {"no_id": True}, {"id": "2"}] + mock_client.get_knowledge_base_files.return_value = [{"fileName": "a"}, {"fileName": "b"}] + mock_client_cls.return_value = mock_client + + core = datamate_core.DataMateCore(base_url="http://example") + + # get_user_indices filters out entries without id and returns string ids + assert core.get_user_indices() == ["1", "2"] + + # check_index_exists uses get_user_indices + assert core.check_index_exists("1") is True + assert core.check_index_exists("missing") is False + + # get_index_chunks and count_documents rely on get_knowledge_base_files + chunks = core.get_index_chunks("1") + assert isinstance(chunks, dict) + assert chunks["total"] == 2 + assert core.count_documents("1") == 2 + + +@patch("sdk.nexent.vector_database.datamate_core.DataMateClient") +def test_hybrid_search_and_retrieve(mock_client_cls): + mock_client = MagicMock() + mock_client.retrieve_knowledge_base.return_value = [{"id": "res1"}] + mock_client_cls.return_value = mock_client + + core = datamate_core.DataMateCore(base_url="http://example") + res = core.hybrid_search(["kb1"], "query", embedding_model=None, top_k=2, weight_accurate=0.1) + assert res == [{"id": "res1"}] + mock_client.retrieve_knowledge_base.assert_called_once_with("query", ["kb1"], 2, 0.1) + + +@patch("sdk.nexent.vector_database.datamate_core.DataMateClient") +def test_get_documents_detail_parsing(mock_client_cls): + mock_client = MagicMock() + mock_client.get_knowledge_base_files.return_value = [ + { + "path_or_url": "s3://bucket/file.txt", + "fileName": "file.txt", + "fileSize": 12345, + "createdAt": "2021-01-01T00:00:00Z", + "chunkCount": 3, + "errMsg": "no error", + } + ] + mock_client_cls.return_value = mock_client + + core = datamate_core.DataMateCore(base_url="http://example") + details = core.get_documents_detail("kb1") + assert isinstance(details, list) and len(details) == 1 + d = details[0] + assert d["file"] == "file.txt" + assert d["file_size"] == 12345 + assert d["chunk_count"] == 3 + assert isinstance(d["create_time"], int) and d["create_time"] > 0 + assert d["error_reason"] == "no error" + + +@patch("sdk.nexent.vector_database.datamate_core.DataMateClient") +def test_get_indices_detail_success_and_error(mock_client_cls): + mock_client = MagicMock() + + def side_effect_get_info(kb_id): + if kb_id == "bad": + raise RuntimeError("boom") + return { + "fileCount": 10, + "name": "KnowledgeBaseName", + "chunkCount": 20, + "storeSize": 999, + "processSource": "Unstructured", + "embedding": {"modelName": "embed-v1"}, + "createdAt": "2022-01-01T00:00:00Z", + "updatedAt": "2022-02-01T00:00:00Z", + } + + mock_client.get_knowledge_base_info.side_effect = side_effect_get_info + mock_client_cls.return_value = mock_client + + core = datamate_core.DataMateCore(base_url="http://example") + details, names = core.get_indices_detail(["good", "bad"], embedding_dim=512) + + # success case + assert "good" in details + assert details["good"]["base_info"]["embedding_model"] == "embed-v1" + assert details["good"]["base_info"]["embedding_dim"] == 512 + assert "KnowledgeBaseName" in names + + # error case + assert "bad" in details + assert "error" in details["bad"] + + +@patch("sdk.nexent.vector_database.datamate_core.DataMateClient") +def test_not_implemented_methods_raise(mock_client_cls): + mock_client_cls.return_value = MagicMock() + core = datamate_core.DataMateCore(base_url="http://example") + + # Methods that are intentionally not implemented should raise NotImplementedError + with pytest.raises(NotImplementedError): + core.create_index("i") + with pytest.raises(NotImplementedError): + core.delete_index("i") + with pytest.raises(NotImplementedError): + core.vectorize_documents("i", None, []) + with pytest.raises(NotImplementedError): + core.delete_documents("i", "path") + with pytest.raises(NotImplementedError): + core.create_chunk("i", {}) + with pytest.raises(NotImplementedError): + core.update_chunk("i", "cid", {}) + with pytest.raises(NotImplementedError): + core.delete_chunk("i", "cid") + with pytest.raises(NotImplementedError): + core.search("i", {}) + with pytest.raises(NotImplementedError): + core.multi_search([], "i") + with pytest.raises(NotImplementedError): + core.accurate_search(["i"], "q") + with pytest.raises(NotImplementedError): + core.semantic_search(["i"], "q", None) + + diff --git a/test/sdk/vector_database/test_elasticsearch_core.py b/test/sdk/vector_database/test_elasticsearch_core.py index f9f878852..40b29853a 100644 --- a/test/sdk/vector_database/test_elasticsearch_core.py +++ b/test/sdk/vector_database/test_elasticsearch_core.py @@ -7,7 +7,6 @@ # Import the class under test from sdk.nexent.vector_database.elasticsearch_core import ElasticSearchCore - # ---------------------------------------------------------------------------- # Fixtures # ---------------------------------------------------------------------------- @@ -56,12 +55,12 @@ def test_preprocess_documents_with_complete_document(elasticsearch_core_instance # Use the second document which has all fields complete_doc = [sample_documents[1]] content_field = "content" - + result = elasticsearch_core_instance._preprocess_documents(complete_doc, content_field) - + assert len(result) == 1 doc = result[0] - + # Should preserve existing values assert doc["content"] == "This is test content 2" assert doc["title"] == "Test Document 2" @@ -79,33 +78,33 @@ def test_preprocess_documents_with_incomplete_document(elasticsearch_core_instan # Use the first document which is missing several fields incomplete_doc = [sample_documents[0]] content_field = "content" - + with patch('time.strftime') as mock_strftime, \ patch('time.time') as mock_time, \ patch('time.gmtime') as mock_gmtime: - + # Mock time functions mock_strftime.side_effect = lambda fmt, t: "2025-01-15T10:30:00" if "T" in fmt else "2025-01-15" mock_time.return_value = 1642234567 mock_gmtime.return_value = None - + result = elasticsearch_core_instance._preprocess_documents(incomplete_doc, content_field) - + assert len(result) == 1 doc = result[0] - + # Should preserve existing values assert doc["content"] == "This is test content 1" assert doc["title"] == "Test Document 1" assert doc["filename"] == "test1.pdf" assert doc["path_or_url"] == "/path/to/test1.pdf" - + # Should add missing fields with default values assert doc["create_time"] == "2025-01-15T10:30:00" assert doc["date"] == "2025-01-15" assert doc["file_size"] == 0 assert doc["process_source"] == "Unstructured" - + # Should generate an ID assert "id" in doc assert doc["id"].startswith("1642234567_") @@ -115,20 +114,20 @@ def test_preprocess_documents_with_incomplete_document(elasticsearch_core_instan def test_preprocess_documents_with_multiple_documents(elasticsearch_core_instance, sample_documents): """Test preprocessing multiple documents.""" content_field = "content" - + with patch('time.strftime') as mock_strftime, \ patch('time.time') as mock_time, \ patch('time.gmtime') as mock_gmtime: - + # Mock time functions mock_strftime.side_effect = lambda fmt, t: "2025-01-15T10:30:00" if "T" in fmt else "2025-01-15" mock_time.return_value = 1642234567 mock_gmtime.return_value = None - + result = elasticsearch_core_instance._preprocess_documents(sample_documents, content_field) - + assert len(result) == 2 - + # First document should have defaults added doc1 = result[0] assert doc1["create_time"] == "2025-01-15T10:30:00" @@ -136,7 +135,7 @@ def test_preprocess_documents_with_multiple_documents(elasticsearch_core_instanc assert doc1["file_size"] == 0 assert doc1["process_source"] == "Unstructured" assert "id" in doc1 - + # Second document should preserve existing values doc2 = result[1] assert doc2["create_time"] == "2025-01-15T10:30:00" @@ -155,20 +154,20 @@ def test_preprocess_documents_preserves_original_data(elasticsearch_core_instanc } ] content_field = "content" - + with patch('time.strftime') as mock_strftime, \ patch('time.time') as mock_time, \ patch('time.gmtime') as mock_gmtime: - + mock_strftime.side_effect = lambda fmt, t: "2025-01-15T10:30:00" if "T" in fmt else "2025-01-15" mock_time.return_value = 1642234567 mock_gmtime.return_value = None - + result = elasticsearch_core_instance._preprocess_documents(original_docs, content_field) - + # Original document should remain unchanged assert original_docs[0] == {"content": "Original content", "title": "Original title"} - + # Result should be a new document with added fields assert result[0]["content"] == "Original content" assert result[0]["title"] == "Original title" @@ -182,9 +181,9 @@ def test_preprocess_documents_preserves_original_data(elasticsearch_core_instanc def test_preprocess_documents_with_empty_list(elasticsearch_core_instance): """Test preprocessing an empty list of documents.""" content_field = "content" - + result = elasticsearch_core_instance._preprocess_documents([], content_field) - + assert result == [] @@ -196,27 +195,27 @@ def test_preprocess_documents_id_generation(elasticsearch_core_instance): {"content": "Content 1"} # Same content as first ] content_field = "content" - + with patch('time.strftime') as mock_strftime, \ patch('time.time') as mock_time, \ patch('time.gmtime') as mock_gmtime: - + mock_strftime.side_effect = lambda fmt, t: "2025-01-15T10:30:00" if "T" in fmt else "2025-01-15" mock_time.return_value = 1642234567 mock_gmtime.return_value = None - + result = elasticsearch_core_instance._preprocess_documents(docs, content_field) - + assert len(result) == 3 - + # All documents should have IDs assert "id" in result[0] assert "id" in result[1] assert "id" in result[2] - + # IDs should be different for different content assert result[0]["id"] != result[1]["id"] - + # Same content should generate same hash part (but might be different due to time) id1_parts = result[0]["id"].split("_") id3_parts = result[2]["id"].split("_") @@ -237,19 +236,19 @@ def test_preprocess_documents_with_none_values(elasticsearch_core_instance): } ] content_field = "content" - + with patch('time.strftime') as mock_strftime, \ patch('time.time') as mock_time, \ patch('time.gmtime') as mock_gmtime: - + mock_strftime.side_effect = lambda fmt, t: "2025-01-15T10:30:00" if "T" in fmt else "2025-01-15" mock_time.return_value = 1642234567 mock_gmtime.return_value = None - + result = elasticsearch_core_instance._preprocess_documents(docs, content_field) - + doc = result[0] - + # None values should be replaced with defaults assert doc["file_size"] == 0 assert doc["create_time"] == "2025-01-15T10:30:00" @@ -270,19 +269,19 @@ def test_preprocess_documents_with_zero_values(elasticsearch_core_instance): } ] content_field = "content" - + with patch('time.strftime') as mock_strftime, \ patch('time.time') as mock_time, \ patch('time.gmtime') as mock_gmtime: - + mock_strftime.side_effect = lambda fmt, t: "2025-01-15T10:30:00" if "T" in fmt else "2025-01-15" mock_time.return_value = 1642234567 mock_gmtime.return_value = None - + result = elasticsearch_core_instance._preprocess_documents(docs, content_field) - + doc = result[0] - + # Zero values should be preserved assert doc["file_size"] == 0 assert doc["create_time"] == "2025-01-15T10:30:00" @@ -760,12 +759,12 @@ def test_create_chunk_exception(elasticsearch_core_instance): """Test create_chunk raises exception when client.index fails.""" elasticsearch_core_instance.client = MagicMock() elasticsearch_core_instance.client.index.side_effect = Exception("Index operation failed") - + payload = {"id": "chunk-1", "content": "A"} - + with pytest.raises(Exception) as exc_info: elasticsearch_core_instance.create_chunk("kb-index", payload) - + assert "Index operation failed" in str(exc_info.value) elasticsearch_core_instance.client.index.assert_called_once() @@ -779,10 +778,10 @@ def test_update_chunk_exception_from_resolve(elasticsearch_core_instance): side_effect=Exception("Resolve failed"), ): updates = {"content": "updated"} - + with pytest.raises(Exception) as exc_info: elasticsearch_core_instance.update_chunk("kb-index", "chunk-1", updates) - + assert "Resolve failed" in str(exc_info.value) elasticsearch_core_instance.client.update.assert_not_called() @@ -796,12 +795,12 @@ def test_update_chunk_exception_from_update(elasticsearch_core_instance): return_value="es-id-1", ): elasticsearch_core_instance.client.update.side_effect = Exception("Update operation failed") - + updates = {"content": "updated"} - + with pytest.raises(Exception) as exc_info: elasticsearch_core_instance.update_chunk("kb-index", "chunk-1", updates) - + assert "Update operation failed" in str(exc_info.value) elasticsearch_core_instance.client.update.assert_called_once() @@ -816,7 +815,7 @@ def test_delete_chunk_exception_from_resolve(elasticsearch_core_instance): ): with pytest.raises(Exception) as exc_info: elasticsearch_core_instance.delete_chunk("kb-index", "chunk-1") - + assert "Resolve failed" in str(exc_info.value) elasticsearch_core_instance.client.delete.assert_not_called() @@ -830,10 +829,10 @@ def test_delete_chunk_exception_from_delete(elasticsearch_core_instance): return_value="es-id-1", ): elasticsearch_core_instance.client.delete.side_effect = Exception("Delete operation failed") - + with pytest.raises(Exception) as exc_info: elasticsearch_core_instance.delete_chunk("kb-index", "chunk-1") - + assert "Delete operation failed" in str(exc_info.value) elasticsearch_core_instance.client.delete.assert_called_once() diff --git a/test/sdk/vector_database/test_elasticsearch_core_coverage.py b/test/sdk/vector_database/test_elasticsearch_core_coverage.py deleted file mode 100644 index 757bbc566..000000000 --- a/test/sdk/vector_database/test_elasticsearch_core_coverage.py +++ /dev/null @@ -1,731 +0,0 @@ -""" -Supplementary test module for elasticsearch_core to improve code coverage - -Tests for functions not fully covered in the main test file. -""" -import pytest -from unittest.mock import MagicMock, patch, mock_open -import time -import os -import sys -from typing import List, Dict, Any -from datetime import datetime, timedelta - -# Add the project root to the path -current_dir = os.path.dirname(os.path.abspath(__file__)) -project_root = os.path.abspath(os.path.join(current_dir, "../../..")) -sys.path.insert(0, project_root) - -# Import the class under test -from sdk.nexent.vector_database.elasticsearch_core import ElasticSearchCore, BulkOperation -from elasticsearch import exceptions - - -class TestElasticSearchCoreCoverage: - """Test class for improving elasticsearch_core coverage""" - - @pytest.fixture - def vdb_core(self): - """Create an ElasticSearchCore instance for testing.""" - return ElasticSearchCore( - host="http://localhost:9200", - api_key="test_api_key", - verify_certs=False, - ssl_show_warn=False - ) - - def test_force_refresh_with_retry_success(self, vdb_core): - """Test _force_refresh_with_retry successful refresh""" - vdb_core.client = MagicMock() - vdb_core.client.indices.refresh.return_value = {"_shards": {"total": 1, "successful": 1}} - - result = vdb_core._force_refresh_with_retry("test_index") - assert result is True - vdb_core.client.indices.refresh.assert_called_once_with(index="test_index") - - def test_force_refresh_with_retry_failure_retry(self, vdb_core): - """Test _force_refresh_with_retry with retries""" - vdb_core.client = MagicMock() - vdb_core.client.indices.refresh.side_effect = [ - Exception("Connection error"), - Exception("Still failing"), - {"_shards": {"total": 1, "successful": 1}} - ] - - with patch('time.sleep'): # Mock sleep to speed up test - result = vdb_core._force_refresh_with_retry("test_index", max_retries=3) - assert result is True - assert vdb_core.client.indices.refresh.call_count == 3 - - def test_force_refresh_with_retry_max_retries_exceeded(self, vdb_core): - """Test _force_refresh_with_retry when max retries exceeded""" - vdb_core.client = MagicMock() - vdb_core.client.indices.refresh.side_effect = Exception("Persistent error") - - with patch('time.sleep'): # Mock sleep to speed up test - result = vdb_core._force_refresh_with_retry("test_index", max_retries=2) - assert result is False - assert vdb_core.client.indices.refresh.call_count == 2 - - def test_ensure_index_ready_success(self, vdb_core): - """Test _ensure_index_ready successful case""" - vdb_core.client = MagicMock() - vdb_core.client.cluster.health.return_value = {"status": "green"} - vdb_core.client.search.return_value = {"hits": {"total": {"value": 0}}} - - result = vdb_core._ensure_index_ready("test_index") - assert result is True - - def test_ensure_index_ready_yellow_status(self, vdb_core): - """Test _ensure_index_ready with yellow status""" - vdb_core.client = MagicMock() - vdb_core.client.cluster.health.return_value = {"status": "yellow"} - vdb_core.client.search.return_value = {"hits": {"total": {"value": 0}}} - - result = vdb_core._ensure_index_ready("test_index") - assert result is True - - def test_ensure_index_ready_timeout(self, vdb_core): - """Test _ensure_index_ready timeout scenario""" - vdb_core.client = MagicMock() - vdb_core.client.cluster.health.return_value = {"status": "red"} - - with patch('time.sleep'): # Mock sleep to speed up test - result = vdb_core._ensure_index_ready("test_index", timeout=1) - assert result is False - - def test_ensure_index_ready_exception(self, vdb_core): - """Test _ensure_index_ready with exception""" - vdb_core.client = MagicMock() - vdb_core.client.cluster.health.side_effect = Exception("Connection error") - - with patch('time.sleep'): # Mock sleep to speed up test - result = vdb_core._ensure_index_ready("test_index", timeout=1) - assert result is False - - def test_apply_bulk_settings_success(self, vdb_core): - """Test _apply_bulk_settings successful case""" - vdb_core.client = MagicMock() - vdb_core.client.indices.put_settings.return_value = {"acknowledged": True} - - vdb_core._apply_bulk_settings("test_index") - vdb_core.client.indices.put_settings.assert_called_once() - - def test_apply_bulk_settings_failure(self, vdb_core): - """Test _apply_bulk_settings with exception""" - vdb_core.client = MagicMock() - vdb_core.client.indices.put_settings.side_effect = Exception("Settings error") - - # Should not raise exception, just log warning - vdb_core._apply_bulk_settings("test_index") - vdb_core.client.indices.put_settings.assert_called_once() - - def test_restore_normal_settings_success(self, vdb_core): - """Test _restore_normal_settings successful case""" - vdb_core.client = MagicMock() - vdb_core.client.indices.put_settings.return_value = {"acknowledged": True} - vdb_core._force_refresh_with_retry = MagicMock(return_value=True) - - vdb_core._restore_normal_settings("test_index") - vdb_core.client.indices.put_settings.assert_called_once() - vdb_core._force_refresh_with_retry.assert_called_once_with("test_index") - - def test_restore_normal_settings_failure(self, vdb_core): - """Test _restore_normal_settings with exception""" - vdb_core.client = MagicMock() - vdb_core.client.indices.put_settings.side_effect = Exception("Settings error") - - # Should not raise exception, just log warning - vdb_core._restore_normal_settings("test_index") - vdb_core.client.indices.put_settings.assert_called_once() - - def test_delete_index_success(self, vdb_core): - """Test delete_index successful case""" - vdb_core.client = MagicMock() - vdb_core.client.indices.delete.return_value = {"acknowledged": True} - - result = vdb_core.delete_index("test_index") - assert result is True - vdb_core.client.indices.delete.assert_called_once_with(index="test_index") - - def test_delete_index_not_found(self, vdb_core): - """Test delete_index when index not found""" - vdb_core.client = MagicMock() - # Create a proper NotFoundError with required parameters - not_found_error = exceptions.NotFoundError(404, "Index not found", {"error": {"type": "index_not_found_exception"}}) - vdb_core.client.indices.delete.side_effect = not_found_error - - result = vdb_core.delete_index("test_index") - assert result is False - vdb_core.client.indices.delete.assert_called_once_with(index="test_index") - - def test_delete_index_general_exception(self, vdb_core): - """Test delete_index with general exception""" - vdb_core.client = MagicMock() - vdb_core.client.indices.delete.side_effect = Exception("General error") - - result = vdb_core.delete_index("test_index") - assert result is False - vdb_core.client.indices.delete.assert_called_once_with(index="test_index") - - def test_handle_bulk_errors_no_errors(self, vdb_core): - """Test _handle_bulk_errors when no errors in response""" - response = {"errors": False, "items": []} - vdb_core._handle_bulk_errors(response) - # Should not raise any exceptions - - def test_handle_bulk_errors_with_version_conflict(self, vdb_core): - """Test _handle_bulk_errors with version conflict (should be ignored)""" - response = { - "errors": True, - "items": [ - { - "index": { - "error": { - "type": "version_conflict_engine_exception", - "reason": "Document already exists", - "caused_by": { - "type": "version_conflict", - "reason": "Document version conflict" - } - } - } - } - ] - } - vdb_core._handle_bulk_errors(response) - # Should not raise any exceptions for version conflicts - - def test_handle_bulk_errors_with_fatal_error(self, vdb_core): - """Test _handle_bulk_errors with fatal error""" - response = { - "errors": True, - "items": [ - { - "index": { - "error": { - "type": "mapper_parsing_exception", - "reason": "Failed to parse field", - "caused_by": { - "type": "json_parse_exception", - "reason": "Unexpected character" - } - } - } - } - ] - } - with pytest.raises(Exception) as exc_info: - vdb_core._handle_bulk_errors(response) - assert "Bulk indexing failed" in str(exc_info.value) - - def test_handle_bulk_errors_with_caused_by(self, vdb_core): - """Test _handle_bulk_errors with caused_by information""" - response = { - "errors": True, - "items": [ - { - "index": { - "error": { - "type": "illegal_argument_exception", - "reason": "Invalid argument", - "caused_by": { - "type": "json_parse_exception", - "reason": "JSON parsing failed" - } - } - } - } - ] - } - with pytest.raises(Exception) as exc_info: - vdb_core._handle_bulk_errors(response) - assert "Invalid argument" in str(exc_info.value) - assert "JSON parsing failed" in str(exc_info.value) - - def test_delete_documents_success(self, vdb_core): - """Test delete_documents successful case""" - vdb_core.client = MagicMock() - vdb_core.client.delete_by_query.return_value = {"deleted": 5} - - result = vdb_core.delete_documents("test_index", "/path/to/file.pdf") - assert result == 5 - vdb_core.client.delete_by_query.assert_called_once() - - def test_delete_documents_exception(self, vdb_core): - """Test delete_documents with exception""" - vdb_core.client = MagicMock() - vdb_core.client.delete_by_query.side_effect = Exception("Delete error") - - result = vdb_core.delete_documents("test_index", "/path/to/file.pdf") - assert result == 0 - vdb_core.client.delete_by_query.assert_called_once() - - def test_get_index_chunks_not_found(self, vdb_core): - """Ensure get_index_chunks handles missing index gracefully.""" - vdb_core.client = MagicMock() - vdb_core.client.count.side_effect = exceptions.NotFoundError( - 404, "missing", {}) - - result = vdb_core.get_index_chunks("missing-index") - - assert result == {"chunks": [], "total": 0, - "page": None, "page_size": None} - vdb_core.client.clear_scroll.assert_not_called() - - def test_get_index_chunks_cleanup_warning(self, vdb_core): - """Ensure clear_scroll errors are swallowed.""" - vdb_core.client = MagicMock() - vdb_core.client.count.return_value = {"count": 1} - vdb_core.client.search.return_value = { - "_scroll_id": "scroll123", - "hits": {"hits": [{"_id": "doc-1", "_source": {"content": "A"}}]} - } - vdb_core.client.scroll.return_value = { - "_scroll_id": "scroll123", - "hits": {"hits": []} - } - vdb_core.client.clear_scroll.side_effect = Exception("cleanup-failed") - - result = vdb_core.get_index_chunks("kb-index") - - assert len(result["chunks"]) == 1 - assert result["chunks"][0]["id"] == "doc-1" - vdb_core.client.clear_scroll.assert_called_once_with( - scroll_id="scroll123") - - def test_create_index_request_error_existing(self, vdb_core): - """Ensure RequestError with resource already exists still succeeds.""" - vdb_core.client = MagicMock() - vdb_core.client.indices.exists.return_value = False - meta = MagicMock(status=400) - vdb_core.client.indices.create.side_effect = exceptions.RequestError( - "resource_already_exists_exception", meta, {"error": {"reason": "exists"}} - ) - vdb_core._ensure_index_ready = MagicMock(return_value=True) - - assert vdb_core.create_index("test_index") is True - vdb_core._ensure_index_ready.assert_called_once_with("test_index") - - def test_create_index_request_error_failure(self, vdb_core): - """Ensure create_index returns False for non recoverable RequestError.""" - vdb_core.client = MagicMock() - vdb_core.client.indices.exists.return_value = False - meta = MagicMock(status=400) - vdb_core.client.indices.create.side_effect = exceptions.RequestError( - "validation_exception", meta, {"error": {"reason": "bad"}} - ) - - assert vdb_core.create_index("test_index") is False - - def test_create_index_general_exception(self, vdb_core): - """Ensure unexpected exception from create_index returns False.""" - vdb_core.client = MagicMock() - vdb_core.client.indices.exists.return_value = False - vdb_core.client.indices.create.side_effect = Exception("boom") - - assert vdb_core.create_index("test_index") is False - - def test_force_refresh_with_retry_zero_attempts(self, vdb_core): - """Ensure guard clause without attempts returns False.""" - vdb_core.client = MagicMock() - result = vdb_core._force_refresh_with_retry("idx", max_retries=0) - assert result is False - - def test_bulk_operation_context_preexisting_operation(self, vdb_core): - """Ensure context skips apply/restore when operations remain.""" - existing = BulkOperation( - index_name="test_index", - operation_id="existing", - start_time=datetime.utcnow(), - expected_duration=timedelta(seconds=30), - ) - vdb_core._bulk_operations = {"test_index": [existing]} - - with patch.object(vdb_core, "_apply_bulk_settings") as mock_apply, \ - patch.object(vdb_core, "_restore_normal_settings") as mock_restore: - - with vdb_core.bulk_operation_context("test_index") as op_id: - assert op_id != existing.operation_id - - mock_apply.assert_not_called() - mock_restore.assert_not_called() - assert vdb_core._bulk_operations["test_index"] == [existing] - - def test_get_user_indices_exception(self, vdb_core): - """Ensure get_user_indices returns empty list on failure.""" - vdb_core.client = MagicMock() - vdb_core.client.indices.get_alias.side_effect = Exception("failure") - - assert vdb_core.get_user_indices() == [] - - def test_check_index_exists(self, vdb_core): - """Ensure check_index_exists delegates to client.""" - vdb_core.client = MagicMock() - vdb_core.client.indices.exists.return_value = True - - assert vdb_core.check_index_exists("idx") is True - vdb_core.client.indices.exists.assert_called_once_with(index="idx") - - def test_small_batch_insert_sets_embedding_model_name(self, vdb_core): - """_small_batch_insert should attach embedding model name.""" - vdb_core.client = MagicMock() - vdb_core.client.bulk.return_value = {"errors": False, "items": []} - vdb_core._preprocess_documents = MagicMock(return_value=[{"content": "body"}]) - vdb_core._handle_bulk_errors = MagicMock() - - mock_embedding_model = MagicMock() - mock_embedding_model.get_embeddings.return_value = [[0.1, 0.2]] - mock_embedding_model.embedding_model_name = "demo-model" - - vdb_core._small_batch_insert("idx", [{"content": "body"}], "content", mock_embedding_model) - operations = vdb_core.client.bulk.call_args.kwargs["operations"] - inserted_doc = operations[1] - assert inserted_doc["embedding_model_name"] == "demo-model" - - def test_large_batch_insert_sets_default_embedding_model_name(self, vdb_core): - """_large_batch_insert should fall back to 'unknown' when attr missing.""" - vdb_core.client = MagicMock() - vdb_core.client.bulk.return_value = {"errors": False, "items": []} - vdb_core._preprocess_documents = MagicMock(return_value=[{"content": "body"}]) - vdb_core._handle_bulk_errors = MagicMock() - - class SimpleEmbedding: - def get_embeddings(self, texts): - return [[0.1 for _ in texts]] - - embedding_model = SimpleEmbedding() - - vdb_core._large_batch_insert("idx", [{"content": "body"}], 10, "content", embedding_model) - operations = vdb_core.client.bulk.call_args.kwargs["operations"] - inserted_doc = operations[1] - assert inserted_doc["embedding_model_name"] == "unknown" - - def test_large_batch_insert_bulk_exception(self, vdb_core): - """Ensure bulk exceptions are handled and indexing continues.""" - vdb_core.client = MagicMock() - vdb_core.client.bulk.side_effect = Exception("bulk error") - vdb_core._preprocess_documents = MagicMock(return_value=[{"content": "body"}]) - - mock_embedding_model = MagicMock() - mock_embedding_model.get_embeddings.return_value = [[0.1]] - - with pytest.raises(Exception) as exc_info: - vdb_core._large_batch_insert("idx", [{"content": "body"}], 1, "content", mock_embedding_model) - assert "bulk error" in str(exc_info.value) - - def test_large_batch_insert_preprocess_exception(self, vdb_core): - """Ensure outer exception handler returns zero on preprocess failure.""" - vdb_core._preprocess_documents = MagicMock(side_effect=Exception("fail")) - - mock_embedding_model = MagicMock() - with pytest.raises(Exception) as exc_info: - vdb_core._large_batch_insert("idx", [{"content": "body"}], 10, "content", mock_embedding_model) - assert "fail" in str(exc_info.value) - - def test_count_documents_success(self, vdb_core): - """Ensure count_documents returns ES count.""" - vdb_core.client = MagicMock() - vdb_core.client.count.return_value = {"count": 42} - - assert vdb_core.count_documents("idx") == 42 - - def test_count_documents_exception(self, vdb_core): - """Ensure count_documents returns zero on error.""" - vdb_core.client = MagicMock() - vdb_core.client.count.side_effect = Exception("fail") - - assert vdb_core.count_documents("idx") == 0 - - def test_search_and_multi_search_passthrough(self, vdb_core): - """Ensure search helpers delegate to the client.""" - vdb_core.client = MagicMock() - vdb_core.client.search.return_value = {"hits": {}} - vdb_core.client.msearch.return_value = {"responses": []} - - assert vdb_core.search("idx", {"query": {"match_all": {}}}) == {"hits": {}} - assert vdb_core.multi_search([{"query": {"match_all": {}}}], "idx") == {"responses": []} - - def test_exec_query_formats_results(self, vdb_core): - """Ensure exec_query strips metadata and exposes scores.""" - vdb_core.client = MagicMock() - vdb_core.client.search.return_value = { - "hits": { - "hits": [ - { - "_score": 1.23, - "_index": "idx", - "_source": {"id": "doc1", "content": "body"}, - } - ] - } - } - - results = vdb_core.exec_query("idx", {"query": {}}) - assert results == [ - {"score": 1.23, "document": {"id": "doc1", "content": "body"}, "index": "idx"} - ] - - def test_hybrid_search_missing_fields_logged_for_accurate(self, vdb_core): - """Ensure hybrid_search tolerates missing accurate fields.""" - mock_embedding_model = MagicMock() - with patch.object(vdb_core, "accurate_search", return_value=[{"score": 1.0}]), \ - patch.object(vdb_core, "semantic_search", return_value=[]): - assert vdb_core.hybrid_search(["idx"], "query", mock_embedding_model) == [] - - def test_hybrid_search_missing_fields_logged_for_semantic(self, vdb_core): - """Ensure hybrid_search tolerates missing semantic fields.""" - mock_embedding_model = MagicMock() - with patch.object(vdb_core, "accurate_search", return_value=[]), \ - patch.object(vdb_core, "semantic_search", return_value=[{"score": 0.5}]): - assert vdb_core.hybrid_search(["idx"], "query", mock_embedding_model) == [] - - def test_hybrid_search_faulty_combined_results(self, vdb_core): - """Inject faulty combined result to hit KeyError handling in final loop.""" - mock_embedding_model = MagicMock() - accurate_payload = [ - {"score": 1.0, "document": {"id": "doc1"}, "index": "idx"} - ] - - with patch.object(vdb_core, "accurate_search", return_value=accurate_payload), \ - patch.object(vdb_core, "semantic_search", return_value=[]): - - injected = {"done": False} - - def tracer(frame, event, arg): - if ( - frame.f_code.co_name == "hybrid_search" - and event == "line" - and frame.f_lineno == 788 - and not injected["done"] - ): - frame.f_locals["combined_results"]["faulty"] = { - "accurate_score": 0, - "semantic_score": 0, - } - injected["done"] = True - return tracer - - sys.settrace(tracer) - try: - results = vdb_core.hybrid_search(["idx"], "query", mock_embedding_model) - finally: - sys.settrace(None) - - assert len(results) == 1 - - def test_get_documents_detail_exception(self, vdb_core): - """Ensure get_documents_detail returns empty list on failure.""" - vdb_core.client = MagicMock() - vdb_core.client.search.side_effect = Exception("fail") - - assert vdb_core.get_documents_detail("idx") == [] - - def test_get_indices_detail_success(self, vdb_core): - """Test get_indices_detail successful case""" - vdb_core.client = MagicMock() - vdb_core.client.indices.stats.return_value = { - "indices": { - "test_index": { - "primaries": { - "docs": {"count": 100}, - "store": {"size_in_bytes": 1024}, - "search": {"query_total": 50}, - "request_cache": {"hit_count": 25} - } - } - } - } - vdb_core.client.indices.get_settings.return_value = { - "test_index": { - "settings": { - "index": { - "number_of_shards": "1", - "number_of_replicas": "0", - "creation_date": "1640995200000" - } - } - } - } - vdb_core.client.search.return_value = { - "aggregations": { - "unique_path_or_url_count": {"value": 10}, - "process_sources": {"buckets": [{"key": "test_source"}]}, - "embedding_models": {"buckets": [{"key": "test_model"}]} - } - } - - result = vdb_core.get_indices_detail(["test_index"]) - assert "test_index" in result - assert "base_info" in result["test_index"] - assert "search_performance" in result["test_index"] - - def test_get_indices_detail_exception(self, vdb_core): - """Test get_indices_detail with exception""" - vdb_core.client = MagicMock() - vdb_core.client.indices.stats.side_effect = Exception("Stats error") - - result = vdb_core.get_indices_detail(["test_index"]) - # The function returns error info for failed indices, not empty dict - assert "test_index" in result - assert "error" in result["test_index"] - - def test_get_indices_detail_with_embedding_dim(self, vdb_core): - """Test get_indices_detail with embedding dimension""" - vdb_core.client = MagicMock() - vdb_core.client.indices.stats.return_value = { - "indices": { - "test_index": { - "primaries": { - "docs": {"count": 100}, - "store": {"size_in_bytes": 1024}, - "search": {"query_total": 50}, - "request_cache": {"hit_count": 25} - } - } - } - } - vdb_core.client.indices.get_settings.return_value = { - "test_index": { - "settings": { - "index": { - "number_of_shards": "1", - "number_of_replicas": "0", - "creation_date": "1640995200000" - } - } - } - } - vdb_core.client.search.return_value = { - "aggregations": { - "unique_path_or_url_count": {"value": 10}, - "process_sources": {"buckets": [{"key": "test_source"}]}, - "embedding_models": {"buckets": [{"key": "test_model"}]} - } - } - - result = vdb_core.get_indices_detail(["test_index"], embedding_dim=512) - assert "test_index" in result - assert "base_info" in result["test_index"] - assert "search_performance" in result["test_index"] - assert result["test_index"]["base_info"]["embedding_dim"] == 512 - - def test_bulk_operation_context_success(self, vdb_core): - """Test bulk_operation_context successful case""" - vdb_core._bulk_operations = {} - vdb_core._operation_counter = 0 - vdb_core._settings_lock = MagicMock() - vdb_core._apply_bulk_settings = MagicMock() - vdb_core._restore_normal_settings = MagicMock() - - with vdb_core.bulk_operation_context("test_index") as operation_id: - assert operation_id is not None - assert "test_index" in vdb_core._bulk_operations - vdb_core._apply_bulk_settings.assert_called_once_with("test_index") - - # After context exit, should restore settings - vdb_core._restore_normal_settings.assert_called_once_with("test_index") - - def test_bulk_operation_context_multiple_operations(self, vdb_core): - """Test bulk_operation_context with multiple operations""" - vdb_core._bulk_operations = {} - vdb_core._operation_counter = 0 - vdb_core._settings_lock = MagicMock() - vdb_core._apply_bulk_settings = MagicMock() - vdb_core._restore_normal_settings = MagicMock() - - # First operation - with vdb_core.bulk_operation_context("test_index") as op1: - assert op1 is not None - vdb_core._apply_bulk_settings.assert_called_once() - - # After first operation exits, settings should be restored - vdb_core._restore_normal_settings.assert_called_once_with("test_index") - - # Second operation - will apply settings again since first operation is done - with vdb_core.bulk_operation_context("test_index") as op2: - assert op2 is not None - # Should call apply_bulk_settings again since first operation is done - assert vdb_core._apply_bulk_settings.call_count == 2 - - # After second operation exits, should restore settings again - assert vdb_core._restore_normal_settings.call_count == 2 - - def test_small_batch_insert_success(self, vdb_core): - """Test _small_batch_insert successful case""" - vdb_core.client = MagicMock() - vdb_core.client.bulk.return_value = {"items": [], "errors": False} - vdb_core._preprocess_documents = MagicMock(return_value=[ - {"content": "test content", "title": "test"} - ]) - vdb_core._handle_bulk_errors = MagicMock() - - mock_embedding_model = MagicMock() - mock_embedding_model.get_embeddings.return_value = [[0.1, 0.2, 0.3]] - mock_embedding_model.embedding_model_name = "test_model" - - documents = [{"content": "test content", "title": "test"}] - - result = vdb_core._small_batch_insert("test_index", documents, "content", mock_embedding_model) - assert result == 1 - vdb_core.client.bulk.assert_called_once() - - def test_small_batch_insert_exception(self, vdb_core): - """Test _small_batch_insert with exception""" - vdb_core._preprocess_documents = MagicMock(side_effect=Exception("Preprocess error")) - - mock_embedding_model = MagicMock() - documents = [{"content": "test content", "title": "test"}] - - with pytest.raises(Exception) as exc_info: - vdb_core._small_batch_insert("test_index", documents, "content", mock_embedding_model) - assert "Preprocess error" in str(exc_info.value) - - def test_large_batch_insert_success(self, vdb_core): - """Test _large_batch_insert successful case""" - vdb_core.client = MagicMock() - vdb_core.client.bulk.return_value = {"items": [], "errors": False} - vdb_core._preprocess_documents = MagicMock(return_value=[ - {"content": "test content", "title": "test"} - ]) - vdb_core._handle_bulk_errors = MagicMock() - - mock_embedding_model = MagicMock() - mock_embedding_model.get_embeddings.return_value = [[0.1, 0.2, 0.3]] - mock_embedding_model.embedding_model_name = "test_model" - - documents = [{"content": "test content", "title": "test"}] - - result = vdb_core._large_batch_insert("test_index", documents, 10, "content", mock_embedding_model) - assert result == 1 - vdb_core.client.bulk.assert_called_once() - - def test_large_batch_insert_embedding_error(self, vdb_core): - """Test _large_batch_insert with embedding API error""" - vdb_core.client = MagicMock() - vdb_core._preprocess_documents = MagicMock(return_value=[ - {"content": "test content", "title": "test"} - ]) - - mock_embedding_model = MagicMock() - mock_embedding_model.get_embeddings.side_effect = Exception("Embedding API error") - - documents = [{"content": "test content", "title": "test"}] - - result = vdb_core._large_batch_insert("test_index", documents, 10, "content", mock_embedding_model) - assert result == 0 # No documents indexed due to embedding error - - def test_large_batch_insert_no_embeddings(self, vdb_core): - """Test _large_batch_insert with no successful embeddings""" - vdb_core.client = MagicMock() - vdb_core._preprocess_documents = MagicMock(return_value=[ - {"content": "test content", "title": "test"} - ]) - - mock_embedding_model = MagicMock() - mock_embedding_model.get_embeddings.side_effect = Exception("Embedding API error") - - documents = [{"content": "test content", "title": "test"}] - - result = vdb_core._large_batch_insert("test_index", documents, 10, "content", mock_embedding_model) - assert result == 0 # No documents indexed