diff --git a/surfsense_backend/alembic/versions/5263aa4e7f94_allow_multiple_connectors_with_unique_.py b/surfsense_backend/alembic/versions/5263aa4e7f94_allow_multiple_connectors_with_unique_.py new file mode 100644 index 000000000..954a864ab --- /dev/null +++ b/surfsense_backend/alembic/versions/5263aa4e7f94_allow_multiple_connectors_with_unique_.py @@ -0,0 +1,52 @@ +"""allow_multiple_connectors_with_unique_names + +Revision ID: 5263aa4e7f94 +Revises: ffd7445eb90a +Create Date: 2026-01-13 12:23:31.481643 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '5263aa4e7f94' +down_revision: Union[str, None] = 'ffd7445eb90a' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # Drop the old unique constraint + op.drop_constraint( + 'uq_searchspace_user_connector_type', + 'search_source_connectors', + type_='unique' + ) + + # Create new unique constraint that includes name + op.create_unique_constraint( + 'uq_searchspace_user_connector_type_name', + 'search_source_connectors', + ['search_space_id', 'user_id', 'connector_type', 'name'] + ) + + +def downgrade() -> None: + """Downgrade schema.""" + # Drop the new constraint + op.drop_constraint( + 'uq_searchspace_user_connector_type_name', + 'search_source_connectors', + type_='unique' + ) + + # Restore the old constraint + op.create_unique_constraint( + 'uq_searchspace_user_connector_type', + 'search_source_connectors', + ['search_space_id', 'user_id', 'connector_type'] + ) diff --git a/surfsense_backend/alembic/versions/60_add_mcp_connector_type.py b/surfsense_backend/alembic/versions/60_add_mcp_connector_type.py new file mode 100644 index 000000000..f3ec2611c --- /dev/null +++ b/surfsense_backend/alembic/versions/60_add_mcp_connector_type.py @@ -0,0 +1,37 @@ +"""Add MCP connector type + +Revision ID: 60 +Revises: 59 +Create Date: 2026-01-09 15:19:51.827647 + +""" +from collections.abc import Sequence + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = '60' +down_revision: str | None = '59' +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Add MCP_CONNECTOR to SearchSourceConnectorType enum.""" + # Add new enum value using raw SQL + op.execute( + """ + ALTER TYPE searchsourceconnectortype ADD VALUE IF NOT EXISTS 'MCP_CONNECTOR'; + """ + ) + + +def downgrade() -> None: + """Remove MCP_CONNECTOR from SearchSourceConnectorType enum.""" + # Note: PostgreSQL does not support removing enum values directly. + # To downgrade, you would need to: + # 1. Create a new enum without MCP_CONNECTOR + # 2. Alter the column to use the new enum + # 3. Drop the old enum + # This is left as a manual operation if needed. + pass diff --git a/surfsense_backend/alembic/versions/ffd7445eb90a_add_is_active_to_search_source_.py b/surfsense_backend/alembic/versions/ffd7445eb90a_add_is_active_to_search_source_.py new file mode 100644 index 000000000..f1a88658b --- /dev/null +++ b/surfsense_backend/alembic/versions/ffd7445eb90a_add_is_active_to_search_source_.py @@ -0,0 +1,33 @@ +"""add_is_active_to_search_source_connectors + +Revision ID: ffd7445eb90a +Revises: 60 +Create Date: 2026-01-12 22:11:26.132654 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'ffd7445eb90a' +down_revision: Union[str, None] = '60' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # Add is_active column to search_source_connectors table + op.add_column( + 'search_source_connectors', + sa.Column('is_active', sa.Boolean(), nullable=False, server_default=sa.true()) + ) + + +def downgrade() -> None: + """Downgrade schema.""" + # Remove is_active column from search_source_connectors table + op.drop_column('search_source_connectors', 'is_active') diff --git a/surfsense_backend/app/agents/new_chat/chat_deepagent.py b/surfsense_backend/app/agents/new_chat/chat_deepagent.py index 6c8deb409..9675521f5 100644 --- a/surfsense_backend/app/agents/new_chat/chat_deepagent.py +++ b/surfsense_backend/app/agents/new_chat/chat_deepagent.py @@ -20,7 +20,7 @@ build_configurable_system_prompt, build_surfsense_system_prompt, ) -from app.agents.new_chat.tools import build_tools +from app.agents.new_chat.tools.registry import build_tools_async from app.services.connector_service import ConnectorService # ============================================================================= @@ -28,7 +28,7 @@ # ============================================================================= -def create_surfsense_deep_agent( +async def create_surfsense_deep_agent( llm: ChatLiteLLM, search_space_id: int, db_session: AsyncSession, @@ -120,8 +120,8 @@ def create_surfsense_deep_agent( "firecrawl_api_key": firecrawl_api_key, } - # Build tools using the registry - tools = build_tools( + # Build tools using the async registry (includes MCP tools) + tools = await build_tools_async( dependencies=dependencies, enabled_tools=enabled_tools, disabled_tools=disabled_tools, diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_client.py b/surfsense_backend/app/agents/new_chat/tools/mcp_client.py new file mode 100644 index 000000000..62ad258b9 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/mcp_client.py @@ -0,0 +1,185 @@ +"""MCP Client Wrapper. + +This module provides a client for communicating with MCP servers via stdio transport. +It handles server lifecycle management, tool discovery, and tool execution. +""" + +import asyncio +import logging +import os +from contextlib import asynccontextmanager +from typing import Any + +from mcp import ClientSession +from mcp.client.stdio import StdioServerParameters, stdio_client + +logger = logging.getLogger(__name__) + + +class MCPClient: + """Client for communicating with an MCP server.""" + + def __init__(self, command: str, args: list[str], env: dict[str, str] | None = None): + """Initialize MCP client. + + Args: + command: Command to spawn the MCP server (e.g., "uvx", "node") + args: Arguments for the command (e.g., ["mcp-server-git"]) + env: Optional environment variables for the server process + + """ + self.command = command + self.args = args + self.env = env or {} + self.session: ClientSession | None = None + + @asynccontextmanager + async def connect(self): + """Connect to the MCP server and manage its lifecycle. + + Yields: + ClientSession: Active MCP session for making requests + + """ + try: + # Merge env vars with current environment + server_env = os.environ.copy() + server_env.update(self.env) + + # Create server parameters with env + server_params = StdioServerParameters( + command=self.command, + args=self.args, + env=server_env + ) + + # Spawn server process and create session + async with stdio_client(server=server_params) as (read, write): + async with ClientSession(read, write) as session: + # Initialize the connection + await session.initialize() + self.session = session + logger.info( + f"Connected to MCP server: {self.command} {' '.join(self.args)}" + ) + yield session + + except Exception as e: + logger.error(f"Failed to connect to MCP server: {e!s}", exc_info=True) + raise + finally: + self.session = None + logger.info(f"Disconnected from MCP server: {self.command}") + + async def list_tools(self) -> list[dict[str, Any]]: + """List all tools available from the MCP server. + + Returns: + List of tool definitions with name, description, and input schema + + Raises: + RuntimeError: If not connected to server + + """ + if not self.session: + raise RuntimeError("Not connected to MCP server. Use 'async with client.connect():'") + + try: + # Call tools/list RPC method + response = await self.session.list_tools() + + tools = [] + for tool in response.tools: + tools.append({ + "name": tool.name, + "description": tool.description or "", + "input_schema": tool.inputSchema if hasattr(tool, "inputSchema") else {}, + }) + + logger.info(f"Listed {len(tools)} tools from MCP server") + return tools + + except Exception as e: + logger.error(f"Failed to list tools from MCP server: {e!s}", exc_info=True) + raise + + async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any: + """Call a tool on the MCP server. + + Args: + tool_name: Name of the tool to call + arguments: Arguments to pass to the tool + + Returns: + Tool execution result + + Raises: + RuntimeError: If not connected to server + + """ + if not self.session: + raise RuntimeError("Not connected to MCP server. Use 'async with client.connect():'") + + try: + logger.info(f"Calling MCP tool '{tool_name}' with arguments: {arguments}") + + # Call tools/call RPC method + response = await self.session.call_tool(tool_name, arguments=arguments) + + # Extract content from response + result = [] + for content in response.content: + if hasattr(content, "text"): + result.append(content.text) + elif hasattr(content, "data"): + result.append(str(content.data)) + else: + result.append(str(content)) + + result_str = "\n".join(result) if result else "" + logger.info(f"MCP tool '{tool_name}' succeeded: {result_str[:200]}") + return result_str + + except RuntimeError as e: + # Handle validation errors from MCP server responses + # Some MCP servers (like server-memory) return extra fields not in their schema + if "Invalid structured content" in str(e): + logger.warning(f"MCP server returned data not matching its schema, but continuing: {e}") + # Try to extract result from error message or return a success message + return "Operation completed (server returned unexpected format)" + raise + except Exception as e: + logger.error(f"Failed to call MCP tool '{tool_name}': {e!s}", exc_info=True) + return f"Error calling tool: {e!s}" + + +async def test_mcp_connection( + command: str, args: list[str], env: dict[str, str] | None = None +) -> dict[str, Any]: + """Test connection to an MCP server and fetch available tools. + + Args: + command: Command to spawn the MCP server + args: Arguments for the command + env: Optional environment variables + + Returns: + Dict with connection status and available tools + + """ + client = MCPClient(command, args, env) + + try: + async with client.connect(): + tools = await client.list_tools() + return { + "status": "success", + "message": f"Connected successfully. Found {len(tools)} tools.", + "tools": tools, + } + except Exception as e: + return { + "status": "error", + "message": f"Failed to connect: {e!s}", + "tools": [], + } diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py new file mode 100644 index 000000000..84b5f003c --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py @@ -0,0 +1,250 @@ +"""MCP Tool Factory. + +This module creates LangChain tools from MCP servers using the Model Context Protocol. +Tools are dynamically discovered from MCP servers - no manual configuration needed. + +This implements real MCP protocol support similar to Cursor's implementation. +""" + +import logging +from typing import Any + +from langchain_core.tools import StructuredTool +from pydantic import BaseModel, create_model +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.agents.new_chat.tools.mcp_client import MCPClient +from app.db import SearchSourceConnector, SearchSourceConnectorType + +logger = logging.getLogger(__name__) + + +def _normalize_gemini_params(params: dict[str, Any], mcp_schema: dict[str, Any]) -> dict[str, Any]: + """Normalize Gemini-transformed parameter names back to MCP schema format. + + Gemini tends to transform field names like: + - entityType -> type + - from/to -> fromEntity/toEntity + - relationType -> relation + + This function maps them back to the original MCP schema field names. + """ + schema_properties = mcp_schema.get("properties", {}) + normalized = {} + + for param_key, param_value in params.items(): + # Handle array parameters (need to normalize nested objects) + if isinstance(param_value, list) and len(param_value) > 0: + if isinstance(param_value[0], dict): + # Get the items schema to know what fields should be present + items_schema = schema_properties.get(param_key, {}).get("items", {}) + items_properties = items_schema.get("properties", {}) + + normalized_array = [] + for item in param_value: + normalized_item = {} + for item_key, item_value in item.items(): + # Map common Gemini transformations back to MCP names + if item_key == "type" and "entityType" in items_properties: + normalized_item["entityType"] = item_value + elif item_key == "fromEntity" and "from" in items_properties: + normalized_item["from"] = item_value + elif item_key == "toEntity" and "to" in items_properties: + normalized_item["to"] = item_value + elif item_key == "relation" and "relationType" in items_properties: + normalized_item["relationType"] = item_value + else: + # Use the original key if it exists in schema + normalized_item[item_key] = item_value + + # Add missing required fields with empty defaults if needed + for required_field in items_properties.keys(): + if required_field not in normalized_item: + # For arrays like observations, default to empty array + if items_properties[required_field].get("type") == "array": + normalized_item[required_field] = [] + else: + normalized_item[required_field] = "" + + normalized_array.append(normalized_item) + normalized[param_key] = normalized_array + else: + normalized[param_key] = param_value + else: + normalized[param_key] = param_value + + return normalized + + +def _create_dynamic_input_model_from_schema( + tool_name: str, input_schema: dict[str, Any], +) -> type[BaseModel]: + """Create a Pydantic model from MCP tool's JSON schema. + + Args: + tool_name: Name of the tool (used for model class name) + input_schema: JSON schema from MCP server + + Returns: + Pydantic model class for tool input validation + + """ + properties = input_schema.get("properties", {}) + required_fields = input_schema.get("required", []) + + # Build Pydantic field definitions + field_definitions = {} + for param_name, param_schema in properties.items(): + param_description = param_schema.get("description", "") + is_required = param_name in required_fields + + # Use Any type for complex schemas to preserve structure + # This allows the MCP server to do its own validation + from typing import Any as AnyType + from pydantic import Field + + if is_required: + field_definitions[param_name] = (AnyType, Field(..., description=param_description)) + else: + field_definitions[param_name] = ( + AnyType | None, + Field(None, description=param_description), + ) + + # Create dynamic model + model_name = f"{tool_name.replace(' ', '').replace('-', '_')}Input" + return create_model(model_name, **field_definitions) + + +async def _create_mcp_tool_from_definition( + tool_def: dict[str, Any], + mcp_client: MCPClient, +) -> StructuredTool: + """Create a LangChain tool from an MCP tool definition. + + Args: + tool_def: Tool definition from MCP server with name, description, input_schema + mcp_client: MCP client instance for calling the tool + + Returns: + LangChain StructuredTool instance + + """ + tool_name = tool_def.get("name", "unnamed_tool") + tool_description = tool_def.get("description", "No description provided") + input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}}) + + # Log the actual schema for debugging + logger.info(f"MCP tool '{tool_name}' input schema: {input_schema}") + + # Create dynamic input model from schema + input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema) + + async def mcp_tool_call(**kwargs) -> str: + """Execute the MCP tool call via the client.""" + logger.info(f"MCP tool '{tool_name}' called with params: {kwargs}") + + # Normalize Gemini-transformed field names back to MCP schema + # Gemini transforms: entityType->type, from/to->fromEntity/toEntity, relationType->relation + normalized_kwargs = _normalize_gemini_params(kwargs, input_schema) + + try: + # Connect to server and call tool + async with mcp_client.connect(): + result = await mcp_client.call_tool(tool_name, normalized_kwargs) + return str(result) + except Exception as e: + error_msg = f"MCP tool '{tool_name}' failed: {e!s}" + logger.exception(error_msg) + return f"Error: {error_msg}" + + # Create StructuredTool with response_format to preserve exact schema + tool = StructuredTool( + name=tool_name, + description=tool_description, + coroutine=mcp_tool_call, + args_schema=input_model, + # Store the original MCP schema as metadata so we can access it later + metadata={"mcp_input_schema": input_schema}, + ) + + logger.info(f"Created MCP tool: '{tool_name}'") + return tool + + +async def load_mcp_tools( + session: AsyncSession, search_space_id: int, +) -> list[StructuredTool]: + """Load all MCP tools from user's active MCP server connectors. + + This discovers tools dynamically from MCP servers using the protocol. + + Args: + session: Database session + search_space_id: User's search space ID + + Returns: + List of LangChain StructuredTool instances + + """ + try: + # Fetch all ACTIVE MCP connectors for this search space + result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.connector_type + == SearchSourceConnectorType.MCP_CONNECTOR, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.is_active == True, # Only load active connectors + ), + ) + + tools: list[StructuredTool] = [] + for connector in result.scalars(): + try: + # Extract server config + config = connector.config or {} + server_config = config.get("server_config", {}) + + command = server_config.get("command") + args = server_config.get("args", []) + env = server_config.get("env", {}) + + if not command: + logger.warning(f"MCP connector {connector.id} missing command, skipping") + continue + + # Create MCP client + mcp_client = MCPClient(command, args, env) + + # Connect and discover tools + async with mcp_client.connect(): + tool_definitions = await mcp_client.list_tools() + + logger.info( + f"Discovered {len(tool_definitions)} tools from MCP server " + f"'{command}' (connector {connector.id})" + ) + + # Create LangChain tools from definitions + for tool_def in tool_definitions: + try: + tool = await _create_mcp_tool_from_definition(tool_def, mcp_client) + tools.append(tool) + except Exception as e: + logger.exception( + f"Failed to create tool '{tool_def.get('name')}' " + f"from connector {connector.id}: {e!s}", + ) + + except Exception as e: + logger.exception( + f"Failed to load tools from MCP connector {connector.id}: {e!s}", + ) + + logger.info(f"Loaded {len(tools)} MCP tools for search space {search_space_id}") + return tools + + except Exception as e: + logger.exception(f"Failed to load MCP tools: {e!s}") + return [] diff --git a/surfsense_backend/app/agents/new_chat/tools/registry.py b/surfsense_backend/app/agents/new_chat/tools/registry.py index c7439bf8f..bb8708b2b 100644 --- a/surfsense_backend/app/agents/new_chat/tools/registry.py +++ b/surfsense_backend/app/agents/new_chat/tools/registry.py @@ -1,5 +1,4 @@ -""" -Tools registry for SurfSense deep agent. +"""Tools registry for SurfSense deep agent. This module provides a registry pattern for managing tools in the SurfSense agent. It makes it easy for OSS contributors to add new tools by: @@ -37,6 +36,7 @@ async def my_tool(param: str) -> dict: ), """ +import logging from collections.abc import Callable from dataclasses import dataclass, field from typing import Any @@ -46,6 +46,7 @@ async def my_tool(param: str) -> dict: from .display_image import create_display_image_tool from .knowledge_base import create_search_knowledge_base_tool from .link_preview import create_link_preview_tool +from .mcp_tool import load_mcp_tools from .podcast import create_generate_podcast_tool from .scrape_webpage import create_scrape_webpage_tool from .search_surfsense_docs import create_search_surfsense_docs_tool @@ -57,8 +58,7 @@ async def my_tool(param: str) -> dict: @dataclass class ToolDefinition: - """ - Definition of a tool that can be added to the agent. + """Definition of a tool that can be added to the agent. Attributes: name: Unique identifier for the tool @@ -66,6 +66,7 @@ class ToolDefinition: factory: Callable that creates the tool. Receives a dict of dependencies. requires: List of dependency names this tool needs (e.g., "search_space_id", "db_session") enabled_by_default: Whether the tool is enabled when no explicit config is provided + """ name: str @@ -178,8 +179,7 @@ def build_tools( disabled_tools: list[str] | None = None, additional_tools: list[BaseTool] | None = None, ) -> list[BaseTool]: - """ - Build the list of tools for the agent. + """Build the list of tools for the agent. Args: dependencies: Dict containing all possible dependencies: @@ -206,6 +206,7 @@ def build_tools( # Add custom tools tools = build_tools(deps, additional_tools=[my_custom_tool]) + """ # Determine which tools to enable if enabled_tools is not None: @@ -226,8 +227,9 @@ def build_tools( # Check that all required dependencies are provided missing_deps = [dep for dep in tool_def.requires if dep not in dependencies] if missing_deps: + msg = f"Tool '{tool_def.name}' requires dependencies: {missing_deps}" raise ValueError( - f"Tool '{tool_def.name}' requires dependencies: {missing_deps}" + msg, ) # Create the tool @@ -239,3 +241,61 @@ def build_tools( tools.extend(additional_tools) return tools + + +async def build_tools_async( + dependencies: dict[str, Any], + enabled_tools: list[str] | None = None, + disabled_tools: list[str] | None = None, + additional_tools: list[BaseTool] | None = None, + include_mcp_tools: bool = True, +) -> list[BaseTool]: + """Async version of build_tools that also loads MCP tools from database. + + Design Note: + This function exists because MCP tools require database queries to load user configs, + while built-in tools are created synchronously from static code. + + Alternative: We could make build_tools() itself async and always query the database, + but that would force async everywhere even when only using built-in tools. The current + design keeps the simple case (static tools only) synchronous while supporting dynamic + database-loaded tools through this async wrapper. + + Args: + dependencies: Dict containing all possible dependencies + enabled_tools: Explicit list of tool names to enable. If None, uses defaults. + disabled_tools: List of tool names to disable (applied after enabled_tools). + additional_tools: Extra tools to add (e.g., custom tools not in registry). + include_mcp_tools: Whether to load user's MCP tools from database. + + Returns: + List of configured tool instances ready for the agent, including MCP tools. + + """ + # Build standard tools + tools = build_tools(dependencies, enabled_tools, disabled_tools, additional_tools) + + # Load MCP tools if requested and dependencies are available + if ( + include_mcp_tools + and "db_session" in dependencies + and "search_space_id" in dependencies + ): + try: + mcp_tools = await load_mcp_tools( + dependencies["db_session"], dependencies["search_space_id"], + ) + tools.extend(mcp_tools) + logging.info( + f"Registered {len(mcp_tools)} MCP tools: {[t.name for t in mcp_tools]}", + ) + except Exception as e: + # Log error but don't fail - just continue without MCP tools + logging.exception(f"Failed to load MCP tools: {e!s}") + + # Log all tools being returned to agent + logging.info( + f"Total tools for agent: {len(tools)} - {[t.name for t in tools]}", + ) + + return tools diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 73727a9ef..1b184a24a 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -80,6 +80,7 @@ class SearchSourceConnectorType(str, Enum): WEBCRAWLER_CONNECTOR = "WEBCRAWLER_CONNECTOR" BOOKSTACK_CONNECTOR = "BOOKSTACK_CONNECTOR" CIRCLEBACK_CONNECTOR = "CIRCLEBACK_CONNECTOR" + MCP_CONNECTOR = "MCP_CONNECTOR" # Model Context Protocol - User-defined API tools class LiteLLMProvider(str, Enum): @@ -605,13 +606,15 @@ class SearchSourceConnector(BaseModel, TimestampMixin): "search_space_id", "user_id", "connector_type", - name="uq_searchspace_user_connector_type", + "name", + name="uq_searchspace_user_connector_type_name", ), ) name = Column(String(100), nullable=False, index=True) connector_type = Column(SQLAlchemyEnum(SearchSourceConnectorType), nullable=False) is_indexable = Column(Boolean, nullable=False, default=False) + is_active = Column(Boolean, nullable=False, default=True) # Enable/disable connector last_indexed_at = Column(TIMESTAMP(timezone=True), nullable=True) config = Column(JSON, nullable=False) diff --git a/surfsense_backend/app/routes/search_source_connectors_routes.py b/surfsense_backend/app/routes/search_source_connectors_routes.py index 8e8ebb72d..a7c577bba 100644 --- a/surfsense_backend/app/routes/search_source_connectors_routes.py +++ b/surfsense_backend/app/routes/search_source_connectors_routes.py @@ -7,6 +7,13 @@ DELETE /search-source-connectors/{connector_id} - Delete a specific connector POST /search-source-connectors/{connector_id}/index - Index content from a connector to a search space +MCP (Model Context Protocol) Connector routes: +POST /connectors/mcp - Create a new MCP connector with custom API tools +GET /connectors/mcp - List all MCP connectors for the current user's search space +GET /connectors/mcp/{connector_id} - Get a specific MCP connector with tools config +PUT /connectors/mcp/{connector_id} - Update an MCP connector's tools config +DELETE /connectors/mcp/{connector_id} - Delete an MCP connector + Note: OAuth connectors (Gmail, Drive, Slack, etc.) support multiple accounts per search space. Non-OAuth connectors (BookStack, GitHub, etc.) are limited to one per search space. """ @@ -32,6 +39,9 @@ ) from app.schemas import ( GoogleDriveIndexRequest, + MCPConnectorCreate, + MCPConnectorRead, + MCPConnectorUpdate, SearchSourceConnectorBase, SearchSourceConnectorCreate, SearchSourceConnectorRead, @@ -127,18 +137,20 @@ async def create_search_source_connector( # Check if a connector with the same type already exists for this search space # (for non-OAuth connectors that don't support multiple accounts) - result = await session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.connector_type == connector.connector_type, - ) - ) - existing_connector = result.scalars().first() - if existing_connector: - raise HTTPException( - status_code=409, - detail=f"A connector with type {connector.connector_type} already exists in this search space.", + # Exception: MCP_CONNECTOR can have multiple instances with different names + if connector.connector_type != SearchSourceConnectorType.MCP_CONNECTOR: + result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.connector_type == connector.connector_type, + ) ) + existing_connector = result.scalars().first() + if existing_connector: + raise HTTPException( + status_code=409, + detail=f"A connector with type {connector.connector_type} already exists in this search space.", + ) # Prepare connector data connector_data = connector.model_dump() @@ -1964,3 +1976,348 @@ async def run_bookstack_indexing( f"Critical error in run_bookstack_indexing for connector {connector_id}: {e}", exc_info=True, ) + + +# ============================================================================= +# MCP Connector Routes +# ============================================================================= + + +@router.post("/connectors/mcp", response_model=MCPConnectorRead, status_code=201) +async def create_mcp_connector( + connector_data: MCPConnectorCreate, + search_space_id: int = Query(..., description="Search space ID"), + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Create a new MCP (Model Context Protocol) connector. + + MCP connectors allow users to connect to MCP servers (like in Cursor). + Tools are auto-discovered from the server - no manual configuration needed. + + Args: + connector_data: MCP server configuration (command, args, env) + search_space_id: ID of the search space to attach the connector to + session: Database session + user: Current authenticated user + + Returns: + Created MCP connector with server configuration + + Raises: + HTTPException: If search space not found or permission denied + """ + try: + # Check user has permission to create connectors + await check_permission( + session, + user, + search_space_id, + Permission.CONNECTORS_CREATE.value, + "You don't have permission to create connectors in this search space", + ) + + # Create the connector with server config + db_connector = SearchSourceConnector( + name=connector_data.name, + connector_type=SearchSourceConnectorType.MCP_CONNECTOR, + is_indexable=False, # MCP connectors are not indexable + config={"server_config": connector_data.server_config.model_dump()}, + periodic_indexing_enabled=False, + indexing_frequency_minutes=None, + search_space_id=search_space_id, + user_id=user.id, + ) + + session.add(db_connector) + await session.commit() + await session.refresh(db_connector) + + logger.info( + f"Created MCP connector {db_connector.id} for server '{connector_data.server_config.command}' " + f"for user {user.id} in search space {search_space_id}" + ) + + # Convert to read schema + connector_read = SearchSourceConnectorRead.model_validate(db_connector) + return MCPConnectorRead.from_connector(connector_read) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to create MCP connector: {e!s}", exc_info=True) + await session.rollback() + raise HTTPException( + status_code=500, detail=f"Failed to create MCP connector: {e!s}" + ) from e + + +@router.get("/connectors/mcp", response_model=list[MCPConnectorRead]) +async def list_mcp_connectors( + search_space_id: int = Query(..., description="Search space ID"), + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + List all MCP connectors for a search space. + + Args: + search_space_id: ID of the search space + session: Database session + user: Current authenticated user + + Returns: + List of MCP connectors with their tool configurations + """ + try: + # Check user has permission to read connectors + await check_permission( + session, + user, + search_space_id, + Permission.CONNECTORS_READ.value, + "You don't have permission to view connectors in this search space", + ) + + # Fetch MCP connectors + result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.connector_type + == SearchSourceConnectorType.MCP_CONNECTOR, + SearchSourceConnector.search_space_id == search_space_id, + ) + ) + + connectors = result.scalars().all() + return [ + MCPConnectorRead.from_connector(SearchSourceConnectorRead.model_validate(c)) + for c in connectors + ] + + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to list MCP connectors: {e!s}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to list MCP connectors: {e!s}" + ) from e + + +@router.get("/connectors/mcp/{connector_id}", response_model=MCPConnectorRead) +async def get_mcp_connector( + connector_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Get a specific MCP connector by ID. + + Args: + connector_id: ID of the connector + session: Database session + user: Current authenticated user + + Returns: + MCP connector with tool configurations + """ + try: + # Fetch connector + result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == connector_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.MCP_CONNECTOR, + ) + ) + connector = result.scalars().first() + + if not connector: + raise HTTPException(status_code=404, detail="MCP connector not found") + + # Check user has permission to read connectors + await check_permission( + session, + user, + connector.search_space_id, + Permission.CONNECTORS_READ.value, + "You don't have permission to view this connector", + ) + + connector_read = SearchSourceConnectorRead.model_validate(connector) + return MCPConnectorRead.from_connector(connector_read) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to get MCP connector: {e!s}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to get MCP connector: {e!s}" + ) from e + + +@router.put("/connectors/mcp/{connector_id}", response_model=MCPConnectorRead) +async def update_mcp_connector( + connector_id: int, + connector_update: MCPConnectorUpdate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Update an MCP connector. + + Args: + connector_id: ID of the connector to update + connector_update: Updated connector data + session: Database session + user: Current authenticated user + + Returns: + Updated MCP connector + """ + try: + # Fetch connector + result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == connector_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.MCP_CONNECTOR, + ) + ) + connector = result.scalars().first() + + if not connector: + raise HTTPException(status_code=404, detail="MCP connector not found") + + # Check user has permission to update connectors + await check_permission( + session, + user, + connector.search_space_id, + Permission.CONNECTORS_UPDATE.value, + "You don't have permission to update this connector", + ) + + # Update fields + if connector_update.name is not None: + connector.name = connector_update.name + + if connector_update.server_config is not None: + connector.config = { + "server_config": connector_update.server_config.model_dump() + } + + connector.updated_at = datetime.now(UTC) + + await session.commit() + await session.refresh(connector) + + logger.info(f"Updated MCP connector {connector_id}") + + connector_read = SearchSourceConnectorRead.model_validate(connector) + return MCPConnectorRead.from_connector(connector_read) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to update MCP connector: {e!s}", exc_info=True) + await session.rollback() + raise HTTPException( + status_code=500, detail=f"Failed to update MCP connector: {e!s}" + ) from e + + +@router.delete("/connectors/mcp/{connector_id}", status_code=204) +async def delete_mcp_connector( + connector_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Delete an MCP connector. + + Args: + connector_id: ID of the connector to delete + session: Database session + user: Current authenticated user + """ + try: + # Fetch connector + result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == connector_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.MCP_CONNECTOR, + ) + ) + connector = result.scalars().first() + + if not connector: + raise HTTPException(status_code=404, detail="MCP connector not found") + + # Check user has permission to delete connectors + await check_permission( + session, + user, + connector.search_space_id, + Permission.CONNECTORS_DELETE.value, + "You don't have permission to delete this connector", + ) + + await session.delete(connector) + await session.commit() + + logger.info(f"Deleted MCP connector {connector_id}") + + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to delete MCP connector: {e!s}", exc_info=True) + await session.rollback() + raise HTTPException( + status_code=500, detail=f"Failed to delete MCP connector: {e!s}" + ) from e + + +@router.post("/connectors/mcp/test") +async def test_mcp_server_connection( + server_config: dict = Body(...), + user: User = Depends(current_active_user), +): + """ + Test connection to an MCP server and fetch available tools. + + This endpoint allows users to test their MCP server configuration + before saving it, similar to Cursor's flow. + + Args: + server_config: Server configuration with command, args, env + user: Current authenticated user + + Returns: + Connection status and list of available tools + """ + try: + from app.agents.new_chat.tools.mcp_client import test_mcp_connection + + command = server_config.get("command") + args = server_config.get("args", []) + env = server_config.get("env", {}) + + if not command: + raise HTTPException(status_code=400, detail="Server command is required") + + # Test the connection + result = await test_mcp_connection(command, args, env) + + return result + + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to test MCP connection: {e!s}", exc_info=True) + return { + "status": "error", + "message": f"Failed to test connection: {e!s}", + "tools": [], + } diff --git a/surfsense_backend/app/schemas/__init__.py b/surfsense_backend/app/schemas/__init__.py index a8bde7ed9..076ac5915 100644 --- a/surfsense_backend/app/schemas/__init__.py +++ b/surfsense_backend/app/schemas/__init__.py @@ -55,6 +55,10 @@ UserSearchSpaceAccess, ) from .search_source_connector import ( + MCPConnectorCreate, + MCPConnectorRead, + MCPConnectorUpdate, + MCPServerConfig, SearchSourceConnectorBase, SearchSourceConnectorCreate, SearchSourceConnectorRead, @@ -108,6 +112,11 @@ "LogFilter", "LogRead", "LogUpdate", + # Search source connector schemas + "MCPConnectorCreate", + "MCPConnectorRead", + "MCPConnectorUpdate", + "MCPServerConfig", "MembershipRead", "MembershipReadWithUser", "MembershipUpdate", @@ -135,7 +144,6 @@ "RoleCreate", "RoleRead", "RoleUpdate", - # Search source connector schemas "SearchSourceConnectorBase", "SearchSourceConnectorCreate", "SearchSourceConnectorRead", diff --git a/surfsense_backend/app/schemas/search_source_connector.py b/surfsense_backend/app/schemas/search_source_connector.py index dbe4dce1f..3f6125614 100644 --- a/surfsense_backend/app/schemas/search_source_connector.py +++ b/surfsense_backend/app/schemas/search_source_connector.py @@ -14,6 +14,7 @@ class SearchSourceConnectorBase(BaseModel): name: str connector_type: SearchSourceConnectorType is_indexable: bool + is_active: bool = True last_indexed_at: datetime | None = None config: dict[str, Any] periodic_indexing_enabled: bool = False @@ -23,7 +24,7 @@ class SearchSourceConnectorBase(BaseModel): @field_validator("config") @classmethod def validate_config_for_connector_type( - cls, config: dict[str, Any], values: dict[str, Any] + cls, config: dict[str, Any], values: dict[str, Any], ) -> dict[str, Any]: connector_type = values.data.get("connector_type") return validate_connector_config(connector_type, config) @@ -38,15 +39,18 @@ def validate_periodic_indexing(self): """ if self.periodic_indexing_enabled: if not self.is_indexable: + msg = "periodic_indexing_enabled can only be True for indexable connectors" raise ValueError( - "periodic_indexing_enabled can only be True for indexable connectors" + msg, ) if self.indexing_frequency_minutes is None: + msg = "indexing_frequency_minutes is required when periodic_indexing_enabled is True" raise ValueError( - "indexing_frequency_minutes is required when periodic_indexing_enabled is True" + msg, ) if self.indexing_frequency_minutes <= 0: - raise ValueError("indexing_frequency_minutes must be greater than 0") + msg = "indexing_frequency_minutes must be greater than 0" + raise ValueError(msg) return self @@ -58,6 +62,7 @@ class SearchSourceConnectorUpdate(BaseModel): name: str | None = None connector_type: SearchSourceConnectorType | None = None is_indexable: bool | None = None + is_active: bool | None = None last_indexed_at: datetime | None = None config: dict[str, Any] | None = None periodic_indexing_enabled: bool | None = None @@ -70,3 +75,63 @@ class SearchSourceConnectorRead(SearchSourceConnectorBase, IDModel, TimestampMod user_id: uuid.UUID model_config = ConfigDict(from_attributes=True) + + +# ============================================================================= +# MCP-specific schemas +# ============================================================================= + + +class MCPServerConfig(BaseModel): + """Configuration for an MCP server connection (similar to Cursor's config).""" + + command: str # e.g., "uvx", "node", "python" + args: list[str] = [] # e.g., ["mcp-server-git", "--repository", "/path"] + env: dict[str, str] = {} # Environment variables for the server process + transport: str = "stdio" # "stdio" | "sse" | "http" (stdio is most common) + + +class MCPConnectorCreate(BaseModel): + """Schema for creating an MCP connector.""" + + name: str + server_config: MCPServerConfig + + +class MCPConnectorUpdate(BaseModel): + """Schema for updating an MCP connector.""" + + name: str | None = None + server_config: MCPServerConfig | None = None + + +class MCPConnectorRead(BaseModel): + """Schema for reading an MCP connector with server config.""" + + id: int + name: str + connector_type: SearchSourceConnectorType + server_config: MCPServerConfig + search_space_id: int + user_id: uuid.UUID + created_at: datetime + updated_at: datetime + + model_config = ConfigDict(from_attributes=True) + + @classmethod + def from_connector(cls, connector: SearchSourceConnectorRead) -> "MCPConnectorRead": + """Convert from base SearchSourceConnectorRead.""" + config = connector.config or {} + server_config = MCPServerConfig(**config.get("server_config", {})) + + return cls( + id=connector.id, + name=connector.name, + connector_type=connector.connector_type, + server_config=server_config, + search_space_id=connector.search_space_id, + user_id=connector.user_id, + created_at=connector.created_at, + updated_at=connector.updated_at, + ) diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index a74f134dc..5f8cd638b 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -237,7 +237,7 @@ async def stream_new_chat( checkpointer = await get_checkpointer() # Create the deep agent with checkpointer and configurable prompts - agent = create_surfsense_deep_agent( + agent = await create_surfsense_deep_agent( llm=llm, search_space_id=search_space_id, db_session=session, diff --git a/surfsense_backend/app/utils/connector_naming.py b/surfsense_backend/app/utils/connector_naming.py index 731f419d6..bfc9decdd 100644 --- a/surfsense_backend/app/utils/connector_naming.py +++ b/surfsense_backend/app/utils/connector_naming.py @@ -27,6 +27,7 @@ SearchSourceConnectorType.DISCORD_CONNECTOR: "Discord", SearchSourceConnectorType.CONFLUENCE_CONNECTOR: "Confluence", SearchSourceConnectorType.AIRTABLE_CONNECTOR: "Airtable", + SearchSourceConnectorType.MCP_CONNECTOR: "Model Context Protocol (MCP)", } diff --git a/surfsense_backend/pyproject.toml b/surfsense_backend/pyproject.toml index e3e7583f8..83a00b4e4 100644 --- a/surfsense_backend/pyproject.toml +++ b/surfsense_backend/pyproject.toml @@ -57,6 +57,9 @@ dependencies = [ "chonkie[all]>=1.5.0", "langgraph-checkpoint-postgres>=3.0.2", "psycopg[binary,pool]>=3.3.2", + "mcp>=1.25.0", + "starlette>=0.40.0,<0.51.0", + "sse-starlette>=3.1.1,<3.1.2", ] [dependency-groups] diff --git a/surfsense_web/components/assistant-ui/connector-popup.tsx b/surfsense_web/components/assistant-ui/connector-popup.tsx index 1e6dd09ae..28fe5b5b0 100644 --- a/surfsense_web/components/assistant-ui/connector-popup.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup.tsx @@ -68,6 +68,7 @@ export const ConnectorIndicator: FC = () => { setEndDate, setPeriodicEnabled, setFrequencyMinutes, + setOtherMCPConnectorIds, handleOpenChange, handleTabChange, handleScroll, @@ -239,6 +240,8 @@ export const ConnectorIndicator: FC = () => { isSaving={isSaving} isDisconnecting={isDisconnecting} isIndexing={indexingConnectorIds.has(editingConnector.id)} + searchSpaceId={searchSpaceId?.toString()} + onOtherMCPConnectorsLoaded={setOtherMCPConnectorIds} onStartDateChange={setStartDate} onEndDateChange={setEndDate} onPeriodicEnabledChange={setPeriodicEnabled} diff --git a/surfsense_web/components/assistant-ui/connector-popup/components/connector-card.tsx b/surfsense_web/components/assistant-ui/connector-popup/components/connector-card.tsx index fa4b8feb6..bb35ee8a3 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/components/connector-card.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/components/connector-card.tsx @@ -27,6 +27,16 @@ interface ConnectorCardProps { onManage?: () => void; } +/** + * Check if a connector type is indexable (has documents) + * MCP connectors are tools only and don't have indexable content + */ +function isIndexableConnector(connectorType?: string): boolean { + if (!connectorType) return true; // Default to true for unknown types + const nonIndexableTypes = ["MCP_CONNECTOR"]; + return !nonIndexableTypes.includes(connectorType); +} + /** * Extract a number from the active task message for display * Looks for patterns like "45 indexed", "Processing 123", etc. @@ -135,7 +145,12 @@ export const ConnectorCard: FC = ({ } if (isConnected) { - // Show last indexed date for connected connectors + // For non-indexable connectors (like MCP), show description instead of index status + if (!isIndexableConnector(connectorType)) { + return description; + } + + // Show last indexed date for connected indexable connectors if (lastIndexedAt) { return ( diff --git a/surfsense_web/components/assistant-ui/connector-popup/connect-forms/components/mcp-connect-form.tsx b/surfsense_web/components/assistant-ui/connector-popup/connect-forms/components/mcp-connect-form.tsx new file mode 100644 index 000000000..3ef43b3db --- /dev/null +++ b/surfsense_web/components/assistant-ui/connector-popup/connect-forms/components/mcp-connect-form.tsx @@ -0,0 +1,276 @@ +"use client"; + +import { CheckCircle2, Server, XCircle } from "lucide-react"; +import { type FC, useRef, useState } from "react"; +import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; +import { Button } from "@/components/ui/button"; +import { Label } from "@/components/ui/label"; +import { Textarea } from "@/components/ui/textarea"; +import { EnumConnectorName } from "@/contracts/enums/connector"; +import type { MCPServerConfig, MCPToolDefinition } from "@/contracts/types/mcp.types"; +import { connectorsApiService } from "@/lib/apis/connectors-api.service"; +import type { ConnectFormProps } from ".."; + +const DEFAULT_CONFIG = `[ + { + "name": "MCP Server 1", + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", "/path/to/allowed/directory"], + "env": {}, + "transport": "stdio" + } +]`; + +interface MCPServerWithName extends MCPServerConfig { + name: string; +} + +export const MCPConnectForm: FC = ({ onSubmit, isSubmitting }) => { + const isSubmittingRef = useRef(false); + const [configJson, setConfigJson] = useState(DEFAULT_CONFIG); + const [jsonError, setJsonError] = useState(null); + const [isTesting, setIsTesting] = useState(false); + const [testResults, setTestResults] = useState | null>(null); + + const parseConfigs = (): { configs: MCPServerWithName[] | null; error: string | null } => { + try { + const parsed = JSON.parse(configJson); + + // Must be an array + if (!Array.isArray(parsed)) { + return { + configs: null, + error: "Configuration must be an array of MCP server objects", + }; + } + + if (parsed.length === 0) { + return { + configs: null, + error: "Array must contain at least one MCP server configuration", + }; + } + + // Validate each server config + const configs: MCPServerWithName[] = []; + for (let i = 0; i < parsed.length; i++) { + const server = parsed[i]; + + if (!server.name || typeof server.name !== "string") { + return { + configs: null, + error: `Server ${i + 1}: 'name' field is required and must be a string`, + }; + } + + if (!server.command || typeof server.command !== "string") { + return { + configs: null, + error: `Server ${i + 1} (${server.name}): 'command' field is required and must be a string`, + }; + } + + configs.push({ + name: server.name, + command: server.command, + args: Array.isArray(server.args) ? server.args : [], + env: typeof server.env === "object" && server.env !== null ? server.env : {}, + transport: server.transport || "stdio", + }); + } + + return { configs, error: null }; + } catch (error) { + return { + configs: null, + error: error instanceof Error ? error.message : "Invalid JSON", + }; + } + }; + + const handleConfigChange = (value: string) => { + setConfigJson(value); + if (jsonError) { + setJsonError(null); + } + }; + + const handleTestConnection = async () => { + const { configs, error } = parseConfigs(); + + if (!configs || error) { + setJsonError(error); + setTestResults([{ + name: "Parse Error", + status: "error", + message: error || "Invalid configuration", + tools: [], + }]); + return; + } + + setIsTesting(true); + setTestResults(null); + setJsonError(null); + + const results: Array<{ + name: string; + status: "success" | "error"; + message: string; + tools: MCPToolDefinition[]; + }> = []; + + for (const config of configs) { + try { + const result = await connectorsApiService.testMCPConnection(config); + results.push({ + name: config.name, + ...result, + }); + } catch (error) { + results.push({ + name: config.name, + status: "error", + message: error instanceof Error ? error.message : "Failed to connect to MCP server", + tools: [], + }); + } + } + + setTestResults(results); + setIsTesting(false); + }; + + const handleSubmit = async (e: React.FormEvent) => { + e.preventDefault(); + + // Prevent multiple submissions + if (isSubmittingRef.current || isSubmitting) { + return; + } + + const { configs, error } = parseConfigs(); + + if (!configs || error) { + setJsonError(error); + alert(error || "Invalid JSON configuration"); + return; + } + + isSubmittingRef.current = true; + try { + // Submit all servers + for (const config of configs) { + await onSubmit({ + name: config.name, + connector_type: EnumConnectorName.MCP_CONNECTOR, + config: { server_config: config }, + is_indexable: false, + is_active: true, + last_indexed_at: null, + periodic_indexing_enabled: false, + indexing_frequency_minutes: null, + next_scheduled_at: null, + }); + } + } finally { + isSubmittingRef.current = false; + } + }; + + return ( +
+ + +
+ MCP Servers + + Connect to one or more MCP (Model Context Protocol) servers. Paste a JSON array of server configurations below. + +
+
+ +
+
+
+ +