Skip to content

Commit b1024a6

Browse files
authored
✨ Add tool category feature. #1362
2 parents d5761ed + 616553e commit b1024a6

40 files changed

+2900
-72
lines changed

backend/apps/tool_config_app.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from fastapi import APIRouter, Header, HTTPException
66
from fastapi.responses import JSONResponse
77

8-
from consts.exceptions import MCPConnectionError, TimeoutException, NotFoundException
8+
from consts.exceptions import MCPConnectionError, NotFoundException
99
from consts.model import ToolInstanceInfoRequest, ToolInstanceSearchRequest, ToolValidateRequest
1010
from services.tool_configuration_service import (
1111
search_tool_info_impl,
@@ -109,8 +109,8 @@ async def validate_tool(
109109
):
110110
"""Validate specific tool based on source type"""
111111
try:
112-
_, tenant_id = get_current_user_id(authorization)
113-
result = await validate_tool_impl(request, tenant_id)
112+
user_id, tenant_id = get_current_user_id(authorization)
113+
result = await validate_tool_impl(request, tenant_id, user_id)
114114

115115
return JSONResponse(
116116
status_code=HTTPStatus.OK,
@@ -132,5 +132,5 @@ async def validate_tool(
132132
logger.error(f"Failed to validate tool: {e}")
133133
raise HTTPException(
134134
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
135-
detail=f"Failed to validate tool: {str(e)}"
135+
detail=str(e)
136136
)

backend/consts/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ class ToolInfo(BaseModel):
245245
class_name: str
246246
usage: Optional[str]
247247
origin_name: Optional[str] = None
248+
category: Optional[str] = None
248249

249250

250251
# used in Knowledge Summary request

backend/database/db_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ class ToolInfo(TableBase):
177177
params = Column(JSON, doc="Tool parameter information (json)")
178178
inputs = Column(String(2048), doc="Prompt tool inputs description")
179179
output_type = Column(String(100), doc="Prompt tool output description")
180+
category = Column(String(100), doc="Tool category description")
180181
is_available = Column(
181182
Boolean, doc="Whether the tool can be used under the current main service")
182183

backend/services/tool_configuration_service.py

Lines changed: 57 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import importlib
33
import inspect
44
import json
5-
import keyword
65
import logging
76
from typing import Any, List, Optional, Dict
87
from urllib.parse import urljoin
@@ -24,6 +23,8 @@
2423
search_last_tool_instance_by_tool_id
2524
)
2625
from database.user_tenant_db import get_all_tenant_ids
26+
from services.elasticsearch_service import get_embedding_model, elastic_core
27+
from services.tenant_config_service import get_selected_knowledge_list
2728

2829
logger = logging.getLogger("tool_configuration_service")
2930

@@ -102,6 +103,7 @@ def get_local_tools() -> List[ToolInfo]:
102103
inputs=json.dumps(getattr(tool_class, 'inputs'),
103104
ensure_ascii=False),
104105
output_type=getattr(tool_class, 'output_type'),
106+
category=getattr(tool_class, 'category'),
105107
class_name=tool_class.__name__,
106108
usage=None,
107109
origin_name=getattr(tool_class, 'name')
@@ -162,7 +164,8 @@ def _build_tool_info_from_langchain(obj) -> ToolInfo:
162164
output_type=output_type,
163165
class_name=tool_name,
164166
usage=None,
165-
origin_name=tool_name
167+
origin_name=tool_name,
168+
category=None
166169
)
167170
return tool_info
168171

@@ -298,7 +301,8 @@ async def get_tool_from_remote_mcp_server(mcp_server_name: str, remote_mcp_serve
298301
output_type="string",
299302
class_name=sanitized_tool_name,
300303
usage=mcp_server_name,
301-
origin_name=tool.name)
304+
origin_name=tool.name,
305+
category=None)
302306
tools_info.append(tool_info)
303307
return tools_info
304308
except Exception as e:
@@ -351,7 +355,8 @@ async def list_all_tools(tenant_id: str):
351355
"create_time": tool.get("create_time"),
352356
"usage": tool.get("usage"),
353357
"params": tool.get("params", []),
354-
"inputs": tool.get("inputs", {})
358+
"inputs": tool.get("inputs", {}),
359+
"category": tool.get("category")
355360
}
356361
formatted_tools.append(formatted_tool)
357362

