22import importlib
33import inspect
44import json
5- import keyword
65import logging
76from typing import Any , List , Optional , Dict
87from urllib .parse import urljoin
2423 search_last_tool_instance_by_tool_id
2524)
2625from database .user_tenant_db import get_all_tenant_ids
26+ from services .elasticsearch_service import get_embedding_model , elastic_core
27+ from services .tenant_config_service import get_selected_knowledge_list
2728
2829logger = logging .getLogger ("tool_configuration_service" )
2930
@@ -102,6 +103,7 @@ def get_local_tools() -> List[ToolInfo]:
102103 inputs = json .dumps (getattr (tool_class , 'inputs' ),
103104 ensure_ascii = False ),
104105 output_type = getattr (tool_class , 'output_type' ),
106+ category = getattr (tool_class , 'category' ),
105107 class_name = tool_class .__name__ ,
106108 usage = None ,
107109 origin_name = getattr (tool_class , 'name' )
@@ -162,7 +164,8 @@ def _build_tool_info_from_langchain(obj) -> ToolInfo:
162164 output_type = output_type ,
163165 class_name = tool_name ,
164166 usage = None ,
165- origin_name = tool_name
167+ origin_name = tool_name ,
168+ category = None
166169 )
167170 return tool_info
168171
@@ -298,7 +301,8 @@ async def get_tool_from_remote_mcp_server(mcp_server_name: str, remote_mcp_serve
298301 output_type = "string" ,
299302 class_name = sanitized_tool_name ,
300303 usage = mcp_server_name ,
301- origin_name = tool .name )
304+ origin_name = tool .name ,
305+ category = None )
302306 tools_info .append (tool_info )
303307 return tools_info
304308 except Exception as e :
@@ -351,7 +355,8 @@ async def list_all_tools(tenant_id: str):
351355 "create_time" : tool .get ("create_time" ),
352356 "usage" : tool .get ("usage" ),
353357 "params" : tool .get ("params" , []),
354- "inputs" : tool .get ("inputs" , {})
358+ "inputs" : tool .get ("inputs" , {}),
359+ "category" : tool .get ("category" )
355360 }
356361 formatted_tools .append (formatted_tool )
357362
@@ -543,7 +548,9 @@ def _get_tool_class_by_name(tool_name: str) -> Optional[type]:
543548def _validate_local_tool (
544549 tool_name : str ,
545550 inputs : Optional [Dict [str , Any ]] = None ,
546- params : Optional [Dict [str , Any ]] = None
551+ params : Optional [Dict [str , Any ]] = None ,
552+ tenant_id : Optional [str ] = None ,
553+ user_id : Optional [str ] = None
547554) -> Dict [str , Any ]:
548555 """
549556 Validate local tool by actually instantiating and calling it.
@@ -552,6 +559,8 @@ def _validate_local_tool(
552559 tool_name: Name of the tool to test
553560 inputs: Parameters to pass to the tool's forward method
554561 params: Configuration parameters for tool initialization
562+ tenant_id: Tenant ID for knowledge base tools (optional)
563+ user_id: User ID for knowledge base tools (optional)
555564
556565 Returns:
557566 Dict[str, Any]: The actual result returned by the tool's forward method,
@@ -567,15 +576,45 @@ def _validate_local_tool(
567576 if not tool_class :
568577 raise NotFoundException (f"Tool class not found for { tool_name } " )
569578
570- # Instantiate tool with provided params or default parameters
579+ # Parse instantiation parameters first
571580 instantiation_params = params or {}
572- # Check if the tool constructor expects an observer parameter
581+ # Get signature and extract default values for all parameters
573582 sig = inspect .signature (tool_class .__init__ )
574- if 'observer' in sig .parameters and 'observer' not in instantiation_params :
575- instantiation_params ['observer' ] = None
576- tool_instance = tool_class (** instantiation_params )
577583
578- # Call forward method with provided parameters
584+ # Extract default values for all parameters not provided in instantiation_params
585+ for param_name , param in sig .parameters .items ():
586+ if param_name == "self" :
587+ continue
588+
589+ # If parameter not provided, extract default value
590+ if param_name not in instantiation_params :
591+ if param .default is PydanticUndefined :
592+ continue
593+ elif hasattr (param .default , 'default' ):
594+ # This is a Field object, extract its default value
595+ if param .default .default is not PydanticUndefined :
596+ instantiation_params [param_name ] = param .default .default
597+ else :
598+ instantiation_params [param_name ] = param .default
599+
600+ if tool_name == "knowledge_base_search" :
601+ if not tenant_id or not user_id :
602+ raise ToolExecutionException (f"Tenant ID and User ID are required for { tool_name } validation" )
603+ knowledge_info_list = get_selected_knowledge_list (
604+ tenant_id = tenant_id , user_id = user_id )
605+ index_names = [knowledge_info .get ("index_name" )
606+ for knowledge_info in knowledge_info_list ]
607+ embedding_model = get_embedding_model (tenant_id = tenant_id )
608+ params = {
609+ ** instantiation_params ,
610+ 'index_names' : index_names ,
611+ 'es_core' : elastic_core ,
612+ 'embedding_model' : embedding_model
613+ }
614+ tool_instance = tool_class (** params )
615+ else :
616+ tool_instance = tool_class (** instantiation_params )
617+
579618 result = tool_instance .forward (** (inputs or {}))
580619 return result
581620 except Exception as e :
@@ -630,14 +669,16 @@ def _validate_langchain_tool(
630669
631670async def validate_tool_impl (
632671 request : ToolValidateRequest ,
633- tenant_id : Optional [str ] = None
672+ tenant_id : Optional [str ] = None ,
673+ user_id : Optional [str ] = None
634674) -> Dict [str , Any ]:
635675 """
636676 Validate a tool from various sources (MCP, local, or LangChain).
637677
638678 Args:
639679 request: Tool validation request containing tool details
640680 tenant_id: Tenant ID for database queries (optional)
681+ user_id: User ID for database queries (optional)
641682
642683 Returns:
643684 Dict containing validation result - success returns tool result, failure returns error message
@@ -657,18 +698,18 @@ async def validate_tool_impl(
657698 else :
658699 return await _validate_mcp_tool_remote (tool_name , inputs , usage , tenant_id )
659700 elif source == ToolSourceEnum .LOCAL .value :
660- return _validate_local_tool (tool_name , inputs , params )
701+ return _validate_local_tool (tool_name , inputs , params , tenant_id , user_id )
661702 elif source == ToolSourceEnum .LANGCHAIN .value :
662703 return _validate_langchain_tool (tool_name , inputs )
663704 else :
664705 raise Exception (f"Unsupported tool source: { source } " )
665706
666707 except NotFoundException as e :
667708 logger .error (f"Tool not found: { e } " )
668- raise NotFoundException (f"Tool not found: { str (e )} " )
709+ raise NotFoundException (str (e ))
669710 except MCPConnectionError as e :
670711 logger .error (f"MCP connection failed: { e } " )
671- raise MCPConnectionError (f"MCP connection failed: { str (e )} " )
712+ raise MCPConnectionError (str (e ))
672713 except Exception as e :
673714 logger .error (f"Validate Tool failed: { e } " )
674- raise ToolExecutionException (f"Validate Tool failed: { str (e )} " )
715+ raise ToolExecutionException (str (e ))
0 commit comments