|
4 | 4 | import os |
5 | 5 | import uuid |
6 | 6 | from collections import deque |
| 7 | +from typing import Optional |
7 | 8 |
|
8 | 9 | from fastapi import Header, Request |
9 | 10 | from fastapi.responses import JSONResponse, StreamingResponse |
|
35 | 36 | search_agent_id_by_agent_name, |
36 | 37 | search_agent_info_by_agent_id, |
37 | 38 | search_blank_sub_agent_by_main_agent_id, |
38 | | - update_agent |
| 39 | + update_agent, |
| 40 | + update_related_agents |
39 | 41 | ) |
40 | 42 | from database.model_management_db import get_model_by_model_id, get_model_id_by_display_name |
41 | 43 | from database.remote_mcp_db import check_mcp_name_exists, get_mcp_server_by_name_and_tenant |
|
45 | 47 | delete_tools_by_agent_id, |
46 | 48 | query_all_enabled_tool_instances, |
47 | 49 | query_all_tools, |
| 50 | + query_tool_instances_by_id, |
48 | 51 | search_tools_for_sub_agent |
49 | 52 | ) |
50 | 53 | from services.conversation_management_service import save_conversation_assistant, save_conversation_user |
@@ -340,12 +343,100 @@ async def get_creating_sub_agent_info_impl(authorization: str = Header(None)): |
340 | 343 | async def update_agent_info_impl(request: AgentInfoRequest, authorization: str = Header(None)): |
341 | 344 | user_id, tenant_id, _ = get_current_user_info(authorization) |
342 | 345 |
|
| 346 | + # If agent_id is None, create a new agent; otherwise, update existing |
| 347 | + agent_id: Optional[int] = request.agent_id |
343 | 348 | try: |
344 | | - update_agent(request.agent_id, request, tenant_id, user_id) |
| 349 | + if agent_id is None: |
| 350 | + # Create agent |
| 351 | + created = create_agent(agent_info={ |
| 352 | + "name": request.name, |
| 353 | + "display_name": request.display_name, |
| 354 | + "description": request.description, |
| 355 | + "business_description": request.business_description, |
| 356 | + "model_id": request.model_id, |
| 357 | + "model_name": request.model_name, |
| 358 | + "business_logic_model_id": request.business_logic_model_id, |
| 359 | + "business_logic_model_name": request.business_logic_model_name, |
| 360 | + "max_steps": request.max_steps, |
| 361 | + "provide_run_summary": request.provide_run_summary, |
| 362 | + "duty_prompt": request.duty_prompt, |
| 363 | + "constraint_prompt": request.constraint_prompt, |
| 364 | + "few_shots_prompt": request.few_shots_prompt, |
| 365 | + "enabled": request.enabled if request.enabled is not None else True |
| 366 | + }, tenant_id=tenant_id, user_id=user_id) |
| 367 | + agent_id = created["agent_id"] |
| 368 | + else: |
| 369 | + # Update agent |
| 370 | + update_agent(agent_id, request, tenant_id, user_id) |
345 | 371 | except Exception as e: |
346 | 372 | logger.error(f"Failed to update agent info: {str(e)}") |
347 | 373 | raise ValueError(f"Failed to update agent info: {str(e)}") |
348 | 374 |
|
| 375 | + # Handle enabled tools saving when provided |
| 376 | + try: |
| 377 | + if request.enabled_tool_ids is not None and agent_id is not None: |
| 378 | + enabled_set = set(request.enabled_tool_ids) |
| 379 | + # Get all tools for current tenant |
| 380 | + all_tools = query_all_tools(tenant_id=tenant_id) |
| 381 | + for tool in all_tools: |
| 382 | + tool_id = tool.get("tool_id") |
| 383 | + if tool_id is None: |
| 384 | + continue |
| 385 | + # Keep existing params if any |
| 386 | + existing_instance = query_tool_instances_by_id( |
| 387 | + agent_id, tool_id, tenant_id) |
| 388 | + params = (existing_instance or {}).get( |
| 389 | + "params", {}) if existing_instance else {} |
| 390 | + create_or_update_tool_by_tool_info( |
| 391 | + tool_info=ToolInstanceInfoRequest( |
| 392 | + tool_id=tool_id, |
| 393 | + agent_id=agent_id, |
| 394 | + params=params, |
| 395 | + enabled=(tool_id in enabled_set) |
| 396 | + ), |
| 397 | + tenant_id=tenant_id, |
| 398 | + user_id=user_id |
| 399 | + ) |
| 400 | + except Exception as e: |
| 401 | + logger.error(f"Failed to update agent tools: {str(e)}") |
| 402 | + raise ValueError(f"Failed to update agent tools: {str(e)}") |
| 403 | + |
| 404 | + # Handle related agents saving when provided |
| 405 | + try: |
| 406 | + if request.related_agent_ids is not None and agent_id is not None: |
| 407 | + related_agent_ids = request.related_agent_ids |
| 408 | + # Check for circular dependencies using BFS |
| 409 | + search_list = deque(related_agent_ids) |
| 410 | + agent_id_set = set() |
| 411 | + |
| 412 | + while len(search_list): |
| 413 | + left_ele = search_list.popleft() |
| 414 | + if left_ele == agent_id: |
| 415 | + raise ValueError("Circular dependency detected: Agent cannot be related to itself or create circular calls") |
| 416 | + if left_ele in agent_id_set: |
| 417 | + continue |
| 418 | + else: |
| 419 | + agent_id_set.add(left_ele) |
| 420 | + sub_ids = query_sub_agents_id_list( |
| 421 | + main_agent_id=left_ele, tenant_id=tenant_id) |
| 422 | + search_list.extend(sub_ids) |
| 423 | + |
| 424 | + # Update related agents |
| 425 | + update_related_agents( |
| 426 | + parent_agent_id=agent_id, |
| 427 | + related_agent_ids=related_agent_ids, |
| 428 | + tenant_id=tenant_id, |
| 429 | + user_id=user_id |
| 430 | + ) |
| 431 | + except ValueError as e: |
| 432 | + # Re-raise ValueError (circular dependency) as-is |
| 433 | + raise |
| 434 | + except Exception as e: |
| 435 | + logger.error(f"Failed to update related agents: {str(e)}") |
| 436 | + raise ValueError(f"Failed to update related agents: {str(e)}") |
| 437 | + |
| 438 | + return {"agent_id": agent_id} |
| 439 | + |
349 | 440 |
|
350 | 441 | async def delete_agent_impl(agent_id: int, authorization: str = Header(None)): |
351 | 442 | user_id, tenant_id, _ = get_current_user_info(authorization) |
|
0 commit comments