@@ -543,7 +548,9 @@ def _get_tool_class_by_name(tool_name: str) -> Optional[type]:
543548
def _validate_local_tool(
544549
tool_name: str,
545550
inputs: Optional[Dict[str, Any]] = None,
546-
params: Optional[Dict[str, Any]] = None
551+
params: Optional[Dict[str, Any]] = None,
552+
tenant_id: Optional[str] = None,
553+
user_id: Optional[str] = None
547554
) -> Dict[str, Any]:
548555
"""
549556
Validate local tool by actually instantiating and calling it.
@@ -552,6 +559,8 @@ def _validate_local_tool(
552559
tool_name: Name of the tool to test
553560
inputs: Parameters to pass to the tool's forward method
554561
params: Configuration parameters for tool initialization
562+
tenant_id: Tenant ID for knowledge base tools (optional)
563+
user_id: User ID for knowledge base tools (optional)
555564
556565
Returns:
557566
Dict[str, Any]: The actual result returned by the tool's forward method,
@@ -567,15 +576,45 @@ def _validate_local_tool(
567576
if not tool_class:
568577
raise NotFoundException(f"Tool class not found for {tool_name}")
569578

570-
# Instantiate tool with provided params or default parameters
579+
# Parse instantiation parameters first
571580
instantiation_params = params or {}
572-
# Check if the tool constructor expects an observer parameter
581+
# Get signature and extract default values for all parameters
573582
sig = inspect.signature(tool_class.__init__)
574-
if 'observer' in sig.parameters and 'observer' not in instantiation_params:
575-
instantiation_params['observer'] = None
576-
tool_instance = tool_class(**instantiation_params)
577583

578-
# Call forward method with provided parameters
584+
# Extract default values for all parameters not provided in instantiation_params
585+
for param_name, param in sig.parameters.items():
586+
if param_name == "self":
587+
continue
588+
589+
# If parameter not provided, extract default value
590+
if param_name not in instantiation_params:
591+
if param.default is PydanticUndefined:
592+
continue
593+
elif hasattr(param.default, 'default'):
594+
# This is a Field object, extract its default value
595+
if param.default.default is not PydanticUndefined:
596+
instantiation_params[param_name] = param.default.default
597+
else:
598+
instantiation_params[param_name] = param.default
599+
600+
if tool_name == "knowledge_base_search":
601+
if not tenant_id or not user_id:
602+
raise ToolExecutionException(f"Tenant ID and User ID are required for {tool_name} validation")
603+
knowledge_info_list = get_selected_knowledge_list(
604+
tenant_id=tenant_id, user_id=user_id)
605+
index_names = [knowledge_info.get("index_name")
606+
for knowledge_info in knowledge_info_list]
607+
embedding_model = get_embedding_model(tenant_id=tenant_id)
608+
params = {
609+
**instantiation_params,
610+
'index_names': index_names,
611+
'es_core': elastic_core,
612+
'embedding_model': embedding_model
613+
}
614+
tool_instance = tool_class(**params)
615+
else:
616+
tool_instance = tool_class(**instantiation_params)
617+
579618
result = tool_instance.forward(**(inputs or {}))
580619
return result
581620
except Exception as e:
@@ -630,14 +669,16 @@ def _validate_langchain_tool(
630669

631670
async def validate_tool_impl(
632671
request: ToolValidateRequest,
633-
tenant_id: Optional[str] = None
672+
tenant_id: Optional[str] = None,
673+
user_id: Optional[str] = None
634674
) -> Dict[str, Any]:
635675
"""
636676
Validate a tool from various sources (MCP, local, or LangChain).
637677
638678
Args:
639679
request: Tool validation request containing tool details
640680
tenant_id: Tenant ID for database queries (optional)
681+
user_id: User ID for database queries (optional)
641682
642683
Returns:
643684
Dict containing validation result - success returns tool result, failure returns error message
@@ -657,18 +698,18 @@ async def validate_tool_impl(
657698
else:
658699
return await _validate_mcp_tool_remote(tool_name, inputs, usage, tenant_id)
659700
elif source == ToolSourceEnum.LOCAL.value:
660-
return _validate_local_tool(tool_name, inputs, params)
701+
return _validate_local_tool(tool_name, inputs, params, tenant_id, user_id)
661702
elif source == ToolSourceEnum.LANGCHAIN.value:
662703
return _validate_langchain_tool(tool_name, inputs)
663704
else:
664705
raise Exception(f"Unsupported tool source: {source}")
665706

666707
except NotFoundException as e:
667708
logger.error(f"Tool not found: {e}")
668-
raise NotFoundException(f"Tool not found: {str(e)}")
709+
raise NotFoundException(str(e))
669710
except MCPConnectionError as e:
670711
logger.error(f"MCP connection failed: {e}")
671-
raise MCPConnectionError(f"MCP connection failed: {str(e)}")
712+
raise MCPConnectionError(str(e))
672713
except Exception as e:
673714
logger.error(f"Validate Tool failed: {e}")
674-
raise ToolExecutionException(f"Validate Tool failed: {str(e)}")
715+
raise ToolExecutionException(str(e))

docker/init.sql

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ CREATE TABLE IF NOT EXISTS nexent.ag_tool_info_t (
237237
params JSON,
238238
inputs VARCHAR,
239239
output_type VARCHAR(100),
240+
category VARCHAR(100),
240241
is_available BOOLEAN DEFAULT FALSE,
241242
create_time TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP,
242243
update_time TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP,
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
-- Add category column to ag_tool_info_t table
2+
-- This field stores the tool category information (search, file, email, terminal)
3+
4+
ALTER TABLE nexent.ag_tool_info_t
5+
ADD COLUMN IF NOT EXISTS category VARCHAR(100);
6+
7+
-- Add comment to document the purpose of this field
8+
COMMENT ON COLUMN nexent.ag_tool_info_t.category IS 'Tool category information';

frontend/app/[locale]/setup/agents/components/tool/ToolConfigModal.tsx

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import {
2929
extractParameterNames,
3030
} from "@/services/agentConfigService";
3131
import log from "@/lib/logger";
32+
import { useModalPosition } from "@/hooks/useModalPosition";
3233

3334
export default function ToolConfigModal({
3435
isOpen,
@@ -51,6 +52,8 @@ export default function ToolConfigModal({
5152
const [parsedInputs, setParsedInputs] = useState<Record<string, any>>({});
5253
const [paramValues, setParamValues] = useState<Record<string, string>>({});
5354
const [dynamicInputParams, setDynamicInputParams] = useState<string[]>([]);
55+
const { windowWidth, mainModalTop, mainModalRight } =
56+
useModalPosition(isOpen);
5457

5558
// load tool config
5659
useEffect(() => {
@@ -572,8 +575,11 @@ export default function ToolConfigModal({
572575
className="tool-test-panel"
573576
style={{
574577
position: "fixed",
575-
top: "10vh",
576-
right: "5vw",
578+
top: mainModalTop > 0 ? `${mainModalTop}px` : "10vh", // Align with main modal top or fallback to 10vh
579+
left:
580+
mainModalRight > 0
581+
? `${mainModalRight + windowWidth * 0.05}px`
582+
: "calc(50% + 300px + 5vw)", // Position to the right of main modal with 5% viewport width gap
577583
width: "500px",
578584
height: "auto",
579585
maxHeight: "80vh",
@@ -759,4 +765,4 @@ export default function ToolConfigModal({
759765
)}
760766
</>
761767
);
762-
}
768+
}

0 commit comments

Comments
 (0)