diff --git a/README.md b/README.md index 5a4e732..5f52344 100644 --- a/README.md +++ b/README.md @@ -3,14 +3,16 @@ **Web search and content extraction for AI models via Model Context Protocol (MCP)** [![Version](https://img.shields.io/badge/version-2.2.0-blue.svg)](https://github.com/Kode-Rex/webcat) -[![Docker](https://img.shields.io/badge/docker-ready-brightgreen.svg)](https://hub.docker.com/r/tmfrisinger/webcat) [![License](https://img.shields.io/badge/license-MIT-green.svg)](LICENSE) ## Quick Start ```bash -# Run WebCat with Docker (30 seconds to working demo) -docker run -p 8000:8000 tmfrisinger/webcat:latest +cd docker +python -m pip install -e ".[dev]" + +# Start demo server with UI +python simple_demo.py # Open demo client open http://localhost:8000/demo @@ -30,40 +32,22 @@ Built with **FastAPI** and **FastMCP** for seamless AI integration. ## Features -- ✅ **No Authentication Required** - Simple setup +- ✅ **Optional Authentication** - Bearer token auth when needed, or run without - ✅ **Automatic Fallback** - Serper API → DuckDuckGo if needed - ✅ **Smart Content Extraction** - Trafilatura removes navigation/ads/chrome - ✅ **MCP Compliant** - Works with Claude Desktop, LiteLLM, etc. - ✅ **Rate Limited** - Configurable protection -- ✅ **Docker Ready** - One command deployment - ✅ **Parallel Processing** - Fast concurrent scraping ## Installation & Usage -### Docker (Recommended) - -```bash -# With Serper API (best results) -docker run -p 8000:8000 -e SERPER_API_KEY=your_key tmfrisinger/webcat:2.2.0 - -# Free tier (DuckDuckGo only) -docker run -p 8000:8000 tmfrisinger/webcat:2.2.0 - -# Custom configuration -docker run -p 9000:9000 \ - -e PORT=9000 \ - -e SERPER_API_KEY=your_key \ - -e RATE_LIMIT_WINDOW=60 \ - -e RATE_LIMIT_MAX_REQUESTS=10 \ - tmfrisinger/webcat:2.2.0 -``` - -### Local Development - ```bash cd docker python -m pip install -e ".[dev]" +# Configure environment (optional) +echo "SERPER_API_KEY=your_key" > .env + # Start MCP server python mcp_server.py @@ -88,6 +72,7 @@ python simple_demo.py | Variable | Default | Description | |----------|---------|-------------| | `SERPER_API_KEY` | *(none)* | Serper API key for premium search (optional) | +| `WEBCAT_API_KEY` | *(none)* | Bearer token for authentication (optional, if set all requests must include `Authorization: Bearer `) | | `PORT` | `8000` | Server port | | `LOG_LEVEL` | `INFO` | Logging level (DEBUG, INFO, WARNING, ERROR) | | `LOG_DIR` | `/tmp` | Log file directory | @@ -99,7 +84,17 @@ python simple_demo.py 1. Visit [serper.dev](https://serper.dev) 2. Sign up for free tier (2,500 searches/month) 3. Copy your API key -4. Pass to Docker: `-e SERPER_API_KEY=your_key` +4. Add to `.env` file: `SERPER_API_KEY=your_key` + +### Enable Authentication (Optional) + +To require bearer token authentication for all MCP tool calls: + +1. Generate a secure random token: `openssl rand -hex 32` +2. Add to `.env` file: `WEBCAT_API_KEY=your_token` +3. Include in all requests: `Authorization: Bearer your_token` + +**Note:** If `WEBCAT_API_KEY` is not set, no authentication is required. ## MCP Tools @@ -216,7 +211,6 @@ MIT License - see [LICENSE](LICENSE) file for details. ## Links - **GitHub:** [github.com/Kode-Rex/webcat](https://github.com/Kode-Rex/webcat) -- **Docker Hub:** [hub.docker.com/r/tmfrisinger/webcat](https://hub.docker.com/r/tmfrisinger/webcat) - **MCP Spec:** [modelcontextprotocol.io](https://modelcontextprotocol.io) - **Serper API:** [serper.dev](https://serper.dev) diff --git a/docker/.env.example b/docker/.env.example new file mode 100644 index 0000000..c412dd5 --- /dev/null +++ b/docker/.env.example @@ -0,0 +1,17 @@ +# Serper API key for premium search (optional) +# If not set, DuckDuckGo fallback will be used +SERPER_API_KEY= + +# WebCat API key for bearer token authentication (optional) +# If set, all requests must include: Authorization: Bearer +# If not set, no authentication is required +WEBCAT_API_KEY= + +# Server configuration +PORT=8000 +LOG_LEVEL=INFO +LOG_DIR=/tmp + +# Rate limiting +RATE_LIMIT_WINDOW=60 +RATE_LIMIT_MAX_REQUESTS=10 diff --git a/docker/constants.py b/docker/constants.py index b12f854..54a30a2 100644 --- a/docker/constants.py +++ b/docker/constants.py @@ -6,7 +6,7 @@ """Constants for WebCat application.""" # Application version -VERSION = "2.3.0" +VERSION = "2.3.1" # Service information SERVICE_NAME = "WebCat MCP Server" @@ -19,8 +19,6 @@ "Content extraction and scraping", "Markdown conversion", "FastMCP protocol support", - "SSE streaming", - "Demo UI client", ] # Content limits diff --git a/docker/endpoints/health_endpoints.py b/docker/endpoints/health_endpoints.py index d20cdfd..b439ea1 100644 --- a/docker/endpoints/health_endpoints.py +++ b/docker/endpoints/health_endpoints.py @@ -10,7 +10,6 @@ from fastapi import FastAPI from fastapi.responses import JSONResponse -from endpoints.demo_client import serve_demo_client from models.health_responses import ( get_detailed_status, get_health_status, @@ -34,11 +33,6 @@ async def health_check(): logger.error(f"Health check failed: {str(e)}") return JSONResponse(status_code=500, content=get_unhealthy_status(str(e))) - @app.get("/demo") - async def sse_client(): - """Serve the WebCat SSE demo client.""" - return serve_demo_client() - @app.get("/status") async def server_status(): """Detailed server status endpoint.""" diff --git a/docker/models/responses/health_responses.py b/docker/models/responses/health_responses.py index ec62722..4543126 100644 --- a/docker/models/responses/health_responses.py +++ b/docker/models/responses/health_responses.py @@ -44,11 +44,9 @@ def get_server_configuration() -> dict: def get_server_endpoints() -> dict: """Get server endpoints dictionary.""" return { - "main_mcp": "/mcp", - "sse_demo": "/sse", + "mcp": "/mcp", "health": "/health", "status": "/status", - "demo_client": "/demo", } @@ -81,10 +79,9 @@ def get_root_info() -> dict: "version": VERSION, "description": "Web search and content extraction with MCP protocol support", "endpoints": { - "demo_client": "/demo", + "mcp": "/mcp", "health": "/health", "status": "/status", - "mcp_sse": "/mcp", }, - "documentation": "Access /demo for the demo interface", + "documentation": "MCP server - connect via SSE at /mcp/sse endpoint", } diff --git a/docker/simple_demo.py b/docker/simple_demo.py index 8133be8..cae6550 100755 --- a/docker/simple_demo.py +++ b/docker/simple_demo.py @@ -4,28 +4,19 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -"""Simplified demo server that combines health and SSE endpoints in one FastAPI app.""" +"""Simplified demo server with FastMCP integration.""" -import asyncio import logging import os import tempfile -import time import uvicorn from dotenv import load_dotenv -from fastapi import FastAPI, Query +from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import StreamingResponse from fastmcp import FastMCP from api_tools import create_webcat_functions, setup_webcat_tools -from demo_utils import ( - format_sse_message, - get_server_info, - handle_health_operation, - handle_search_operation, -) from health import setup_health_endpoints # Load environment variables @@ -44,73 +35,14 @@ logger.setLevel(getattr(logging, LOG_LEVEL)) -async def _generate_webcat_stream( - webcat_functions, operation: str, query: str, max_results: int -): - """Generate SSE stream for WebCat operations. - - Args: - webcat_functions: Dictionary of WebCat functions - operation: Operation to perform - query: Search query - max_results: Maximum results - - Yields: - SSE formatted messages - """ - try: - # Send connection message - yield format_sse_message( - "connection", - status="connected", - message="WebCat stream started", - operation=operation, - ) - - if operation == "search" and query: - search_func = webcat_functions.get("search") - if search_func: - async for msg in handle_search_operation( - search_func, query, max_results - ): - yield msg - else: - yield format_sse_message( - "error", message="Search function not available" - ) - - elif operation == "health": - health_func = webcat_functions.get("health_check") - async for msg in handle_health_operation(health_func): - yield msg - - else: - # Just connection - send server info - yield format_sse_message("data", data=get_server_info()) - yield format_sse_message("complete", message="Connection established") - - # Keep alive with heartbeat - heartbeat_count = 0 - while True: - await asyncio.sleep(30) - heartbeat_count += 1 - yield format_sse_message( - "heartbeat", timestamp=time.time(), count=heartbeat_count - ) - - except Exception as e: - logger.error(f"Error in SSE stream: {str(e)}") - yield format_sse_message("error", message=str(e)) - - def create_demo_app(): """Create a single FastAPI app with all endpoints.""" # Create FastAPI app with CORS middleware app = FastAPI( - title="WebCat MCP Demo Server", - description="WebCat server with FastMCP integration and SSE streaming demo", - version="2.2.0", + title="WebCat MCP Server", + description="WebCat server with FastMCP integration", + version="2.3.1", ) app.add_middleware( @@ -131,31 +63,10 @@ def create_demo_app(): webcat_functions = create_webcat_functions() setup_webcat_tools(mcp_server, webcat_functions) - # Add custom SSE endpoint for demo - @app.get("/sse") - async def webcat_stream( - operation: str = Query( - "connect", description="Operation to perform: connect, search, health" - ), - query: str = Query("", description="Search query for search operations"), - max_results: int = Query(5, description="Maximum number of search results"), - ): - """Stream WebCat functionality via SSE""" - return StreamingResponse( - _generate_webcat_stream(webcat_functions, operation, query, max_results), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Headers": "*", - }, - ) - - # Mount FastMCP server as a sub-application (like Clima project) + # Mount FastMCP server app.mount("/mcp", mcp_server.sse_app()) - logger.info("FastAPI app configured with SSE and FastMCP integration") + logger.info("FastAPI app configured with FastMCP integration") return app @@ -166,20 +77,18 @@ def run_simple_demo(host: str = "0.0.0.0", port: int = 8000): app = create_demo_app() # Log endpoints - logger.info(f"WebCat Demo Server: http://{host}:{port}") - logger.info(f"SSE Demo Endpoint: http://{host}:{port}/sse") + logger.info(f"WebCat MCP Server: http://{host}:{port}") logger.info(f"FastMCP Endpoint: http://{host}:{port}/mcp") logger.info(f"Health Check: http://{host}:{port}/health") logger.info(f"Demo Client: http://{host}:{port}/demo") logger.info(f"Server Status: http://{host}:{port}/status") - print("\n🐱 WebCat MCP Demo Server Starting...") + print("\n🐱 WebCat MCP Server Starting...") print(f"📡 Server: http://{host}:{port}") - print(f"🔗 SSE Demo: http://{host}:{port}/sse") - print(f"🛠️ FastMCP: http://{host}:{port}/mcp") + print(f"🛠️ MCP Endpoint: http://{host}:{port}/mcp") print(f"💗 Health: http://{host}:{port}/health") print(f"🎨 Demo UI: http://{host}:{port}/demo") - print(f"📊 Server Status: http://{host}:{port}/status") + print(f"📊 Status: http://{host}:{port}/status") print("\n✨ Ready for connections!") # Run the server diff --git a/docker/test_mcp_client.py b/docker/test_mcp_client.py new file mode 100755 index 0000000..0267944 --- /dev/null +++ b/docker/test_mcp_client.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024 Travis Frisinger +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Simple MCP client to test WebCat tools via SSE transport.""" + +import asyncio +import os + +from dotenv import load_dotenv +from mcp import ClientSession +from mcp.client.sse import sse_client + +# Load environment variables from .env file +load_dotenv() + + +async def test_search(): + """Test the search tool via SSE transport.""" + # Get optional bearer token from environment + api_key = os.environ.get("WEBCAT_API_KEY", "") + + # Set up headers for authentication if needed + headers = {} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + print("🔐 Using bearer token authentication") + else: + print("🔓 No authentication (WEBCAT_API_KEY not set)") + + # Connect to the FastMCP SSE endpoint + url = "http://localhost:8000/mcp/sse" + + print(f"📡 Connecting to {url}...") + + async with sse_client(url, headers=headers) as (read, write): + async with ClientSession(read, write) as session: + await session.initialize() + print("✅ Connected to MCP server") + + # List available tools + tools = await session.list_tools() + print("\n🛠️ Available tools:") + for tool in tools.tools: + print(f" - {tool.name}: {tool.description}") + + # Call search tool + print("\n🔍 Testing search tool with query 'test'...") + result = await session.call_tool("search", arguments={"query": "test"}) + + print("\n📊 Search result:") + print(f" Content length: {len(str(result.content))}") + print( + f" Is error: {result.isError if hasattr(result, 'isError') else 'N/A'}" + ) + if hasattr(result, "content") and result.content: + # Print first 200 chars of content + content_preview = str(result.content)[:200] + print(f" Preview: {content_preview}...") + + +if __name__ == "__main__": + print("🐱 WebCat MCP Client Test\n") + try: + asyncio.run(test_search()) + print("\n✨ Test completed successfully!") + except Exception as e: + print(f"\n❌ Test failed: {e}") + import traceback + + traceback.print_exc() diff --git a/docker/test_with_curl.sh b/docker/test_with_curl.sh new file mode 100755 index 0000000..255b8c1 --- /dev/null +++ b/docker/test_with_curl.sh @@ -0,0 +1,26 @@ +#!/bin/bash +# Simple curl-based test for WebCat endpoints + +echo "=== Testing WebCat Server ===" +echo "" + +# Test health endpoint +echo "1. Health Check (no auth):" +curl -s http://localhost:8000/health | jq . +echo "" + +# Test status endpoint +echo "2. Status Check:" +curl -s http://localhost:8000/status | jq . +echo "" + +# Test with auth if WEBCAT_API_KEY is set +if [ -n "$WEBCAT_API_KEY" ]; then + echo "3. Health Check (with auth):" + curl -s -H "Authorization: Bearer $WEBCAT_API_KEY" http://localhost:8000/health | jq . + echo "" +fi + +echo "=== MCP Tools Testing ===" +echo "MCP tools (search, health_check) require an MCP client." +echo "Use Claude Desktop or run: python test_mcp_client.py" diff --git a/docker/tests/unit/tools/test_search_tool.py b/docker/tests/unit/tools/test_search_tool.py index ec5d90f..c02df89 100644 --- a/docker/tests/unit/tools/test_search_tool.py +++ b/docker/tests/unit/tools/test_search_tool.py @@ -18,10 +18,14 @@ class TestSearchTool: """Tests for search tool.""" @pytest.mark.asyncio + @patch("tools.search_tool.validate_bearer_token") @patch("tools.search_tool.process_search_results") @patch("tools.search_tool.fetch_with_fallback") - async def test_returns_search_results(self, mock_fetch, mock_process): + async def test_returns_search_results( + self, mock_fetch, mock_process, mock_validate + ): # Arrange + mock_validate.return_value = (True, None) api_results = [an_api_search_result().build()] processed = [ SearchResult( @@ -44,9 +48,11 @@ async def test_returns_search_results(self, mock_fetch, mock_process): assert result["results"][0]["title"] == "Test" @pytest.mark.asyncio + @patch("tools.search_tool.validate_bearer_token") @patch("tools.search_tool.fetch_with_fallback") - async def test_returns_error_when_no_results(self, mock_fetch): + async def test_returns_error_when_no_results(self, mock_fetch, mock_validate): # Arrange + mock_validate.return_value = (True, None) mock_fetch.return_value = ([], "DuckDuckGo (free fallback)") # Act @@ -58,10 +64,14 @@ async def test_returns_error_when_no_results(self, mock_fetch): assert len(result["results"]) == 0 @pytest.mark.asyncio + @patch("tools.search_tool.validate_bearer_token") @patch("tools.search_tool.process_search_results") @patch("tools.search_tool.fetch_with_fallback") - async def test_processes_results_correctly(self, mock_fetch, mock_process): + async def test_processes_results_correctly( + self, mock_fetch, mock_process, mock_validate + ): # Arrange + mock_validate.return_value = (True, None) api_results = [ an_api_search_result() .with_title("T") @@ -77,3 +87,31 @@ async def test_processes_results_correctly(self, mock_fetch, mock_process): # Assert mock_process.assert_called_once_with(api_results) + + @pytest.mark.asyncio + @patch("tools.search_tool.validate_bearer_token") + async def test_returns_error_when_authentication_fails(self, mock_validate): + # Arrange + mock_validate.return_value = (False, "Invalid bearer token") + + # Act + result = await search_tool("test query") + + # Assert + assert result["query"] == "test query" + assert result["error"] == "Invalid bearer token" + assert result["search_source"] == "none" + assert len(result["results"]) == 0 + + @pytest.mark.asyncio + @patch("tools.search_tool.validate_bearer_token") + async def test_passes_context_to_authentication(self, mock_validate): + # Arrange + mock_validate.return_value = (False, "Auth error") + ctx = {"headers": {"Authorization": "Bearer test"}} + + # Act + await search_tool("query", ctx=ctx) + + # Assert + mock_validate.assert_called_once_with(ctx) diff --git a/docker/tests/unit/utils/test_auth.py b/docker/tests/unit/utils/test_auth.py new file mode 100644 index 0000000..ad3086e --- /dev/null +++ b/docker/tests/unit/utils/test_auth.py @@ -0,0 +1,178 @@ +# Copyright (c) 2024 Travis Frisinger +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Unit tests for authentication utilities.""" + +from unittest.mock import patch + +from utils.auth import validate_bearer_token + + +class TestValidateBearerToken: + """Tests for validate_bearer_token function.""" + + @patch.dict("os.environ", {}, clear=True) + def test_no_auth_required_when_api_key_not_set(self): + """When WEBCAT_API_KEY is not set, authentication should pass.""" + # Act + is_valid, error_msg = validate_bearer_token(ctx=None) + + # Assert + assert is_valid is True + assert error_msg is None + + @patch.dict( + "os.environ", {"WEBCAT_API_KEY": "test-token"}, clear=True + ) # pragma: allowlist secret + def test_fails_when_context_is_none(self): + """When API key is set but no context provided, should fail.""" + # Act + is_valid, error_msg = validate_bearer_token(ctx=None) + + # Assert + assert is_valid is False + assert "missing bearer token" in error_msg.lower() + + @patch.dict( + "os.environ", {"WEBCAT_API_KEY": "test-token"}, clear=True + ) # pragma: allowlist secret + def test_fails_when_no_headers_in_context(self): + """When API key is set but context has no headers, should fail.""" + # Arrange + ctx = {} + + # Act + is_valid, error_msg = validate_bearer_token(ctx=ctx) + + # Assert + assert is_valid is False + assert "missing bearer token" in error_msg.lower() + + @patch.dict( + "os.environ", {"WEBCAT_API_KEY": "test-token"}, clear=True + ) # pragma: allowlist secret + def test_fails_when_authorization_header_missing(self): + """When API key is set but Authorization header missing, should fail.""" + # Arrange + ctx = {"headers": {}} + + # Act + is_valid, error_msg = validate_bearer_token(ctx=ctx) + + # Assert + assert is_valid is False + assert "authorization" in error_msg.lower() + + @patch.dict( + "os.environ", {"WEBCAT_API_KEY": "test-token"}, clear=True + ) # pragma: allowlist secret + def test_fails_when_authorization_header_invalid_format(self): + """When Authorization header has invalid format, should fail.""" + # Arrange + ctx = {"headers": {"Authorization": "InvalidFormat"}} + + # Act + is_valid, error_msg = validate_bearer_token(ctx=ctx) + + # Assert + assert is_valid is False + assert "invalid authorization header format" in error_msg.lower() + + @patch.dict( + "os.environ", {"WEBCAT_API_KEY": "test-token"}, clear=True + ) # pragma: allowlist secret + def test_fails_when_bearer_keyword_missing(self): + """When Authorization header doesn't start with Bearer, should fail.""" + # Arrange + ctx = {"headers": {"Authorization": "Basic test-token"}} + + # Act + is_valid, error_msg = validate_bearer_token(ctx=ctx) + + # Assert + assert is_valid is False + assert "invalid authorization header format" in error_msg.lower() + + @patch.dict( + "os.environ", {"WEBCAT_API_KEY": "test-token"}, clear=True + ) # pragma: allowlist secret + def test_fails_when_token_is_incorrect(self): + """When token doesn't match WEBCAT_API_KEY, should fail.""" + # Arrange + ctx = {"headers": {"Authorization": "Bearer wrong-token"}} + + # Act + is_valid, error_msg = validate_bearer_token(ctx=ctx) + + # Assert + assert is_valid is False + assert "invalid bearer token" in error_msg.lower() + + @patch.dict( + "os.environ", {"WEBCAT_API_KEY": "test-token"}, clear=True + ) # pragma: allowlist secret + def test_succeeds_when_token_is_correct(self): + """When token matches WEBCAT_API_KEY, should succeed.""" + # Arrange + ctx = {"headers": {"Authorization": "Bearer test-token"}} + + # Act + is_valid, error_msg = validate_bearer_token(ctx=ctx) + + # Assert + assert is_valid is True + assert error_msg is None + + @patch.dict( + "os.environ", {"WEBCAT_API_KEY": "test-token"}, clear=True + ) # pragma: allowlist secret + def test_handles_case_insensitive_authorization_header(self): + """Should handle Authorization header with different cases.""" + # Test lowercase + ctx = {"headers": {"authorization": "Bearer test-token"}} + is_valid, error_msg = validate_bearer_token(ctx=ctx) + assert is_valid is True + assert error_msg is None + + # Test uppercase + ctx = {"headers": {"AUTHORIZATION": "Bearer test-token"}} + is_valid, error_msg = validate_bearer_token(ctx=ctx) + assert is_valid is True + assert error_msg is None + + @patch.dict( + "os.environ", {"WEBCAT_API_KEY": "test-token"}, clear=True + ) # pragma: allowlist secret + def test_handles_context_with_headers_attribute(self): + """Should extract headers from context.headers attribute.""" + + # Arrange + class ContextWithAttribute: + def __init__(self): + self.headers = {"Authorization": "Bearer test-token"} + + ctx = ContextWithAttribute() + + # Act + is_valid, error_msg = validate_bearer_token(ctx=ctx) + + # Assert + assert is_valid is True + assert error_msg is None + + @patch.dict( + "os.environ", {"WEBCAT_API_KEY": "test-token"}, clear=True + ) # pragma: allowlist secret + def test_bearer_keyword_is_case_insensitive(self): + """Bearer keyword should be case-insensitive.""" + # Arrange + ctx = {"headers": {"Authorization": "bearer test-token"}} + + # Act + is_valid, error_msg = validate_bearer_token(ctx=ctx) + + # Assert + assert is_valid is True + assert error_msg is None diff --git a/docker/tools/api_tools_setup.py b/docker/tools/api_tools_setup.py index b43b98b..c568642 100644 --- a/docker/tools/api_tools_setup.py +++ b/docker/tools/api_tools_setup.py @@ -10,7 +10,7 @@ import time from typing import Any, Dict -from fastmcp import FastMCP +from fastmcp import Context, FastMCP from constants import CAPABILITIES, SERVICE_NAME, VERSION from models.api_responses import ( @@ -22,6 +22,7 @@ from models.search_result import SearchResult from services.content_scraper import scrape_search_result from services.search_orchestrator import execute_search +from utils.auth import validate_bearer_token logger = logging.getLogger(__name__) @@ -33,9 +34,24 @@ def setup_search_tool(mcp: FastMCP, search_func): name="search", description="Search the web for information using Serper API or DuckDuckGo fallback", ) - async def search_tool(query: str, max_results: int = 5) -> dict: + async def search_tool(query: str, ctx: Context, max_results: int = 5) -> dict: """Search the web for information on a given query.""" try: + # Validate authentication if WEBCAT_API_KEY is set + is_valid, error_msg = validate_bearer_token(ctx) + if not is_valid: + logger.warning(f"Authentication failed: {error_msg}") + response = APISearchToolResponse( + success=False, + query=query, + max_results=max_results, + search_source="none", + results=[], + total_found=0, + error=error_msg, + ) + return response.model_dump() + logger.info( f"Processing search request: {query} (max {max_results} results)" ) diff --git a/docker/tools/search_tool.py b/docker/tools/search_tool.py index a13be3a..47f8679 100644 --- a/docker/tools/search_tool.py +++ b/docker/tools/search_tool.py @@ -13,6 +13,7 @@ from models.search_result import SearchResult from services.search_processor import process_search_results from services.search_service import fetch_with_fallback +from utils.auth import validate_bearer_token logger = logging.getLogger(__name__) @@ -28,13 +29,25 @@ async def search_tool(query: str, ctx=None) -> dict: Args: query: The search query string - ctx: Optional MCP context + ctx: Optional MCP context (may contain authentication headers) Returns: Dict representation of SearchResponse model (for MCP JSON serialization) """ logger.info(f"Processing search request: {query}") + # Validate authentication if WEBCAT_API_KEY is set + is_valid, error_msg = validate_bearer_token(ctx) + if not is_valid: + logger.warning(f"Authentication failed: {error_msg}") + response = SearchResponse( + query=query, + search_source="none", + results=[], + error=error_msg, + ) + return response.model_dump() + # Fetch results with automatic fallback api_results, search_source = fetch_with_fallback(query, SERPER_API_KEY) diff --git a/docker/utils/auth.py b/docker/utils/auth.py new file mode 100644 index 0000000..7ddaae1 --- /dev/null +++ b/docker/utils/auth.py @@ -0,0 +1,104 @@ +# Copyright (c) 2024 Travis Frisinger +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Authentication utilities for WebCat server. + +This module provides optional bearer token authentication. If WEBCAT_API_KEY +is set in the environment, MCP tool calls must include a valid bearer token +in the context. If not set, no authentication is required. +""" + +import logging +import os +from typing import Any, Optional + +try: + from fastmcp import Context +except ImportError: + Context = None # type: ignore + +logger = logging.getLogger(__name__) + + +def validate_bearer_token(ctx: Optional[Any] = None) -> tuple[bool, Optional[str]]: + """Validate bearer token if WEBCAT_API_KEY is set. + + Args: + ctx: Optional context from MCP tool call (may contain request headers) + + Returns: + Tuple of (is_valid, error_message) + - If WEBCAT_API_KEY not set: (True, None) - no auth required + - If valid token: (True, None) + - If invalid/missing token: (False, error_message) + """ + api_key = os.environ.get("WEBCAT_API_KEY") + + # No API key configured - no authentication required + if not api_key: + return True, None + + # API key is set - authentication required + if ctx is None: + logger.warning("Authentication required but no context provided") + return False, "Authentication required: missing bearer token" + + # Try to extract Authorization header from context + # FastMCP provides Context object with get_http_request() method + headers = None + + # Handle FastMCP Context object + if Context and isinstance(ctx, Context): + try: + request = ctx.get_http_request() + if request and hasattr(request, "headers"): + headers = dict(request.headers) + except Exception as e: + logger.warning(f"Failed to get HTTP request from context: {e}") + + # Fallback: try direct attribute access + if headers is None and hasattr(ctx, "headers"): + headers = ctx.headers + # Fallback: try dict access + elif headers is None and isinstance(ctx, dict) and "headers" in ctx: + headers = ctx["headers"] + + if headers is None: + logger.warning("Authentication required but no headers in context") + return False, "Authentication required: missing bearer token" + + # Get Authorization header (case-insensitive) + auth_header = None + if isinstance(headers, dict): + # Try different case variations + auth_header = ( + headers.get("Authorization") + or headers.get("authorization") + or headers.get("AUTHORIZATION") + ) + + if not auth_header: + logger.warning("Missing Authorization header") + return False, "Authentication required: missing Authorization header" + + # Validate bearer token format + parts = auth_header.split() + if len(parts) != 2 or parts[0].lower() != "bearer": + logger.warning("Invalid Authorization header format") + return ( + False, + "Invalid Authorization header format. Expected: Bearer ", + ) + + token = parts[1] + + # Validate token + if token != api_key: + logger.warning("Invalid bearer token provided") + return False, "Invalid bearer token" + + # Token is valid + logger.debug("Bearer token validated successfully") + return True, None