diff --git a/docs/agent_hocon_reference.md b/docs/agent_hocon_reference.md index 632c26f29..ea1409de3 100644 --- a/docs/agent_hocon_reference.md +++ b/docs/agent_hocon_reference.md @@ -54,7 +54,6 @@ Items in ***bold*** are essentials. Try to understand these first. - [***tools*** - list of other agents/tools that this agent may access](#tools-agents) - [External Agents](#external-agents) - [MCP Servers](#mcp-servers) - - [Authentication](#authentication) - [***class*** - Python class name to invoke for Coded Tools](#class-1) - [llm_config - agent-specific LLM configuration](#llm_config-1) - [command](#command) @@ -569,48 +568,7 @@ MCP servers can be configured in two formats: - `tools` key filters which specific tools from the MCP server are made available. If omitted, all tools on the server will be accessible. -##### Authentication - -MCP tools can be authenticated using the following methods: - -- `http_headers` field in `sly_data`. The required fields depend on the authentication scheme expected by each MCP -server. Users may specify different authorization credentials for different MCP URLs. - - Example: - - ```json - { - "http_headers": { - "": { - "Authorization": "Bearer " - }, - "": { - "client_id": "", - "client_secret": "" - } - } - } - ``` - -- Set the `MCP_SERVERS_INFO_FILE` environment variable to point to a HOCON file containing MCP server configurations: - - ```json - { - "mcp_server_url_1": { - "http_headers": { - "Authorization": "Bearer ", - }, - "tools": ["tool_1", "tool_2"] - }, - } - ``` - - - Server URLs must match those in the agent network HOCON file - - - If the headers exist in both `sly_data` and the configuration file for the same server, - `sly_data` takes precedence - - - Tool filtering from the configuration file is used only if no tool filtering exists in the agent network HOCON +See [this document](mcp_authentication.md) for MCP authentication. ### llm_config diff --git a/docs/mcp_authentication.md b/docs/mcp_authentication.md new file mode 100644 index 000000000..9d1e2f621 --- /dev/null +++ b/docs/mcp_authentication.md @@ -0,0 +1,299 @@ +# MCP Authentication Guide + +Neuro-san supports machine-to-machine authentication for MCP servers that require credentials before granting +access to tools. This guide explains the available authentication methods and how to configure them. + + + +- [Authentication Methods](#authentication-methods) +- [Configuration Methods](#configuration-methods) + - [Method 1: Using sly_data](#method-1-using-sly_data) + - [http_headers](#http_headers) + - [Authorization](#authorization) + - [mcp_client_info](#mcp_client_info) + - [client_id](#client_id) + - [client_secret](#client_secret) + - [token_endpoint_auth_method](#token_endpoint_auth_method) + - [scope](#scope) + - [mcp_tokens](#mcp_tokens) + - [access_token](#access_token) + - [refresh_token](#refresh_token) + - [Method 2: Using Configuration File](#method-2-using-configuration-file) + - [http_headers](#http_headers-1) + - [mcp_client_info](#mcp_client_info-1) + - [mcp_server_info](#mcp_server_info) + - [token_endpoint](#token_endpoint) + - [auth_timeout](#auth_timeout) + - [tools](#tools) +- [Configuration Precedence Rules](#configuration-precedence-rules) + +--- + +## Authentication Methods + +Neuro-san supports three authentication methods, applied in the following priority order: + +1. **Headers** (highest priority) + - Typically uses the `Authorization` field with `Bearer ` + - Required fields depend on the authentication scheme expected by the MCP server + +2. **Refresh Token** (fallback if headers unavailable) + - Exchanges client ID and refresh token for an access token + - Used when both client credentials and refresh token are available + +3. **Client Credentials** (lowest priority) + - Exchanges client ID and/or client secret for an access token + - Used when only client information is provided + +--- + +## Configuration Methods + +Authentication data can be provided in two ways: through `sly_data` or via an environment variable configuration file. + +### Method 1: Using `sly_data` + +Pass authentication credentials directly in the `sly_data` object. You can specify different credentials +for different MCP URLs. + +**Available Fields:** +- `http_headers` - HTTP headers for authentication +- `mcp_client_info` - Client credentials for token exchange +- `mcp_tokens` - Token information (only available via sly_data) + +**Example:** + +```hocon +{ + "http_headers": { + "": { + "Authorization": "Bearer " + } + }, + "mcp_client_info": { + "": { + "client_id": "", + "client_secret": "", + "token_endpoint_auth_method": "client_secret_post", + "scope": "" + } + }, + "mcp_tokens": { + "": { + "access_token": "", + "refresh_token": "" + } + } +} +``` + +#### `http_headers` + +HTTP headers to be sent with requests to the MCP server for authentication purposes. + +##### `Authorization` + +The `Authorization` header field typically contains credentials for authenticating the client with the server. +The most common format is `Bearer ` for OAuth 2.0 bearer tokens, +but other authentication schemes may be used depending on the server's requirements +(e.g., `Basic`, `Digest`, or custom schemes). +Consult your MCP server's documentation for the specific authentication scheme required. + +**Example:** + +```hocon +"Authorization": "Bearer " +``` + +#### `mcp_client_info` + +Client credentials used for OAuth 2.0 authentication flows to obtain access tokens. + +##### `client_id` + +The OAuth 2.0 client identifier issued to the client during the registration process. +This is a required field that uniquely identifies your application to the authorization server. + +**Example:** + +```hocon +"client_id": "my-application-client-id" +``` + +##### `client_secret` + +The OAuth 2.0 client secret issued to the client during registration. +This confidential credential is used to authenticate the client with the authorization server. +This field is required unless `token_endpoint_auth_method` is set to `null`. + +**Example:** + +```hocon +"client_secret": "super-secret-client-secret-value" +``` + +##### `token_endpoint_auth_method` + +Specifies the authentication method used when exchanging credentials for +an access token at the token endpoint. This follows the OAuth 2.0 specification. + +**Supported values:** +- `client_secret_basic` (default): Client credentials are sent via +HTTP Basic authentication, with the client ID and secret encoded in the `Authorization` header +- `client_secret_post`: Client credentials are sent in the request body as POST parameters +- `None`: No authentication (for public clients) + +For more details on OAuth 2.0 client authentication methods, +see [RFC 6749 Section 2.3](https://datatracker.ietf.org/doc/html/rfc6749#section-2.3). + +**Example:** + +```hocon +"token_endpoint_auth_method": "client_secret_post" +``` + +##### `scope` + +A space-separated list of OAuth 2.0 scopes being requested. +Scopes define the level of access that the application is requesting. +The available scopes depend on the MCP server's authorization server configuration. +If omitted, the server will use its default scopes. + +**Example:** + +```hocon +"scope": "read:data write:data admin:settings" +``` + +#### `mcp_tokens` + +Token information for authentication. This field is only available when using `sly_data` configuration. + +The system will attempt to authenticate using the provided `access_token`. +If authentication fails (e.g., due to token expiration), the system will automatically attempt to refresh +the access token using the `refresh_token` if one is provided. + +**Note:** While standard OAuth 2.0 token responses include additional fields such as `token_type`, `expires_in`, +and `scope`, these fields are not used by the authentication system and should not be included in the configuration. + +##### `access_token` + +The OAuth 2.0 access token string used to authenticate requests to the MCP server. This is typically a JWT +(JSON Web Token) or an opaque token string issued by the authorization server. + +**Example:** + +```hocon +"access_token": "" +``` + +##### `refresh_token` + +An optional OAuth 2.0 refresh token that can be used to obtain a new access token when the current one expires or +fails. If provided, the system will automatically use this token to refresh authentication when needed. +Not all authorization servers issue refresh tokens, particularly for short-lived sessions or public clients. + +**Example:** + +```hocon +"refresh_token": "" +``` + +### Method 2: Using Configuration File + +Set the `AGENT_MCP_INFO_FILE` environment variable to point to a HOCON configuration file containing MCP authentication +settings. For detailed information about the structure and all available configuration options, +see the [MCP Info Configuration Documentation](mcp_info_hocon_reference.md). + +**Important Notes:** +- `MCP_SERVERS_INFO_FILE` is deprecated and will be removed in version 0.7.0 +- Server info (e.g., token endpoints) can only be configured via environment variable +- Server URLs must match those defined in the agent network HOCON file +- We strongly recommend **not** storing secrets directly in any source file. +Source files can easily be committed to version control, and checking in secrets is a serious security risk. +If these configuration files need to be committed, use **HOCON substitution** (e.g., environment variable references) +instead of hardcoding secret values. + +**Example Configuration:** + +```hocon +{ + "mcp_server_url_1": { + "http_headers": { + "Authorization": "Bearer " + }, + "mcp_client_info": { + "client_id": "", + "client_secret": "", + "token_endpoint_auth_method": "client_secret_post", + "scope": "" + }, + "mcp_server_info": { + "token_endpoint": "https://example.com/token" + }, + "auth_timeout": 300.0, + "tools": ["tool_1", "tool_2"] + } +} +``` + + +#### `http_headers` + +Same as in Method 1. See the `http_headers` section under Method 1 for detailed field descriptions. + + +#### `mcp_client_info` + +Same as in Method 1. See the `mcp_client_info` section under Method 1 for detailed field descriptions. + +#### `mcp_server_info` + +Configuration for the MCP server's OAuth 2.0 endpoints and authentication behavior. + +##### `token_endpoint` + +The URL of the OAuth 2.0 token endpoint where the client exchanges credentials for access tokens. +If not provided, the system will attempt to discover the endpoint automatically through the server's +discovery mechanism, or fall back to `{base_url}/token`. + +**Example:** + +```hocon +"token_endpoint": "https://auth.example.com/oauth/token" +``` + +#### `auth_timeout` + +The maximum time in seconds to wait for authentication operations to complete. +This includes token exchange requests and any other authentication-related network operations. +The default value is 300.0 seconds (5 minutes). + +**Example:** + +```hocon +"auth_timeout": 180.0 # 3 minutes +``` + +**Alternative:** You can also set the `AGENT_MCP_TIMEOUT_SECONDS` environment variable instead of using this field. + +#### `tools` + +An optional list of tool names to filter. +When specified, only the listed tools from this MCP server will be made available. +If omitted, all tools from the server are available. + +**Example:** + +```hocon +"tools": ["search_database", "update_record", "delete_record"] +``` + +--- + +## Configuration Precedence Rules + +When authentication data exists in multiple locations, the following precedence applies: + +1. **sly_data takes precedence** over configuration file for headers and client info on the same server +2. **Configuration file tool filtering** is used only if no tool filtering exists in the agent network HOCON file diff --git a/docs/mcp_info_hocon_reference.md b/docs/mcp_info_hocon_reference.md new file mode 100644 index 000000000..b2307ffa1 --- /dev/null +++ b/docs/mcp_info_hocon_reference.md @@ -0,0 +1,715 @@ +# MCP Info Configuration Documentation + +The MCP Info Configuration file is a HOCON-formatted file that contains authentication and connection settings for MCP +(Model Context Protocol) servers. This file is referenced by setting the `AGENT_MCP_INFO_FILE` environment variable +to point to its location. + +The configuration file allows you to: + +- Define authentication credentials for multiple MCP servers +- Configure OAuth 2.0 client credentials and token endpoints +- Set custom headers for authentication +- Specify timeouts for authentication operations +- Filter which tools from each server should be available + + + +- [Security Best Practices](#security-best-practices) + - [Never Hardcode Secrets](#never-hardcode-secrets) + - [Using HOCON Substitution](#using-hocon-substitution) +- [Configuration File Structure](#configuration-file-structure) +- [Configuration Fields](#configuration-fields) + - [http_headers](#http_headers) + - [Common Headers](#common-headers) + - [Authorization](#authorization) + - [X-API-Key](#x-api-key) + - [Custom Headers](#custom-headers) + - [mcp_client_info](#mcp_client_info) + - [client_id](#client_id) + - [client_secret](#client_secret) + - [token_endpoint_auth_method](#token_endpoint_auth_method) + - [client_secret_basic](#client_secret_basic) + - [client_secret_post](#client_secret_post) + - [null](#null) + - [scope](#scope) + - [mcp_server_info](#mcp_server_info) + - [token_endpoint](#token_endpoint) + - [auth_timeout](#auth_timeout) + - [tools](#tools) +- [Complete Configuration Example](#complete-configuration-example) +- [Related Documentation](#related-documentation) + +--- + +## Security Best Practices + +### Never Hardcode Secrets + + +**⚠️ CRITICAL SECURITY WARNING ⚠️** + +We **strongly recommend NOT storing secrets directly in any source file**. This includes: +- API tokens and access tokens +- Client secrets +- Bearer tokens +- Passwords +- Any other sensitive credentials + +**Why this matters:** +- Source files can easily be committed to version control systems (Git, SVN, etc.) +- Checking in secrets is a serious security risk that can lead to: + - Unauthorized access to your systems + - Data breaches + - Compliance violations + - Exposure of sensitive customer information + +### Using HOCON Substitution + +If your configuration files need to be committed to version control, **always use HOCON substitution** to reference +environment variables instead of hardcoding secret values. + +**Secure approach using environment variable substitution:** + +```hocon +{ + "https://api.example.com/mcp": { + "http_headers": { + "Authorization": "Bearer "${MCP_ACCESS_TOKEN}" + }, + "mcp_client_info": { + "client_id": "${MCP_CLIENT_ID}", + "client_secret": "${MCP_CLIENT_SECRET}" + } + } +} +``` + +**Then set the environment variables separately:** + +```bash +export MCP_ACCESS_TOKEN="your-actual-token-here" +export MCP_CLIENT_ID="your-client-id-here" +export MCP_CLIENT_SECRET="your-client-secret-here" +export AGENT_MCP_INFO_FILE="/path/to/mcp-config.hocon" +``` + +**Additional security recommendations:** +- Use a secrets management system (e.g., HashiCorp Vault, AWS Secrets Manager, Azure Key Vault) +- Rotate credentials regularly +- Use different credentials for different environments (dev, staging, production) +- Implement principle of least privilege - only grant necessary scopes +- Monitor and audit access to secrets + +--- + +## Configuration File Structure + +The configuration file is organized as a HOCON object where each key is an MCP server URL, +and the value is a configuration block for that server. + +**Basic structure:** + +```hocon +{ + "server_url_1": { + # Configuration for server 1 + }, + "server_url_2": { + # Configuration for server 2 + } +} +``` + +**Important notes:** +- Server URLs must be complete, valid URLs including the protocol (https://) +- Server URLs must exactly match those defined in your agent network HOCON file +- Each server can have its own independent authentication configuration +- Multiple authentication methods can be specified per server (with precedence rules applied) + +--- + +## Configuration Fields + +Each server URL maps to a configuration block that can contain the following top-level fields: + +```hocon +{ + "https://api.example.com/mcp": { + "http_headers": { ... }, # Optional: HTTP headers for authentication + "mcp_client_info": { ... }, # Optional: OAuth 2.0 client credentials + "mcp_server_info": { ... }, # Optional: Server endpoint configuration + "auth_timeout": 300.0, # Optional: Authentication timeout in seconds + "tools": ["tool1", "tool2"] # Optional: Tool filtering + } +} +``` + +### `http_headers` + +HTTP headers to be sent with requests to the MCP server for authentication purposes. +This is typically used for token-based authentication. + +**Structure:** + +```hocon +"http_headers": { + "Header-Name-1": "header-value-1", + "Header-Name-2": "header-value-2" +} +``` + +#### Common Headers + +##### `Authorization` + +The most commonly used authentication header. Supports various authentication schemes: + +**Bearer Token Authentication (most common):** + +```hocon +"http_headers": { + "Authorization": "Bearer ${ACCESS_TOKEN}" +} +``` + +Bearer tokens are typically used with OAuth 2.0 and are defined in +[RFC 6750](https://datatracker.ietf.org/doc/html/rfc6750). +The token should be sent exactly as received from the authorization server. + +**Basic Authentication:** + +```hocon +"http_headers": { + "Authorization": "Basic ${BASE64_CREDENTIALS}" +} +``` + +Where `BASE64_CREDENTIALS` is the base64 encoding of `username:password`. +See [RFC 7617](https://datatracker.ietf.org/doc/html/rfc7617). + +**API Key Authentication:** + +```hocon +"http_headers": { + "Authorization": "ApiKey ${API_KEY}" +} +``` + +Some services use custom authentication schemes. Always refer to your MCP server's documentation. + +##### `X-API-Key` + +Some services use a custom header for API keys: + +```hocon +"http_headers": { + "X-API-Key": "${API_KEY}" +} +``` + +##### Custom Headers + +You can include any custom headers required by your MCP server: + +```hocon +"http_headers": { + "X-Client-Id": "${CLIENT_ID}", + "X-Request-ID": "unique-request-identifier", + "X-Tenant-ID": "tenant-123" +} +``` + +**Complete example:** + +```hocon +{ + "https://api.example.com/mcp": { + "http_headers": { + "Authorization": "Bearer ${MCP_TOKEN}", + "X-API-Version": "2024-01", + "X-Client-ID": "${CLIENT_ID}" + } + } +} +``` + +### `mcp_client_info` + +OAuth 2.0 client credentials used for token-based authentication flows. +This section is used when you need to exchange client credentials for an access token. + +**Structure:** + +```hocon +"mcp_client_info": { + "client_id": "string", + "client_secret": "string", + "token_endpoint_auth_method": "string", + "scope": "string" +} +``` + +#### `client_id` + +**Type:** String (required) + +The OAuth 2.0 client identifier issued to your application during the registration process with the authorization +server. This uniquely identifies your application. + +**Example:** + +```hocon +"client_id": "${MCP_CLIENT_ID}" +``` + +**Best practices:** +- Always use environment variable substitution +- Client IDs are typically safe to commit to source control (they're not secret), +but using environment variables provides flexibility +- Keep a record of which client ID corresponds to which environment + +#### `client_secret` + +**Type:** String (conditionally required) + +The OAuth 2.0 client secret issued to your application during registration. +This is a confidential credential that must be protected. + +**Required when:** +- `token_endpoint_auth_method` is `client_secret_basic` or `client_secret_post` + +**Not required when:** +- `token_endpoint_auth_method` is `null` (public clients) +- Using other authentication methods that don't require a client secret + +**Example:** + +```hocon +"client_secret": "${MCP_CLIENT_SECRET}" +``` + +**Security considerations:** +- **NEVER** hardcode client secrets in configuration files +- **ALWAYS** use environment variable substitution +- Treat client secrets with the same security level as passwords +- Rotate client secrets regularly +- Use different client secrets for different environments + +#### `token_endpoint_auth_method` + +**Type:** String (optional) +**Default:** `client_secret_basic` + +Specifies how the client authenticates with the authorization server's token endpoint. +This follows the OAuth 2.0 specification defined in +[RFC 6749 Section 2.3](https://datatracker.ietf.org/doc/html/rfc6749#section-2.3). + +**Supported values:** + +##### `client_secret_basic` + +**Default method.** Client credentials are sent via HTTP Basic authentication. +The client ID and secret are combined as `client_id:client_secret`, base64-encoded, +and sent in the `Authorization` header. + +```hocon +"token_endpoint_auth_method": "client_secret_basic" +``` + +**When to use:** +- Most common and widely supported method +- Recommended for server-to-server communication +- Credentials are not exposed in request body or logs + +**Token request format:** + +```http +POST /token HTTP/1.1 +Host: auth.example.com +Authorization: Basic Base64(client_id:client_secret) +Content-Type: application/x-www-form-urlencoded + +grant_type=client_credentials&scope=read write +``` + +##### `client_secret_post` + +Client credentials are sent as POST parameters in the request body. + +```hocon +"token_endpoint_auth_method": "client_secret_post" +``` + +**When to use:** +- When the authorization server doesn't support Basic authentication +- Required by some OAuth providers +- Less secure than `client_secret_basic` as credentials appear in request body + +**Token request format:** + +```http +POST /token HTTP/1.1 +Host: auth.example.com +Content-Type: application/x-www-form-urlencoded + +grant_type=client_credentials&client_id=xxx&client_secret=yyy&scope=read write +``` + +##### `null` + +No client authentication is performed. Used for public clients that don't have a client secret. + +```hocon +"token_endpoint_auth_method": null +``` + +**When to use:** +- Authorization servers that don't require client authentication + +**Token request format:** + +```http +POST /token HTTP/1.1 +Host: auth.example.com +Content-Type: application/x-www-form-urlencoded + +grant_type=client_credentials&client_id=xxx&scope=read write +``` + +#### `scope` + +**Type:** String (optional) +**Default:** None (server uses its default scopes) + +A space-separated list of OAuth 2.0 scopes being requested. +Scopes define the specific permissions your application is requesting. +The available scopes are defined by the authorization server and vary by service. + +**Example:** + +```hocon +"scope": "read:data write:data admin:users" +``` + +**Understanding scopes:** +- Scopes represent permissions or access levels +- Request only the scopes your application needs (principle of least privilege) +- The authorization server may grant fewer scopes than requested +- Multiple scopes are separated by spaces (not commas) + +**Common scope patterns:** + +**Resource-based scopes:** + +```hocon +"scope": "read:projects write:projects delete:projects" +``` + +**Action-based scopes:** + +```hocon +"scope": "projects.read projects.write projects.delete" +``` + +**Role-based scopes:** + +```hocon +"scope": "user admin superadmin" +``` + +### `mcp_server_info` + +Configuration for the MCP server's OAuth 2.0 endpoints and authentication behavior. +This section provides additional server-specific settings that complement the client credentials. + +**Structure:** + +```hocon +"mcp_server_info": { + "token_endpoint": "string" +} +``` + +#### `token_endpoint` + +**Type:** String (optional) +**Default:** Auto-discovered or `{base_url}/token` + +The URL of the OAuth 2.0 token endpoint where the client exchanges credentials for access tokens. +This is where token requests are sent during authentication. + +**When to specify:** +- The authorization server's token endpoint is different from the standard location +- The server doesn't support endpoint discovery +- You want to explicitly control which endpoint is used + +**When to omit:** +- The server supports OAuth 2.0 discovery +- The token endpoint follows the standard pattern (`{base_url}/token`) + +**Example:** + +```hocon +"mcp_server_info": { + "token_endpoint": "https://auth.example.com/oauth/token" +} +``` + +**Discovery mechanism:** + +If `token_endpoint` is not provided, +the system attempts to discover it using the OAuth 2.0 Authorization Server Metadata discovery mechanism. + +**Discovery order:** + +The system follows the +[MCP 2025-11-25 specification](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization) +and [RFC 8414](https://www.rfc-editor.org/rfc/rfc8414.html) for endpoint discovery: + +1. **If no authorization server URL is specified in the MCP server's metadata:** + - Checks `{server_netloc}/.well-known/oauth-authorization-server` + - This is the legacy path defined in the MCP specification + +2. **If an authorization server URL is found in the MCP server's metadata:** + + **For authorization servers with a path component** (e.g., `https://auth.example.com/tenant1`): + - **Path-aware OAuth discovery ([RFC 8414 Section 3](https://www.rfc-editor.org/rfc/rfc8414.html#section-3)):** + Checks `{base_url}/.well-known/oauth-authorization-server{path}` + Example: `https://auth.example.com/.well-known/oauth-authorization-server/tenant1` + + - **Path-aware OIDC discovery ([RFC 8414 Section 5](https://www.rfc-editor.org/rfc/rfc8414.html#section-5)):** + Checks `{base_url}/.well-known/openid-configuration{path}` + Example: `https://auth.example.com/.well-known/openid-configuration/tenant1` + + - **OIDC 1.0 discovery + ([OpenID Connect Discovery 1.0](https://openid.net/specs/openid-connect-discovery-1_0.html)):** + Checks `{base_url}{path}/.well-known/openid-configuration` + Example: `https://auth.example.com/tenant1/.well-known/openid-configuration` + + **For authorization servers without a path component** (e.g., `https://auth.example.com`): + - **OAuth root discovery ([RFC 8414](https://www.rfc-editor.org/rfc/rfc8414.html)):** + Checks `{base_url}/.well-known/oauth-authorization-server` + Example: `https://auth.example.com/.well-known/oauth-authorization-server` + + - **OIDC 1.0 fallback + ([OpenID Connect Discovery 1.0](https://openid.net/specs/openid-connect-discovery-1_0.html)):** + Checks `{base_url}/.well-known/openid-configuration` + Example: `https://auth.example.com/.well-known/openid-configuration` + +**Token endpoint extraction:** + +Once a valid OAuth Authorization Server Metadata document is discovered: +- If the metadata contains a `token_endpoint` field, that URL is used +- If the metadata is found but doesn't contain a `token_endpoint` field, +the system falls back to `{auth_base_url}/token` + +**If all discovery attempts fail:** +- The system falls back to `{auth_base_url}/token` +- If this also fails, you must manually specify the `token_endpoint` in the `mcp_server_info` configuration + +**Note:** The discovery URLs are tried in the order listed above. +The first successful response will be used to extract the token endpoint. + +**Common endpoint patterns:** + +```hocon +# Standard OAuth 2.0 pattern +"token_endpoint": "https://auth.example.com/oauth/token" + +# Auth0 pattern +"token_endpoint": "https://tenant.auth0.com/oauth/token" + +# Okta pattern +"token_endpoint": "https://tenant.okta.com/oauth2/default/v1/token" + +# Custom authorization server +"token_endpoint": "https://auth.mycompany.com/v1/token" +``` + +### `auth_timeout` + +**Type:** Float (optional) +**Default:** 300.0 (5 minutes) +**Unit:** Seconds + +The maximum time to wait for authentication operations to complete. This includes: +- Token exchange requests +- Token refresh operations +- Any other authentication-related network operations +- Authorization server response time + +**Example:** + +```hocon +"auth_timeout": 180.0 # 3 minutes +``` + +**Choosing the right timeout:** + +**Short timeouts (30-60 seconds):** +- Fast, reliable networks +- High-performance authorization servers +- Synchronous workflows where quick failures are preferred + +**Medium timeouts (120-300 seconds):** +- Standard network conditions +- Most production environments +- Balance between responsiveness and reliability + +**Long timeouts (300+ seconds):** +- Unreliable networks +- Slow authorization servers +- Asynchronous workflows where retries are expensive + +**Example configurations:** + +```hocon +# Development environment - fail fast +"auth_timeout": 30.0 + +# Production environment - standard timeout +"auth_timeout": 300.0 + +# Unstable network - generous timeout +"auth_timeout": 600.0 +``` + +**Alternative configuration:** + +Instead of using the `auth_timeout` field, you can set the `AGENT_MCP_TIMEOUT_SECONDS` environment variable: + +```bash +export AGENT_MCP_TIMEOUT_SECONDS=180 +``` + +**Precedence:** The `auth_timeout` field in the configuration file takes precedence over +the environment variable if both are set. + +### `tools` + +**Type:** Array of strings (optional) +**Default:** None (all tools are available) + +An optional list of tool names to filter from the MCP server. +When specified, only the listed tools will be made available to the agent. This is useful for: +- Limiting the agent's capabilities for security reasons +- Reducing complexity by exposing only necessary tools +- Creating specialized agents with focused functionality +- Testing specific tools in isolation + +**Example:** + +```hocon +"tools": ["search_database", "update_record", "delete_record", "create_report"] +``` + +**Filtering behavior:** + +**If `tools` is specified:** +- Only tools in the list are made available +- Tools not in the list are hidden from the agent +- Invalid tool names are silently ignored +- Order doesn't matter + +**If `tools` is omitted or empty:** +- All tools from the server are available +- No filtering is applied + +**Tool filtering precedence:** + +Tool filtering can be specified in two places: +1. Agent network HOCON file (higher precedence) +2. MCP info configuration file (lower precedence) + +The configuration file's tool filtering is **only used if no tool filtering exists in the agent network HOCON file**. + +**Use cases:** + +**Security-focused filtering:** + +```hocon +# Only allow read operations +"tools": ["read_data", "search_data", "list_data"] +``` + +**Environment-specific filtering:** + +```hocon +# Development: all tools available +"tools": [] + +# Production: restricted to safe operations +"tools": ["read_data", "search_data", "generate_report"] +``` + +**Role-based filtering:** + +```hocon +# Admin role +"tools": ["create", "read", "update", "delete", "admin_panel"] + +# User role +"tools": ["read", "search"] +``` + +**Finding available tools:** + +To discover which tools are available from an MCP server: +1. Check the server's documentation +2. Query the server's tool catalog endpoint (if available) +3. Temporarily omit the `tools` filter and observe which tools are registered + +## Complete Configuration Example + +```hocon +{ + # Server 1: Bearer token authentication + "https://api.service1.com/mcp": { + "http_headers": { + "Authorization": "Bearer ${SERVICE1_TOKEN}", + }, + "auth_timeout": 60.0, + "tools": ["search", "list"] + }, + + # Server 2: Client credentials with custom token endpoint + "https://api.service2.com/mcp": { + "mcp_client_info": { + "client_id": "${SERVICE2_CLIENT_ID}", + "client_secret": "${SERVICE2_CLIENT_SECRET}", + "token_endpoint_auth_method": "client_secret_post", + "scope": "full_access" + }, + "mcp_server_info": { + "token_endpoint": "https://auth.service2.com/token" + }, + "auth_timeout": 180.0 + }, + + # Server 3: API key authentication + "https://api.service3.com/mcp": { + "http_headers": { + "X-API-Key": "${SERVICE3_API_KEY}", + "X-Client-ID": "my-app-123" + }, + "tools": ["query", "update", "delete"] + }, + + # Server 4: Public client (no secret) + "https://api.service4.com/mcp": { + "mcp_client_info": { + "client_id": "${SERVICE4_CLIENT_ID}", + "token_endpoint_auth_method": "None" + }, + "auth_timeout": 120.0 + } +} +``` + +## Related Documentation + +- [MCP Specification Guide](https://modelcontextprotocol.io/specification/2025-11-25/basic) +- OAuth 2.0 Specifications: + - [RFC 6749: OAuth 2.0 Authorization Framework](https://datatracker.ietf.org/doc/html/rfc6749) + - [RFC 6750: Bearer Token Usage](https://datatracker.ietf.org/doc/html/rfc6750) + - [RFC 7617: HTTP Basic Authentication](https://datatracker.ietf.org/doc/html/rfc7617) diff --git a/neuro_san/deploy/Dockerfile b/neuro_san/deploy/Dockerfile index bd5ab123b..25cff64af 100644 --- a/neuro_san/deploy/Dockerfile +++ b/neuro_san/deploy/Dockerfile @@ -315,10 +315,17 @@ ENV AGENT_EXTERNAL_RESERVATIONS_STORAGE="" # to use for cross-pod reservations storage. ENV AGENT_RESERVATIONS_S3_BUCKET="" -# A hocon file with MCP servers information to be used by LangChainMcpAdapter +# A hocon file with MCP clients and servers information to be used by LangChainMcpAdapter # for connecting to external MCP servers with authentication and tool filtering. +ENV AGENT_MCP_INFO_FILE="" + +# Fallback for AGENT_MCP_INFO_FILE. Deprecated in favor of AGENT_MCP_INFO_FILE, but still supported for now for backward compatibility. +# Will be removed in neuro-san==0.7 ENV MCP_SERVERS_INFO_FILE="" +# Timeout to wait for MCP server authentication response, in seconds. Default is 300 seconds (5 minutes). +ENV AGENT_MCP_TIMEOUT_SECONDS=300 + # When set, this parameter enables MCP service protocol for running neuro-san server. # Service endpoint is http://host:port/mcp # (neuro-san own http API is also enabled in this case) diff --git a/neuro_san/internals/run_context/langchain/core/base_tool_factory.py b/neuro_san/internals/run_context/langchain/core/base_tool_factory.py index b8d2cf0fb..5fbfb19c8 100644 --- a/neuro_san/internals/run_context/langchain/core/base_tool_factory.py +++ b/neuro_san/internals/run_context/langchain/core/base_tool_factory.py @@ -153,8 +153,8 @@ async def create_mcp_tool(self, mcp_info: Union[str, Dict[str, Any]]) -> List[Ba """ # By default, assume no allowed tools. This may get updated below or in the LangChainMcpAdadter. allowed_tools: List[str] = None - # Get HTTP headers from sly_data if available - http_headers: Dict[str, Any] = self.tool_caller.get_sly_data().get("http_headers", {}) + # Get sly_data for http headers, client info, and tokens that may be needed for MCP auth and tool retrieval + sly_data: Dict[str, Any] = self.tool_caller.get_sly_data() or {} if isinstance(mcp_info, str): server_url: str = mcp_info @@ -162,12 +162,11 @@ async def create_mcp_tool(self, mcp_info: Union[str, Dict[str, Any]]) -> List[Ba server_url = mcp_info.get("url") allowed_tools = mcp_info.get("tools") - # Get specific headers for the MCP server if available - headers: Dict[str, Any] = http_headers.get(server_url) - try: mcp_adapter = LangChainMcpAdapter() - mcp_tools: List[BaseTool] = await mcp_adapter.get_mcp_tools(server_url, allowed_tools, headers) + # Pass sly_data from tool caller since MCP auth provider may need to write tokens into it + mcp_tools: List[BaseTool] = await mcp_adapter.get_mcp_tools( + server_url, allowed_tools, sly_data) # MCP errors are nested exceptions. except ExceptionGroup as nested_exception: diff --git a/neuro_san/internals/run_context/langchain/mcp/client_credentials_oauth_provider.py b/neuro_san/internals/run_context/langchain/mcp/client_credentials_oauth_provider.py new file mode 100644 index 000000000..4cbbd3acc --- /dev/null +++ b/neuro_san/internals/run_context/langchain/mcp/client_credentials_oauth_provider.py @@ -0,0 +1,156 @@ + +# Copyright © 2023-2026 Cognizant Technology Solutions Corp, www.cognizant.com. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# END COPYRIGHT + +from typing import Any +from typing import Dict +from typing import Literal +from typing import Optional +from typing import override +import httpx + +from mcp.client.auth import OAuthClientProvider +from mcp.client.auth import TokenStorage +from mcp.shared.auth import OAuthClientInformationFull +from mcp.shared.auth import OAuthClientMetadata + + +class ClientCredentialsOauthProvider(OAuthClientProvider): + """OAuth provider for client_credentials grant with client_id + client_secret. + + This provider sets client_info directly, bypassing dynamic client registration. + Use this when you already have client credentials (client_id and client_secret) and there is no refresh token. + + The authentication flow proceeds as follows: + 1. When an HTTP request is sent to an MCP server, the async_auth_flow method is triggered. + 2. The provider attempts to load client information and tokens from storage. This logic is implemented in + the _initialize() method, which is overridden to set client_info directly to prevent invalid parameters. + 3. If the server responds with 401 Unauthorized, the provider attempts metadata discovery, including: + - Protected Resource Metadata + - OAuth Authorization Server Metadata + 4. Attempt to get tokens from the token endpoint in _perform_authorization() method. + Use the endpoint from discovery if available, otherwise use the token_endpoint parameter. + 5. If the server responds with 403 Forbidden, the provider attempts to update the scopes from + the protected resource metadata. + 7. Retry the authentication flow with the new scopes or tokens. + + This is taken from the MCP SDK: + https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/auth/extensions/client_credentials.py + + However, the SDK's ClientCredentialsOauthProvider has two limitations: + - No user-provided endpoint support - Only uses endpoints from metadata discovery, + doesn't allow manual configuration + - Bug in client_secret_post method - The token exchange fails when using this authentication method + + Thus, we modify the code from the MCP SDK to add support for user-provided token endpoint and + fix the client_secret_post method. + + WARNING: This class overrides private methods from the SDK's OAuth implementation. + + FRAGILITY NOTICE: + - We override _initialize() and _perform_authorization() + - These are PRIVATE methods that may change without notice + - Any SDK update could break this implementation + + FUTURE WORK: + - The following GitHub issue/PR tracks making this properly extensible in the SDK + - https://github.com/modelcontextprotocol/python-sdk/issues/2121 + - https://github.com/modelcontextprotocol/python-sdk/issues/2128 + - https://github.com/modelcontextprotocol/python-sdk/pull/2140 + - If/when that lands, we should migrate to use official extension points + - If not accepted, we should build our own OAuth client + """ + + # pylint: disable=too-many-arguments + # pylint: disable=too-many-positional-arguments + def __init__( + self, + server_url: str, + storage: TokenStorage, + client_id: str, + client_secret: Optional[str], + token_endpoint: str = None, + token_endpoint_auth_method: Literal["client_secret_basic", "client_secret_post", None] = "client_secret_basic", + scopes: str | None = None, + timeout: float = 300.0 + ) -> None: + """Constructor""" + # Build minimal client_metadata for the base class + client_metadata = OAuthClientMetadata( + redirect_uris=None, + grant_types=["client_credentials"], + token_endpoint_auth_method=token_endpoint_auth_method, + scope=scopes, + ) + # There is no need for redirect or callback. + super().__init__(server_url, client_metadata, storage, None, None, timeout) + # Store client_info to be set during _initialize - no dynamic registration needed + # Note that client info is obtained through dynamic client registration in the MCP SDK, + # but since we are bypassing dynamic registration, we need to set client_info directly. + self._fixed_client_info = OAuthClientInformationFull( + redirect_uris=None, + client_id=client_id, + client_secret=client_secret, + grant_types=["client_credentials"], + token_endpoint_auth_method=token_endpoint_auth_method, + scope=scopes, + ) + self.token_endpoint: str = token_endpoint + + @override + async def _initialize(self) -> None: + """Load stored tokens and set pre-configured client_info.""" + self.context.current_tokens = await self.context.storage.get_tokens() + # Client info can also be loaded from the storage here, + # but we set it directly to prevent invalid or missing needed parameters. + self.context.client_info = self._fixed_client_info + self._initialized = True + + @override + async def _perform_authorization(self) -> httpx.Request: + """ + Perform client_credentials authorization. + + Note that this method originally performs user authorization and exchange code for token. + Thus, we override it to perform client credentials flow instead of user authorization and token exchange. + """ + return await self._exchange_token_client_credentials() + + async def _exchange_token_client_credentials(self) -> httpx.Request: + """ + Build token exchange request for client_credentials grant. + + A helper method to build the token request for client credentials flow, which is used in _perform_authorization. + """ + token_data: Dict[str, Any] = { + "grant_type": "client_credentials", + "client_id": self.context.client_info.client_id, + } + + headers: Dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"} + + # Use standard auth methods (client_secret_basic, client_secret_post, none) + token_data, headers = self.context.prepare_token_auth(token_data, headers) + + if self.context.should_include_resource_param(self.context.protocol_version): + token_data["resource"] = self.context.get_resource_url() + + if self.context.client_metadata.scope: + token_data["scope"] = self.context.client_metadata.scope + + # Determine token endpoint URL: use provided token_endpoint if available, otherwise use endpoint from discovery + token_url: str = self.token_endpoint or self._get_token_endpoint() + return httpx.Request("POST", token_url, data=token_data, headers=headers) diff --git a/neuro_san/internals/run_context/langchain/mcp/langchain_mcp_adapter.py b/neuro_san/internals/run_context/langchain/mcp/langchain_mcp_adapter.py index a1a38138b..b94f6226e 100644 --- a/neuro_san/internals/run_context/langchain/mcp/langchain_mcp_adapter.py +++ b/neuro_san/internals/run_context/langchain/mcp/langchain_mcp_adapter.py @@ -15,6 +15,8 @@ # # END COPYRIGHT +import os +import threading from typing import Any from typing import Dict from typing import List @@ -23,21 +25,34 @@ import copy from logging import Logger from logging import getLogger -import threading +from httpx import Auth from langchain_core.tools import BaseTool from langchain_mcp_adapters.client import MultiServerMCPClient -from neuro_san.internals.run_context.langchain.mcp.mcp_servers_info_restorer import McpServersInfoRestorer +from mcp.client.auth.exceptions import OAuthFlowError +from mcp.client.auth.exceptions import OAuthTokenError + + +from neuro_san.internals.run_context.langchain.mcp.mcp_info_restorer import McpInfoRestorer +from neuro_san.internals.run_context.langchain.mcp.oauth_provider_factory import OauthProviderFactory +from neuro_san.internals.run_context.langchain.mcp.sly_data_token_storage import SlyDataTokenStorage + +MCP_AUTH_TIMEOUT = float(os.getenv("AGENT_MCP_TIMEOUT_SECONDS", "300.0")) # Default to 5 minutes if not set class LangChainMcpAdapter: """ Adapter class to fetch tools from a Multi-Client Protocol (MCP) server and return them as LangChain-compatible tools. This class provides static methods for interacting with MCP servers. + + Features: + - Automatic OAuth authentication with multiple flow support + - Tool filtering based on allowed lists + - MCP client and server configuration management """ _mcp_info_lock: threading.Lock = threading.Lock() - _mcp_servers_info: Dict[str, Any] = None + _mcp_info: Dict[str, Any] = None def __init__(self): """ @@ -47,77 +62,286 @@ def __init__(self): self.logger: Logger = getLogger(self.__class__.__name__) @staticmethod - def _load_mcp_servers_info(): + async def _load_mcp_info(): """ - Loads MCP servers information from a configuration file if not already loaded. + Loads MCP clients and servers information from a configuration file if not already loaded. """ with LangChainMcpAdapter._mcp_info_lock: - if LangChainMcpAdapter._mcp_servers_info is None: - LangChainMcpAdapter._mcp_servers_info = McpServersInfoRestorer().restore() - if LangChainMcpAdapter._mcp_servers_info is None: + if LangChainMcpAdapter._mcp_info is None: + LangChainMcpAdapter._mcp_info = McpInfoRestorer().restore() + if LangChainMcpAdapter._mcp_info is None: # Something went wrong reading the file. # Prevent further attempts to load info. - LangChainMcpAdapter._mcp_servers_info = {} + LangChainMcpAdapter._mcp_info = {} - async def get_mcp_tools( + async def _get_mcp_info(self, server_url: str) -> Dict[str, Any]: + """ + Get client and server configuration from cached info. + + :param server_url: The MCP server URL to look up configuration for. + + :return: Server configuration dictionary (empty dict if not found). + """ + if self._mcp_info is None: + await self._load_mcp_info() + return self._mcp_info.get(server_url, {}) + + async def _prepare_headers( + self, + server_url: str, + headers: Optional[Dict[str, Any]] + ) -> Optional[Dict[str, Any]]: + """ + Prepare headers for MCP request. + + Priority: explicitly provided headers > server config headers + + :param server_url: The MCP server URL to get configuration for. + :param headers: Explicitly provided headers from sly data (optional). + + :return: Headers dictionary or None if no headers are available or invalid. + """ + # Use provided headers, fallback to server config + mcp_info: Dict[str, Any] = await self._get_mcp_info(server_url) + final_headers: Dict[str, Any] = headers or mcp_info.get("http_headers") + + if final_headers: + if not isinstance(final_headers, dict): + self.logger.error( + "MCP client headers for server %s must be a dictionary, got %s", + server_url, + type(final_headers).__name__ + ) + return None + # Return a copy to avoid modifying the original + return copy.copy(final_headers) + + return None + + async def _prepare_client_info(self, server_url: str, client_info: Optional[Dict[str, Any]]) -> Dict[str, Any]: + """ + Prepare client info for MCP authentication. + + Priority: client info from sly data > client info from environment variable + + :param server_url: The MCP server URL to get configuration for. + :param client_info: Client info from sly data (optional). + + :return: Client info dictionary or an empty dict if no client info are available or invalid. + """ + # Use sly data client info, fallback to environment variable + mcp_info: Dict[str, Any] = await self._get_mcp_info(server_url) + final_client_info: Dict[str, Any] = client_info or mcp_info.get("mcp_client_info") + + if final_client_info: + if not isinstance(final_client_info, dict): + self.logger.error( + "MCP client info for server %s must be a dictionary, got %s", + server_url, + type(final_client_info).__name__ + ) + return {} + # Return a copy to avoid modifying the original + return copy.copy(final_client_info) + + return {} + + async def _prepare_token( self, server_url: str, - allowed_tools: Optional[List[str]] = None, - headers: Optional[Dict[str, Any]] = None + sly_data: Dict[str, Any] + ) -> Dict[str, Any]: + """ + Prepare token for MCP authentication. sly_data is needed as a reference to store the generated token. + Only allow tokens that are explicitly configured in sly_data. + + :param server_url: The MCP server URL to get configuration for. + :param sly_data: The sly data dictionary to use for token lookup and storage. + + :return: Token dictionary or an empty dict if no token are available or invalid. + """ + # Ensure tokens dict exists in sly_data + sly_data_tokens: Dict[str, Any] = sly_data.setdefault("mcp_tokens", {}) + + # Get token from sly_data or env var + sly_data_token: Dict[str, Any] = sly_data_tokens.get(server_url) + + if sly_data_token: + if not isinstance(sly_data_token, dict): + self.logger.error( + "Token for server %s must be a dictionary, got %s", + server_url, + type(sly_data_token).__name__ + ) + sly_data_tokens[server_url] = {} + + # Ensure server has an entry (either existing or empty dict) + return sly_data_tokens.setdefault(server_url, {}) + + async def _create_oauth_provider( + self, server_url: str, client_info: Dict[str, Any], token: Dict[str, Any]) -> OauthProviderFactory: + """ + Create and configure OAuth provider for the server. + + :param server_url: The MCP server URL to create OAuth provider for. + + :return: Configured OAuth provider factory instance. + """ + mcp_info: Dict[str, Any] = await self._get_mcp_info(server_url) + server_info: Dict[str, Any] = mcp_info.get("mcp_server_info", {}) + + # Prepare token storage + storage = SlyDataTokenStorage(client_info, token) + + # Get OAuth endpoints from server config (optional - will be discovered if not provided) + token_endpoint: str = server_info.get("token_endpoint") + + # Get timeout for OAuth flows from server config or use default + timeout: float = mcp_info.get("auth_timeout", MCP_AUTH_TIMEOUT) + + return OauthProviderFactory( + server_url=server_url, + storage=storage, + token_endpoint=token_endpoint, + timeout=timeout + ) + + async def _determine_allowed_tools( + self, + server_url: str, + allowed_tools: Optional[List[str]] + ) -> List[str]: + """ + Determine which tools are allowed. + + Priority: explicitly provided allowed_tools > server config tools > all tools + + :param server_url: The MCP server URL to get tool configuration for. + :param allowed_tools: Explicitly provided allowed tools list (optional). + + :return: List of allowed tool names (empty list means all tools allowed). + """ + if allowed_tools is not None: + return allowed_tools + + # Fallback to server config + server_info = await self._get_mcp_info(server_url) + return server_info.get("tools", []) + + async def _filter_and_tag_tools( + self, + tools: List[BaseTool], + allowed_tools: List[str] ) -> List[BaseTool]: """ - Fetches tools from the given MCP server and returns them as a list of LangChain-compatible tools. + Filter tools based on allowed list and add langchain_tool tags. + + :param tools: List of all tools from MCP server. + :param allowed_tools: List of allowed tool names (empty list = allow all). + + :return: Filtered list of tools with tags added. + """ + # Filter if allowed_tools is not empty + if allowed_tools: + filtered_tools: List[BaseTool] = [] + for tool in tools: + if tool.name in allowed_tools: + filtered_tools.append(tool) + tools = filtered_tools + + # Add tags to all tools + for tool in tools: + # Add "langchain_tool" tag so journal callback can identify it + # These MCP tools are treated as LangChain tools and can be reported in the thinking file + tool.tags = ["langchain_tool"] + + return tools + + async def _prepare_auth(self, server_url: str, sly_data: Dict[str, Any]) -> Optional[Auth]: + """ + Prepare auth provider for MCP server if authentication is needed. + + :param server_url: URL of the MCP server, e.g. https://mcp.deepwiki.com/mcp or http://localhost:8000/mcp/ + :param sly_data: Optional dictionary of sly data (client info, token, headers) to use for MCP requests. + + :return: An Auth object for MCP authentication if needed, otherwise None. + """ + # Get client info from env var if not available in sly data + sly_data_client_info: Dict[str, Any] = sly_data.get("mcp_client_info", {}).get(server_url) + client_info: Dict[str, Any] = await self._prepare_client_info(server_url, sly_data_client_info) + token: Dict[str, Any] = {} + if client_info: + # If there is client info, ensure that there is "mcp_tokens" in the sly data + # for the provider to use and update, otherwise just pass an empty dict. + token = await self._prepare_token(server_url, sly_data) + # Create and configure OAuth provider + provider: OauthProviderFactory = await self._create_oauth_provider(server_url, client_info, token) + + return await provider.get_auth() + + async def get_mcp_tools( + self, + server_url: str, + allowed_tools: Optional[List[str]], + sly_data: Dict[str, Any] + ) -> List[BaseTool]: + """ + Fetches tools from the given MCP server and returns them as LangChain-compatible tools. + + The method handles: + - OAuth authentication (client_credentials and refresh_token flows supported) + - Putting env var values for client info and token to sly_data when available + - Tool filtering based on allowed list + - Automatic session management and cleanup :param server_url: URL of the MCP server, e.g. https://mcp.deepwiki.com/mcp or http://localhost:8000/mcp/ :param allowed_tools: Optional list of tool names to filter from the server's available tools. - If None, all tools from the server will be returned. - :param headers: Optional dictionary of HTTP headers to include in the MCP requests. + If None, uses server config tools or all tools from the server will be returned. + :param sly_data: Optional dictionary of sly data (client info, token, headers) to use for MCP requests. :return: A list of LangChain BaseTool instances retrieved from the MCP server. """ - if self._mcp_servers_info is None: - self._load_mcp_servers_info() - + # Prepare MCP tool configuration mcp_tool_dict: Dict[str, Any] = { "url": server_url, "transport": "streamable_http", } - # Try to look up authentication details first from the sly data then from the MCP servers info. - headers_dict: Dict[str, Any] = headers or self._mcp_servers_info.get(server_url, {}).get("http_headers") - if headers_dict: - if isinstance(headers_dict, dict): - # Use a copy to avoid modifying the original headers dictionary. - mcp_tool_dict["headers"] = copy.copy(headers_dict) - else: - self.logger.error("MCP client headers for server %s must be a dictionary.", server_url) - - client = MultiServerMCPClient( - {"server": mcp_tool_dict} - ) - # The get_tools() method returns a list of StructuredTool instances, which are subclasses of BaseTool. - # Internally, it calls load_mcp_tools(), which uses an `async with create_session(...)` block. - # This guarantees that any temporary MCP session created is properly closed when the block exits, - # even if an error is raised during tool loading. - # See: https://github.com/langchain-ai/langchain-mcp-adapters/blob/main/langchain_mcp_adapters/tools.py#L164 - # Optimization: - # It's possible we might want to cache these results somehow to minimize tool calls. - mcp_tools: List[BaseTool] = await client.get_tools() - - # If allowed_tools is provided, filter the list to include only those tools. - client_allowed_tools: List[str] = allowed_tools - if client_allowed_tools is None: - # Check if MCP server info has a "tools" field to use as allowed tools. - client_allowed_tools = self._mcp_servers_info.get(server_url, {}).get("tools", []) - # If client allowed tools is an empty list, do not filter the tools. - if client_allowed_tools: - mcp_tools = [tool for tool in mcp_tools if tool.name in client_allowed_tools] + # Add headers if available + sly_data_headers: Dict[str, Any] = sly_data.get("http_headers", {}).get(server_url) + # Prepare headers by prioritizing sly data over MCP server info file and validating format + prepared_headers = await self._prepare_headers(server_url, sly_data_headers) + if prepared_headers: + mcp_tool_dict["headers"] = prepared_headers + + # Add auth if needed + # Prepare auth by prioritizing sly data over env var client info and token and validating format + # and store token from env var in sly data when used + auth: Auth = await self._prepare_auth(server_url, sly_data) + if auth: + mcp_tool_dict["auth"] = auth + + try: + # Create MCP client + client = MultiServerMCPClient({"server": mcp_tool_dict}) + # Fetch tools from server + # The get_tools() method uses `async with create_session(...)` internally, + # which guarantees proper session cleanup even if errors occur. + # See: https://github.com/langchain-ai/langchain-mcp-adapters/blob/main/langchain_mcp_adapters/tools.py#L164 + mcp_tools: List[BaseTool] = await client.get_tools() + + except (OAuthFlowError, OAuthTokenError) as auth_error: + self.logger.error("Authentication failed for MCP server %s: %s", server_url, auth_error, exc_info=True) + mcp_tools = [] + + # Determine which tools are allowed + client_allowed_tools = await self._determine_allowed_tools(server_url, allowed_tools) + + # Store for instance reference self.client_allowed_tools = client_allowed_tools - for tool in mcp_tools: - # Add "langchain_tool" tags so journal callback can idenitify it. - # These MCP tools are treated as Langchain tools and can be reported in the thinking file. - tool.tags = ["langchain_tool"] + # Filter and tag tools + mcp_tools = await self._filter_and_tag_tools(mcp_tools, client_allowed_tools) return mcp_tools diff --git a/neuro_san/internals/run_context/langchain/mcp/mcp_servers_info_restorer.py b/neuro_san/internals/run_context/langchain/mcp/mcp_info_restorer.py similarity index 76% rename from neuro_san/internals/run_context/langchain/mcp/mcp_servers_info_restorer.py rename to neuro_san/internals/run_context/langchain/mcp/mcp_info_restorer.py index 581fa7563..0cfa401e5 100644 --- a/neuro_san/internals/run_context/langchain/mcp/mcp_servers_info_restorer.py +++ b/neuro_san/internals/run_context/langchain/mcp/mcp_info_restorer.py @@ -21,22 +21,25 @@ from neuro_san.internals.persistence.abstract_async_config_restorer import AbstractAsyncConfigRestorer -class McpServersInfoRestorer(AbstractAsyncConfigRestorer): +class McpInfoRestorer(AbstractAsyncConfigRestorer): """ - Implementation of the AbstractAsyncConfigRestorer that reads the MCP servers info file. + Implementation of the AbstractAsyncConfigRestorer that reads the MCP info file. The restore() and async_restore() methods both return a dictionary. - - NOTE: This class is highly experimental and implementation of MCP servers - is very likely to change in future releases. """ def __init__(self): - super().__init__(file_purpose="MCP servers info", env_var="MCP_SERVERS_INFO_FILE") + super().__init__( + file_purpose="MCP servers info", + env_var="AGENT_MCP_INFO_FILE", + deprecated_env_var="MCP_SERVERS_INFO_FILE", + # Only necessary if authentication is required. + must_exist=False, + ) def filter_config(self, basis_config: Dict[str, Any], file_path: str = None) -> Dict[str, Any]: """ - :param basis_config: A dictionary with MCP servers information - :param file_path: The path to the MCP servers info file + :param basis_config: A dictionary with MCP headers, clients, and servers information + :param file_path: The path to the MCP info file :return: a dictionary with MCP servers information """ diff --git a/neuro_san/internals/run_context/langchain/mcp/oauth_provider_factory.py b/neuro_san/internals/run_context/langchain/mcp/oauth_provider_factory.py new file mode 100644 index 000000000..df970740e --- /dev/null +++ b/neuro_san/internals/run_context/langchain/mcp/oauth_provider_factory.py @@ -0,0 +1,105 @@ + +# Copyright © 2023-2026 Cognizant Technology Solutions Corp, www.cognizant.com. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# END COPYRIGHT + +from logging import Logger +from logging import getLogger +from typing import Any +from typing import Dict +from typing import Optional + +from httpx import Auth +from mcp.client.auth import TokenStorage + +from neuro_san.internals.run_context.langchain.mcp.client_credentials_oauth_provider import \ + ClientCredentialsOauthProvider +from neuro_san.internals.run_context.langchain.mcp.refresh_token_oauth_provider import RefreshTokenOauthProvider + + +# pylint: disable=too-many-instance-attributes +class OauthProviderFactory: + """ + Factory for creating OAuth providers based on stored credentials and configuration. + Supports machine-to-machine authentication flows: client credentials and refresh token. + """ + # pylint: disable=too-many-arguments + # pylint: disable=too-many-positional-arguments + def __init__( + self, + server_url: str, + storage: TokenStorage, + token_endpoint: Optional[str] = None, + timeout: float = 300.0 + ): + """ + Initialize OAuth provider factory. + """ + self.logger: Logger = getLogger(self.__class__.__name__) + self.server_url = server_url + self.storage = storage + self.token_endpoint = token_endpoint + self.timeout = timeout + + async def get_auth(self) -> Optional[Auth]: + """ + Get appropriate OAuth provider based on stored credentials and tokens. + + Flow implementation: + - Refresh token flow (if refresh token is available) + - Client credentials flow (no refresh token, but client credentials are available) + """ + credentials: Dict[str, Any] = await self.storage.get_client_info_dict() + tokens: Dict[str, Any] = await self.storage.get_tokens_dict() + refresh_token: str = tokens.get("refresh_token") + + if credentials: + # If there is a refresh token, prioritize refresh token flow to reuse existing authorization + if refresh_token: + return await self._create_refresh_token_provider(credentials) + return await self._create_client_credentials_provider(credentials) + + # No client credentials, no auth provider can be created + return None + + async def _create_client_credentials_provider(self, credentials: Dict[str, Any]) -> Auth: + """Create client credentials OAuth provider.""" + self.logger.info("✓ Using client_credentials flow for %s", self.server_url) + + return ClientCredentialsOauthProvider( + server_url=self.server_url, + storage=self.storage, + client_id=credentials.get("client_id"), + client_secret=credentials.get("client_secret"), + token_endpoint=self.token_endpoint, + token_endpoint_auth_method=credentials.get("token_endpoint_auth_method", "client_secret_basic"), + scopes=credentials.get("scope"), + timeout=self.timeout + ) + + async def _create_refresh_token_provider(self, credentials: Dict[str, Any]) -> Auth: + """Create refresh token provider.""" + self.logger.info("✓ Using refresh token flow for %s", self.server_url) + + return RefreshTokenOauthProvider( + server_url=self.server_url, + storage=self.storage, + client_id=credentials.get("client_id"), + client_secret=credentials.get("client_secret"), + token_endpoint=self.token_endpoint, + token_endpoint_auth_method=credentials.get("token_endpoint_auth_method", "client_secret_basic"), + scopes=credentials.get("scope"), + timeout=self.timeout + ) diff --git a/neuro_san/internals/run_context/langchain/mcp/refresh_token_oauth_provider.py b/neuro_san/internals/run_context/langchain/mcp/refresh_token_oauth_provider.py new file mode 100644 index 000000000..15ef33529 --- /dev/null +++ b/neuro_san/internals/run_context/langchain/mcp/refresh_token_oauth_provider.py @@ -0,0 +1,161 @@ + +# Copyright © 2023-2026 Cognizant Technology Solutions Corp, www.cognizant.com. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# END COPYRIGHT + +from typing import Literal +from typing import override +import httpx + +from mcp.client.auth import OAuthClientProvider +from mcp.client.auth import TokenStorage +from mcp.client.auth.exceptions import OAuthTokenError +from mcp.shared.auth import OAuthClientInformationFull +from mcp.shared.auth import OAuthClientMetadata + + +class RefreshTokenOauthProvider(OAuthClientProvider): + """OAuth provider for refresh token grant. + + A new provider we created specifically for machine-to-machine refresh token flows + since the SDK's refresh token logic is designed for the Authorization Code grant type. + + This provider sets client_info and token directly, bypassing dynamic client registration and user authorization. + Use this when you already have client credentials (client_id and client_secret) and token with refresh token field. + In general, refresh token flow works in tandem with authorization code flow, but currently neuro-san only support + machine-to-machine flows, so refresh token flow is implemented to work when there is refresh token in sly data. + + The authentication flow proceeds as follows: + 1. When an HTTP request is sent to an MCP server, the async_auth_flow method is triggered. + 2. The provider attempts to load client information and tokens from storage. This logic is implemented in + the _initialize() method, which is overridden to set client_info directly to prevent invalid parameters. + 3. If the server responds with 401 Unauthorized, the provider attempts metadata discovery, including: + - Protected Resource Metadata + - OAuth Authorization Server Metadata + 4. Attempt to get tokens from the token endpoint in _perform_authorization() method. + Use the endpoint from discovery if available, otherwise use the token_endpoint parameter. + 5. If the server responds with 403 Forbidden, the provider attempts to update the scopes from + the protected resource metadata. + 7. Retry the authentication flow with the new scopes or tokens. + + This is adapted from the MCP SDK: + https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/auth/extensions/client_credentials.py + but overrides the user authorization step to exchange new token with refresh token and client credentials + when token is expired, add token_endpoint as an optional parameter. + + WARNING: This class overrides private methods from the SDK's OAuth implementation. + + FRAGILITY NOTICE: + - We override _initialize(), _perform_authorization(), and _refresh_token() + - These are PRIVATE methods that may change without notice + - Any SDK update could break this implementation + + FUTURE WORK: + - The following GitHub issue/PR tracks making this properly extensible in the SDK + - https://github.com/modelcontextprotocol/python-sdk/issues/1250 + - https://github.com/modelcontextprotocol/python-sdk/issues/1318 + - https://github.com/modelcontextprotocol/python-sdk/pull/1743 + - https://github.com/modelcontextprotocol/python-sdk/pull/1784 + - https://github.com/modelcontextprotocol/python-sdk/issues/2121 + - If/when that lands, we should migrate to use official extension points + - If not accepted, we should build our own OAuth client + """ + + # pylint: disable=too-many-arguments + # pylint: disable=too-many-positional-arguments + def __init__( + self, + server_url: str, + storage: TokenStorage, + client_id: str, + client_secret: str = None, # may not require client_secret to refresh token + token_endpoint: str = None, + token_endpoint_auth_method: Literal["client_secret_basic", "client_secret_post", None] = "client_secret_basic", + scopes: str | None = None, + timeout: float = 300.0 + ) -> None: + """Constructor""" + # Build minimal client_metadata for the base class + client_metadata = OAuthClientMetadata( + redirect_uris=None, + grant_types=["refresh_token"], + token_endpoint_auth_method=token_endpoint_auth_method, + scope=scopes, + ) + # There is no need for redirect or callback. + super().__init__(server_url, client_metadata, storage, None, None, timeout) + # Store client_info to be set during _initialize - no dynamic registration needed + # Note that client info is obtained through dynamic client registration in the MCP SDK, + # but since we are bypassing dynamic registration, we need to set client_info directly. + self._fixed_client_info = OAuthClientInformationFull( + redirect_uris=None, + client_id=client_id, + client_secret=client_secret, + grant_types=["refresh_token"], + token_endpoint_auth_method=token_endpoint_auth_method, + scope=scopes, + ) + self.token_endpoint: str = token_endpoint + + @override + async def _initialize(self) -> None: + """Load stored tokens and set pre-configured client_info.""" + self.context.current_tokens = await self.context.storage.get_tokens() + # Client info can also be loaded from the storage here, + # but we set it directly to prevent invalid or missing needed parameters. + self.context.client_info = self._fixed_client_info + self._initialized = True + + @override + async def _perform_authorization(self) -> httpx.Request: + """ + Refresh tokens. + + Note that this method originally performs user authorization and exchange code for token. + Thus, we override it to perform refresh token flow instead of user authorization and token exchange. + """ + return await self._refresh_token() + + @override + async def _refresh_token(self) -> httpx.Request: + """ + Build token refresh request. + + Overrides to add support for user provided token endpoint + """ + if not self.context.current_tokens or not self.context.current_tokens.refresh_token: + raise OAuthTokenError("No refresh token available") # pragma: no cover + + if not self.context.client_info or not self.context.client_info.client_id: + raise OAuthTokenError("No client info available") # pragma: no cover + + # Determine token endpoint URL: use provided token_endpoint if available, otherwise use endpoint from discovery + token_url: str = self.token_endpoint or self._get_token_endpoint() + + refresh_data: dict[str, str] = { + "grant_type": "refresh_token", + "refresh_token": self.context.current_tokens.refresh_token, + "client_id": self.context.client_info.client_id, + } + + # Only include resource param if conditions are met + if self.context.should_include_resource_param(self.context.protocol_version): + refresh_data["resource"] = self.context.get_resource_url() # RFC 8707 + + # Prepare authentication based on preferred method + headers = {"Content-Type": "application/x-www-form-urlencoded"} + refresh_data, headers = self.context.prepare_token_auth(refresh_data, headers) + + return httpx.Request("POST", token_url, data=refresh_data, headers=headers) diff --git a/neuro_san/internals/run_context/langchain/mcp/sly_data_token_storage.py b/neuro_san/internals/run_context/langchain/mcp/sly_data_token_storage.py new file mode 100644 index 000000000..ccb1c5374 --- /dev/null +++ b/neuro_san/internals/run_context/langchain/mcp/sly_data_token_storage.py @@ -0,0 +1,105 @@ + +# Copyright © 2023-2026 Cognizant Technology Solutions Corp, www.cognizant.com. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# END COPYRIGHT + +from logging import Logger +from logging import getLogger +from typing import Any +from typing import Dict +from typing import Optional +from typing import override + +from mcp.client.auth import TokenStorage +from mcp.shared.auth import OAuthClientInformationFull +from mcp.shared.auth import OAuthToken +from pydantic import ValidationError + + +class SlyDataTokenStorage(TokenStorage): + """ + Sly data-based token storage that also stores client credentials info. + """ + + def __init__(self, client_info: Dict[str, Any], tokens: Dict[str, Any]): + """ Constructor """ + self.logger: Logger = getLogger(self.__class__.__name__) + + self.client_info: Dict[str, Any] = client_info + self.tokens: Dict[str, Any] = tokens + + @override + async def get_tokens(self) -> Optional[OAuthToken]: + """Load OAuth tokens from sly data.""" + if not self.tokens: + return None + + try: + return OAuthToken(**self.tokens) + + except (ValidationError, TypeError, ValueError) as errors: + self.logger.warning("Failed to load token from sly data: %s", errors) + return None + + async def get_tokens_dict(self) -> Dict[str, Any]: + """Get raw token dictionary.""" + return self.tokens + + @override + async def set_tokens(self, tokens: OAuthToken) -> None: + """Save OAuth tokens to a dictionary in sly data.""" + try: + # Clear and update token in-place + self.tokens.clear() + self.tokens.update(tokens.model_dump(mode="json")) + self.logger.info("Tokens saved (expires in %s s)", tokens.expires_in) + except (AttributeError, TypeError) as errors: + self.logger.error("Failed to save tokens in sly data: %s", errors) + + @override + async def get_client_info(self) -> Optional[OAuthClientInformationFull]: + """ + Load client information from sly data. + + Note: + - `OAuthClientInformationFull` requires `redirect_uris` to be provided as a list of + valid URIs. + - This method is intended for use with **dynamic client registration**, where + client metadata (e.g., client ID, client secret, redirect URIs) is loaded + at runtime rather than being hardcoded. + """ + if not self.client_info: + return None + + try: + return OAuthClientInformationFull(**self.client_info) + except ValidationError as validation_error: + self.logger.warning("Failed to instantiate OAuthClientInformationFull: %s", validation_error) + return None + + async def get_client_info_dict(self) -> Dict[str, Any]: + """Get raw client info dictionary.""" + return self.client_info + + @override + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: + """Save client information to sly data.""" + try: + # Clear and update client info in-place + self.client_info.clear() + self.client_info.update(client_info.model_dump(mode="json")) + self.logger.info("Client registered with ID: %s", client_info.client_id) + except (AttributeError, TypeError) as errors: + self.logger.error("Failed to save client info in sly data: %s", errors) diff --git a/requirements.txt b/requirements.txt index 5dca76fd7..a23bfc7a5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -76,3 +76,6 @@ json-repair>=0.47.3,<1.0 # MCP tools langchain-mcp-adapters>=0.1.7,<1.0 +# 1.23 is needed for MCP OAuth with grant type "client_credentials" and <2.0 to avoid breaking changes. +# https://github.com/modelcontextprotocol/python-sdk/releases/tag/v1.25.0 +mcp>=1.23.0,<2.0 diff --git a/tests/neuro_san/internals/run_context/langchain/mcp/test_client_credentials_oauth_provider.py b/tests/neuro_san/internals/run_context/langchain/mcp/test_client_credentials_oauth_provider.py new file mode 100644 index 000000000..b10842875 --- /dev/null +++ b/tests/neuro_san/internals/run_context/langchain/mcp/test_client_credentials_oauth_provider.py @@ -0,0 +1,296 @@ +# Copyright © 2023-2026 Cognizant Technology Solutions Corp, www.cognizant.com. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# END COPYRIGHT + +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch +import pytest + +from neuro_san.internals.run_context.langchain.mcp.client_credentials_oauth_provider import \ + ClientCredentialsOauthProvider + + +class TestClientCredentialsOauthProvider: + """Test suite for ClientCredentialsOauthProvider class""" + + @pytest.fixture + def mock_storage(self): + """Create mock token storage""" + storage = MagicMock() + storage.get_tokens = AsyncMock(return_value=None) + return storage + + @pytest.fixture + def provider(self, mock_storage): + """Create provider instance with basic credentials""" + return ClientCredentialsOauthProvider( + server_url="https://mcp.example.com/mcp", + storage=mock_storage, + client_id="test_client_id", + client_secret="test_client_secret" + ) + + # pylint: disable=protected-access + def test_init_with_minimal_params(self, mock_storage): + """Test initialization with minimal parameters""" + provider = ClientCredentialsOauthProvider( + server_url="https://mcp.example.com/mcp", + storage=mock_storage, + client_id="test_id", + client_secret="test_secret" + ) + + assert provider._fixed_client_info.client_id == "test_id" + assert provider._fixed_client_info.client_secret == "test_secret" + assert provider._fixed_client_info.grant_types == ["client_credentials"] + assert provider._fixed_client_info.token_endpoint_auth_method == "client_secret_basic" + assert provider._fixed_client_info.scope is None + assert provider.token_endpoint is None + assert provider.context.timeout == 300.0 + + def test_init_with_all_params(self, mock_storage): + """Test initialization with all parameters""" + provider = ClientCredentialsOauthProvider( + server_url="https://mcp.example.com/mcp", + storage=mock_storage, + client_id="test_id", + client_secret="test_secret", + token_endpoint="https://auth.example.com/token", + token_endpoint_auth_method="client_secret_post", + scopes="read write", + timeout=600.0 + ) + + assert provider.token_endpoint == "https://auth.example.com/token" + assert provider._fixed_client_info.token_endpoint_auth_method == "client_secret_post" + assert provider._fixed_client_info.scope == "read write" + assert provider.context.timeout == 600.0 + + def test_init_with_none_auth_method(self, mock_storage): + """Test initialization with None auth method""" + provider = ClientCredentialsOauthProvider( + server_url="https://mcp.example.com/mcp", + storage=mock_storage, + client_id="test_id", + client_secret="test_secret", + token_endpoint_auth_method=None + ) + + assert provider._fixed_client_info.token_endpoint_auth_method is None + + def test_fixed_client_info_structure(self, provider): + """Test that fixed client info has correct structure""" + client_info = provider._fixed_client_info + + assert client_info.redirect_uris is None + assert client_info.grant_types == ["client_credentials"] + assert hasattr(client_info, "client_id") + assert hasattr(client_info, "client_secret") + + def test_client_metadata_structure(self, provider): + """Test that client metadata is correctly configured""" + # Access through the parent class context + assert provider.context.client_metadata.grant_types == ["client_credentials"] + assert provider.context.client_metadata.redirect_uris is None + + @pytest.mark.asyncio + async def test_initialize_loads_tokens(self, provider, mock_storage): + """Test that _initialize loads tokens from storage""" + mock_token = MagicMock() + mock_storage.get_tokens.return_value = mock_token + + await provider._initialize() + + assert provider.context.current_tokens == mock_token + mock_storage.get_tokens.assert_called_once() + + @pytest.mark.asyncio + async def test_initialize_sets_client_info(self, provider): + """Test that _initialize sets fixed client info""" + await provider._initialize() + + assert provider.context.client_info == provider._fixed_client_info + assert provider._initialized is True + + @pytest.mark.asyncio + async def test_initialize_does_not_call_dynamic_registration(self, provider, mock_storage): + """Test that _initialize bypasses dynamic client registration""" + await provider._initialize() + + # Client info should be set directly, not loaded from storage + assert provider.context.client_info == provider._fixed_client_info + # get_tokens should be called, but not any client registration methods + mock_storage.get_tokens.assert_called_once() + + @pytest.mark.asyncio + @patch.object(ClientCredentialsOauthProvider, '_exchange_token_client_credentials') + async def test_perform_authorization_calls_exchange_token(self, mock_exchange, provider): + """Test that _perform_authorization calls token exchange""" + mock_request = MagicMock() + mock_exchange.return_value = mock_request + + result = await provider._perform_authorization() + + assert result == mock_request + mock_exchange.assert_called_once() + + @pytest.mark.asyncio + async def test_exchange_token_client_credentials_basic_request(self, provider): + """Test that token exchange creates correct request""" + await provider._initialize() + + request = await provider._exchange_token_client_credentials() + + assert request.method == "POST" + assert "grant_type" in str(request.content) + assert "client_credentials" in str(request.content) + assert "client_id" in str(request.content) + + @pytest.mark.asyncio + async def test_exchange_token_uses_client_secret_basic_auth(self, mock_storage): + """Test that client_secret_basic method is used correctly""" + provider = ClientCredentialsOauthProvider( + server_url="https://mcp.example.com/mcp", + storage=mock_storage, + client_id="test_id", + client_secret="test_secret", + token_endpoint_auth_method="client_secret_basic" + ) + await provider._initialize() + + request = await provider._exchange_token_client_credentials() + + # client_secret_basic should use Authorization header + assert "Authorization" in request.headers or "authorization" in request.headers + + @pytest.mark.asyncio + async def test_exchange_token_uses_client_secret_post(self, mock_storage): + """Test that client_secret_post method includes secret in body""" + provider = ClientCredentialsOauthProvider( + server_url="https://mcp.example.com/mcp", + storage=mock_storage, + client_id="test_id", + client_secret="test_secret", + token_endpoint_auth_method="client_secret_post" + ) + await provider._initialize() + + request = await provider._exchange_token_client_credentials() + + # client_secret_post should include client_secret in request body + request_body = str(request.content) + assert "client_secret" in request_body or request.method == "POST" + + @pytest.mark.asyncio + async def test_exchange_token_includes_scope_when_provided(self, mock_storage): + """Test that scope is included in token request when provided""" + provider = ClientCredentialsOauthProvider( + server_url="https://mcp.example.com/mcp", + storage=mock_storage, + client_id="test_id", + client_secret="test_secret", + scopes="read write admin" + ) + await provider._initialize() + + request = await provider._exchange_token_client_credentials() + + request_body = str(request.content) + assert "scope" in request_body + + @pytest.mark.asyncio + async def test_exchange_token_excludes_scope_when_none(self, mock_storage): + """Test that scope is excluded when not provided""" + provider = ClientCredentialsOauthProvider( + server_url="https://mcp.example.com/mcp", + storage=mock_storage, + client_id="test_id", + client_secret="test_secret", + scopes=None + ) + await provider._initialize() + + request = await provider._exchange_token_client_credentials() + + # Scope should not be in request when None + # We check this by ensuring the request was created without errors + assert request.method == "POST" + + @pytest.mark.asyncio + async def test_exchange_token_uses_provided_endpoint(self, mock_storage): + """Test that provided token_endpoint is used""" + provider = ClientCredentialsOauthProvider( + server_url="https://mcp.example.com/mcp", + storage=mock_storage, + client_id="test_id", + client_secret="test_secret", + token_endpoint="https://custom.auth.com/oauth/token" + ) + await provider._initialize() + + request = await provider._exchange_token_client_credentials() + + assert str(request.url) == "https://custom.auth.com/oauth/token" + + @pytest.mark.asyncio + async def test_exchange_token_content_type_header(self, provider): + """Test that correct Content-Type header is set""" + await provider._initialize() + + request = await provider._exchange_token_client_credentials() + + assert request.headers.get("Content-Type") == "application/x-www-form-urlencoded" + + @pytest.mark.asyncio + @patch.object(ClientCredentialsOauthProvider, '_get_token_endpoint') + async def test_exchange_token_prefers_provided_endpoint(self, mock_get_endpoint, mock_storage): + """Test that provided endpoint is preferred over discovered endpoint""" + provider = ClientCredentialsOauthProvider( + server_url="https://mcp.example.com/mcp", + storage=mock_storage, + client_id="test_id", + client_secret="test_secret", + token_endpoint="https://fallback.auth.com/token" + ) + await provider._initialize() + + # Mock that discovery found an endpoint + mock_get_endpoint.return_value = "https://discovered.auth.com/token" + + request = await provider._exchange_token_client_credentials() + + # Should use the provided one, not the discovered one + assert str(request.url) == "https://fallback.auth.com/token" + + @pytest.mark.asyncio + @patch.object(ClientCredentialsOauthProvider, '_get_token_endpoint') + async def test_exchange_token_falls_back_to_discovered_endpoint(self, mock_get_endpoint, mock_storage): + """Test fallback to discovered endpoint when provided endpoint is None""" + provider = ClientCredentialsOauthProvider( + server_url="https://mcp.example.com/mcp", + storage=mock_storage, + client_id="test_id", + client_secret="test_secret", + token_endpoint=None + ) + await provider._initialize() + + mock_get_endpoint.return_value = "https://mcp.example.com/token" + + request = await provider._exchange_token_client_credentials() + + # Should MCP token endpoint as fallback + assert str(request.url) == "https://mcp.example.com/token" diff --git a/tests/neuro_san/internals/run_context/langchain/mcp/test_langchain_mcp_adapter.py b/tests/neuro_san/internals/run_context/langchain/mcp/test_langchain_mcp_adapter.py index 59373bc89..d57402138 100644 --- a/tests/neuro_san/internals/run_context/langchain/mcp/test_langchain_mcp_adapter.py +++ b/tests/neuro_san/internals/run_context/langchain/mcp/test_langchain_mcp_adapter.py @@ -1,4 +1,3 @@ - # Copyright © 2023-2026 Cognizant Technology Solutions Corp, www.cognizant.com. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,6 +19,7 @@ from unittest.mock import patch import pytest +from httpx import Auth from langchain_core.tools import StructuredTool from neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter import LangChainMcpAdapter @@ -45,24 +45,40 @@ def mock_mcp_tool(self): def reset_class_state(self): """Reset class-level state before and after each test""" # pylint: disable=protected-access - LangChainMcpAdapter._mcp_servers_info = None + LangChainMcpAdapter._mcp_info = None yield - LangChainMcpAdapter._mcp_servers_info = None + LangChainMcpAdapter._mcp_info = None def test_init(self, adapter): """Test adapter initialization""" assert adapter.client_allowed_tools == [] assert adapter.logger is not None + # pylint: disable=too-many-arguments + # pylint: disable=too-many-positional-arguments @pytest.mark.asyncio + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.OauthProviderFactory') + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.McpInfoRestorer') @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.MultiServerMCPClient') - async def test_get_mcp_tools_basic(self, mock_client_class, adapter, mock_mcp_tool): + async def test_get_mcp_tools_basic(self, mock_client_class, mock_servers_restorer_class, + mock_oauth_factory_class, adapter, mock_mcp_tool): """Test basic retrieval of MCP tools""" + # Setup restorer mocks + mock_servers_restorer_class.return_value.restore.return_value = {} + + # Setup oauth mock to return None (no auth needed) + mock_oauth_factory = mock_oauth_factory_class.return_value + mock_oauth_factory.get_auth = AsyncMock(return_value=None) + mock_client = mock_client_class.return_value mock_client.get_tools = AsyncMock(return_value=[mock_mcp_tool]) server_url = "https://mcp.example.com/mcp" - tools = await adapter.get_mcp_tools(server_url) + tools = await adapter.get_mcp_tools(server_url, allowed_tools=None, sly_data={}) + + # Verify auth was not added to client config + call_args = mock_client_class.call_args[0][0] + assert "auth" not in call_args["server"] assert len(tools) == 1 assert tools[0].name == "test_tool" @@ -71,11 +87,21 @@ async def test_get_mcp_tools_basic(self, mock_client_class, adapter, mock_mcp_to mock_client.get_tools.assert_called_once() @pytest.mark.asyncio + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.OauthProviderFactory') + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.McpInfoRestorer') @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.MultiServerMCPClient') async def test_get_mcp_tools_with_allowed_tools_param( - self, mock_client_class, adapter + self, mock_client_class, mock_servers_restorer_class, + mock_oauth_factory_class, adapter ): """Test filtering tools with allowed_tools parameter""" + # Setup restorer mocks + mock_servers_restorer_class.return_value.restore.return_value = {} + + # Setup oauth mock to return None (no auth needed) + mock_oauth_factory = mock_oauth_factory_class.return_value + mock_oauth_factory.get_auth = AsyncMock(return_value=None) + tool1 = MagicMock(spec=StructuredTool) tool1.name = "allowed_tool" tool1.tags = [] @@ -89,27 +115,32 @@ async def test_get_mcp_tools_with_allowed_tools_param( server_url = "https://mcp.example.com/mcp" allowed_tools = ["allowed_tool"] - tools = await adapter.get_mcp_tools(server_url, allowed_tools=allowed_tools) + tools = await adapter.get_mcp_tools(server_url, allowed_tools=allowed_tools, sly_data={}) assert len(tools) == 1 assert tools[0].name == "allowed_tool" assert adapter.client_allowed_tools == allowed_tools @pytest.mark.asyncio - @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.McpServersInfoRestorer') + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.OauthProviderFactory') + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.McpInfoRestorer') @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.MultiServerMCPClient') async def test_get_mcp_tools_with_config_allowed_tools( - self, mock_client_class, mock_restorer_class, adapter + self, mock_client_class, mock_servers_restorer_class, + mock_oauth_factory_class, adapter ): """Test filtering tools with allowed_tools from config""" server_url = "https://mcp.example.com/mcp" - mock_restorer = mock_restorer_class.return_value - mock_restorer.restore.return_value = { + mock_servers_restorer_class.return_value.restore.return_value = { server_url: { "tools": ["config_tool"] } } + # Setup oauth mock to return None (no auth needed) + mock_oauth_factory = mock_oauth_factory_class.return_value + mock_oauth_factory.get_auth = AsyncMock(return_value=None) + tool1 = MagicMock(spec=StructuredTool) tool1.name = "config_tool" tool1.tags = [] @@ -121,85 +152,113 @@ async def test_get_mcp_tools_with_config_allowed_tools( mock_client = mock_client_class.return_value mock_client.get_tools = AsyncMock(return_value=[tool1, tool2]) - tools = await adapter.get_mcp_tools(server_url) + tools = await adapter.get_mcp_tools(server_url, allowed_tools=None, sly_data={}) assert len(tools) == 1 assert tools[0].name == "config_tool" @pytest.mark.asyncio - @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.McpServersInfoRestorer') + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.OauthProviderFactory') + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.McpInfoRestorer') @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.MultiServerMCPClient') async def test_get_mcp_tools_with_headers_param( - self, mock_client_class, mock_restorer_class, adapter + self, mock_client_class, mock_servers_restorer_class, + mock_oauth_factory_class, adapter ): """Test MCP client initialization with headers parameter""" server_url = "https://mcp.example.com/mcp" - headers = {"Authorization": "Bearer custom_token"} + sly_data = { + "http_headers": { + server_url: {"Authorization": "Bearer custom_token"} + } + } - mock_restorer = mock_restorer_class.return_value - mock_restorer.restore.return_value = {} + mock_servers_restorer_class.return_value.restore.return_value = {} + + # Setup oauth mock to return None (no auth needed) + mock_oauth_factory = mock_oauth_factory_class.return_value + mock_oauth_factory.get_auth = AsyncMock(return_value=None) mock_client = mock_client_class.return_value mock_client.get_tools = AsyncMock(return_value=[]) - await adapter.get_mcp_tools(server_url, headers=headers) + await adapter.get_mcp_tools(server_url, allowed_tools=None, sly_data=sly_data) call_args = mock_client_class.call_args[0][0] assert "headers" in call_args["server"] assert call_args["server"]["headers"]["Authorization"] == "Bearer custom_token" @pytest.mark.asyncio - @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.McpServersInfoRestorer') + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.OauthProviderFactory') + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.McpInfoRestorer') @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.MultiServerMCPClient') async def test_get_mcp_tools_with_config_headers( - self, mock_client_class, mock_restorer_class, adapter + self, mock_client_class, mock_servers_restorer_class, + mock_oauth_factory_class, adapter ): """Test MCP client initialization with headers from config""" server_url = "https://mcp.example.com/mcp" - mock_restorer = mock_restorer_class.return_value - mock_restorer.restore.return_value = { + mock_servers_restorer_class.return_value.restore.return_value = { server_url: { "http_headers": {"Authorization": "Bearer config_token"} } } + # Setup oauth mock to return None (no auth needed) + mock_oauth_factory = mock_oauth_factory_class.return_value + mock_oauth_factory.get_auth = AsyncMock(return_value=None) + mock_client = mock_client_class.return_value mock_client.get_tools = AsyncMock(return_value=[]) - await adapter.get_mcp_tools(server_url) + await adapter.get_mcp_tools(server_url, allowed_tools=None, sly_data={}) call_args = mock_client_class.call_args[0][0] assert "headers" in call_args["server"] assert call_args["server"]["headers"]["Authorization"] == "Bearer config_token" @pytest.mark.asyncio + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.OauthProviderFactory') @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.MultiServerMCPClient') async def test_get_mcp_tools_invalid_headers_type( - self, mock_client_class, adapter, caplog + self, mock_client_class, mock_oauth_factory_class, adapter, caplog ): """Test handling of invalid headers type in config""" # pylint: disable=protected-access server_url = "https://mcp.example.com/mcp" - LangChainMcpAdapter._mcp_servers_info = { + LangChainMcpAdapter._mcp_info = { server_url: { "http_headers": "invalid_string_not_dict" } } + # Setup oauth mock to return None (no auth needed) + mock_oauth_factory = mock_oauth_factory_class.return_value + mock_oauth_factory.get_auth = AsyncMock(return_value=None) + mock_client = mock_client_class.return_value mock_client.get_tools = AsyncMock(return_value=[]) - await adapter.get_mcp_tools(server_url) + await adapter.get_mcp_tools(server_url, allowed_tools=None, sly_data={}) # Check that error was logged assert "must be a dictionary" in caplog.text @pytest.mark.asyncio + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.OauthProviderFactory') + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.McpInfoRestorer') @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.MultiServerMCPClient') async def test_get_mcp_tools_adds_langchain_tool_tags( - self, mock_client_class, adapter + self, mock_client_class, mock_servers_restorer_class, + mock_oauth_factory_class, adapter ): """Test that langchain_tool tags are added to all tools""" + mock_servers_restorer_class.return_value.restore.return_value = {} + + # Setup oauth mock to return None (no auth needed) + mock_oauth_factory = mock_oauth_factory_class.return_value + mock_oauth_factory.get_auth = AsyncMock(return_value=None) + tools = [ MagicMock(spec=StructuredTool, name=f"tool{i}", tags=[]) for i in range(3) @@ -208,7 +267,317 @@ async def test_get_mcp_tools_adds_langchain_tool_tags( mock_client = mock_client_class.return_value mock_client.get_tools = AsyncMock(return_value=tools) - result = await adapter.get_mcp_tools("https://mcp.example.com/mcp") + result = await adapter.get_mcp_tools("https://mcp.example.com/mcp", allowed_tools=None, sly_data={}) for tool in result: assert "langchain_tool" in tool.tags + + @pytest.mark.asyncio + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.OauthProviderFactory') + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.McpInfoRestorer') + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.MultiServerMCPClient') + async def test_get_mcp_tools_with_client_info_from_sly_data( + self, mock_client_class, mock_info_restorer_class, + mock_oauth_factory_class, adapter + ): + """Test that client info from sly_data is used for authentication""" + server_url = "https://mcp.example.com/mcp" + sly_data = { + "mcp_client_info": { + server_url: { + "client_id": "sly_data_client_id", + "client_secret": "sly_data_secret" + } + } + } + + mock_info_restorer_class.return_value.restore.return_value = {} + + # Setup oauth mock + mock_oauth_factory = mock_oauth_factory_class.return_value + mock_oauth_factory.get_auth = AsyncMock(return_value=MagicMock(spec=Auth)) + + mock_client = mock_client_class.return_value + mock_client.get_tools = AsyncMock(return_value=[]) + + await adapter.get_mcp_tools(server_url, allowed_tools=None, sly_data=sly_data) + + # Verify OauthProviderFactory was called with correct client_info + mock_oauth_factory_class.assert_called_once() + call_kwargs = mock_oauth_factory_class.call_args[1] + assert call_kwargs["storage"].client_info["client_id"] == "sly_data_client_id" + assert call_kwargs["storage"].client_info["client_secret"] == "sly_data_secret" + + @pytest.mark.asyncio + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.OauthProviderFactory') + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.McpInfoRestorer') + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.MultiServerMCPClient') + async def test_get_mcp_tools_with_client_info_from_config( + self, mock_client_class, mock_info_restorer_class, + mock_oauth_factory_class, adapter + ): + """Test that client info from config is used when not in sly_data""" + server_url = "https://mcp.example.com/mcp" + mock_info_restorer_class.return_value.restore.return_value = { + server_url: { + "mcp_client_info": { + "client_id": "config_client_id", + "client_secret": "config_secret" + } + } + } + + # Setup oauth mock + mock_oauth_factory = mock_oauth_factory_class.return_value + mock_oauth_factory.get_auth = AsyncMock(return_value=MagicMock(spec=Auth)) + + mock_client = mock_client_class.return_value + mock_client.get_tools = AsyncMock(return_value=[]) + + await adapter.get_mcp_tools(server_url, allowed_tools=None, sly_data={}) + + # Verify OauthProviderFactory was called with config client_info + mock_oauth_factory_class.assert_called_once() + call_kwargs = mock_oauth_factory_class.call_args[1] + assert call_kwargs["storage"].client_info["client_id"] == "config_client_id" + assert call_kwargs["storage"].client_info["client_secret"] == "config_secret" + + @pytest.mark.asyncio + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.OauthProviderFactory') + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.McpInfoRestorer') + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.MultiServerMCPClient') + async def test_get_mcp_tools_sly_data_client_info_takes_priority( + self, mock_client_class, mock_info_restorer_class, + mock_oauth_factory_class, adapter + ): + """Test that sly_data client_info takes priority over config""" + server_url = "https://mcp.example.com/mcp" + sly_data = { + "mcp_client_info": { + server_url: { + "client_id": "sly_data_client_id", + "client_secret": "sly_data_secret" + } + } + } + + mock_info_restorer_class.return_value.restore.return_value = { + server_url: { + "mcp_client_info": { + "client_id": "config_client_id", + "client_secret": "config_secret" + } + } + } + + # Setup oauth mock + mock_oauth_factory = mock_oauth_factory_class.return_value + mock_oauth_factory.get_auth = AsyncMock(return_value=MagicMock(spec=Auth)) + + mock_client = mock_client_class.return_value + mock_client.get_tools = AsyncMock(return_value=[]) + + await adapter.get_mcp_tools(server_url, allowed_tools=None, sly_data=sly_data) + + # Verify sly_data client_info was used, not config + call_kwargs = mock_oauth_factory_class.call_args[1] + assert call_kwargs["storage"].client_info["client_id"] == "sly_data_client_id" + + @pytest.mark.asyncio + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.OauthProviderFactory') + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.McpInfoRestorer') + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.MultiServerMCPClient') + async def test_get_mcp_tools_with_token_in_sly_data( + self, mock_client_class, mock_info_restorer_class, + mock_oauth_factory_class, adapter + ): + """Test that token from sly_data is used for authentication""" + server_url = "https://mcp.example.com/mcp" + sly_data = { + "mcp_client_info": { + server_url: { + "client_id": "test_client_id" + } + }, + "mcp_tokens": { + server_url: { + "access_token": "existing_token", + "refresh_token": "existing_refresh" + } + } + } + + mock_info_restorer_class.return_value.restore.return_value = {} + + # Setup oauth mock + mock_oauth_factory = mock_oauth_factory_class.return_value + mock_oauth_factory.get_auth = AsyncMock(return_value=MagicMock(spec=Auth)) + + mock_client = mock_client_class.return_value + mock_client.get_tools = AsyncMock(return_value=[]) + + await adapter.get_mcp_tools(server_url, allowed_tools=None, sly_data=sly_data) + + # Verify token was passed to OauthProviderFactory + call_kwargs = mock_oauth_factory_class.call_args[1] + assert call_kwargs["storage"].tokens["access_token"] == "existing_token" + assert call_kwargs["storage"].tokens["refresh_token"] == "existing_refresh" + + @pytest.mark.asyncio + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.OauthProviderFactory') + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.McpInfoRestorer') + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.MultiServerMCPClient') + async def test_get_mcp_tools_creates_mcp_tokens_dict_in_sly_data( + self, mock_client_class, mock_info_restorer_class, + mock_oauth_factory_class, adapter + ): + """Test that mcp_tokens dict is created in sly_data when client_info exists""" + server_url = "https://mcp.example.com/mcp" + sly_data = { + "mcp_client_info": { + server_url: { + "client_id": "test_client_id" + } + } + } + + mock_info_restorer_class.return_value.restore.return_value = {} + + # Setup oauth mock + mock_oauth_factory = mock_oauth_factory_class.return_value + mock_oauth_factory.get_auth = AsyncMock(return_value=MagicMock(spec=Auth)) + + mock_client = mock_client_class.return_value + mock_client.get_tools = AsyncMock(return_value=[]) + + await adapter.get_mcp_tools(server_url, allowed_tools=None, sly_data=sly_data) + + # Verify mcp_tokens dict was created in sly_data + assert "mcp_tokens" in sly_data + assert server_url in sly_data["mcp_tokens"] + + @pytest.mark.asyncio + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.OauthProviderFactory') + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.McpInfoRestorer') + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.MultiServerMCPClient') + async def test_get_mcp_tools_invalid_client_info_type( + self, mock_client_class, mock_info_restorer_class, + mock_oauth_factory_class, adapter, caplog + ): + """Test handling of invalid client_info type""" + server_url = "https://mcp.example.com/mcp" + mock_info_restorer_class.return_value.restore.return_value = { + server_url: { + "mcp_client_info": "invalid_string_not_dict" + } + } + + # Setup oauth mock + mock_oauth_factory = mock_oauth_factory_class.return_value + mock_oauth_factory.get_auth = AsyncMock(return_value=None) + + mock_client = mock_client_class.return_value + mock_client.get_tools = AsyncMock(return_value=[]) + + await adapter.get_mcp_tools(server_url, allowed_tools=None, sly_data={}) + + # Check that error was logged + assert "MCP client info" in caplog.text + assert "must be a dictionary" in caplog.text + + @pytest.mark.asyncio + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.OauthProviderFactory') + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.McpInfoRestorer') + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.MultiServerMCPClient') + async def test_get_mcp_tools_invalid_token_type( + self, mock_client_class, mock_info_restorer_class, + mock_oauth_factory_class, adapter, caplog + ): + """Test handling of invalid token type in sly_data""" + server_url = "https://mcp.example.com/mcp" + sly_data = { + "mcp_client_info": { + server_url: {"client_id": "test_id"} + }, + "mcp_tokens": { + server_url: "invalid_string_not_dict" + } + } + + mock_info_restorer_class.return_value.restore.return_value = {} + + # Setup oauth mock + mock_oauth_factory = mock_oauth_factory_class.return_value + mock_oauth_factory.get_auth = AsyncMock(return_value=MagicMock(spec=Auth)) + + mock_client = mock_client_class.return_value + mock_client.get_tools = AsyncMock(return_value=[]) + + await adapter.get_mcp_tools(server_url, allowed_tools=None, sly_data=sly_data) + + # Check that error was logged and token was reset to empty dict + assert "Token for server" in caplog.text + assert "must be a dictionary" in caplog.text + assert not sly_data["mcp_tokens"][server_url] + + @pytest.mark.asyncio + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.OauthProviderFactory') + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.McpInfoRestorer') + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.MultiServerMCPClient') + async def test_get_mcp_tools_with_token_endpoint_from_config( + self, mock_client_class, mock_info_restorer_class, + mock_oauth_factory_class, adapter + ): + """Test that token_endpoint from config is passed to OauthProviderFactory""" + server_url = "https://mcp.example.com/mcp" + mock_info_restorer_class.return_value.restore.return_value = { + server_url: { + "mcp_client_info": {"client_id": "test_id"}, + "mcp_server_info": { + "token_endpoint": "https://auth.example.com/token" + } + } + } + + # Setup oauth mock + mock_oauth_factory = mock_oauth_factory_class.return_value + mock_oauth_factory.get_auth = AsyncMock(return_value=MagicMock(spec=Auth)) + + mock_client = mock_client_class.return_value + mock_client.get_tools = AsyncMock(return_value=[]) + + await adapter.get_mcp_tools(server_url, allowed_tools=None, sly_data={}) + + # Verify token_endpoint was passed to OauthProviderFactory + call_kwargs = mock_oauth_factory_class.call_args[1] + assert call_kwargs["token_endpoint"] == "https://auth.example.com/token" + + @pytest.mark.asyncio + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.OauthProviderFactory') + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.McpInfoRestorer') + @patch('neuro_san.internals.run_context.langchain.mcp.langchain_mcp_adapter.MultiServerMCPClient') + async def test_get_mcp_tools_with_auth_timeout_from_config( + self, mock_client_class, mock_info_restorer_class, + mock_oauth_factory_class, adapter + ): + """Test that auth_timeout from config is passed to OauthProviderFactory""" + server_url = "https://mcp.example.com/mcp" + mock_info_restorer_class.return_value.restore.return_value = { + server_url: { + "mcp_client_info": {"client_id": "test_id"}, + "auth_timeout": 600.0 + } + } + + # Setup oauth mock + mock_oauth_factory = mock_oauth_factory_class.return_value + mock_oauth_factory.get_auth = AsyncMock(return_value=MagicMock(spec=Auth)) + + mock_client = mock_client_class.return_value + mock_client.get_tools = AsyncMock(return_value=[]) + + await adapter.get_mcp_tools(server_url, allowed_tools=None, sly_data={}) + + # Verify timeout was passed to OauthProviderFactory + call_kwargs = mock_oauth_factory_class.call_args[1] + assert call_kwargs["timeout"] == 600.0 diff --git a/tests/neuro_san/internals/run_context/langchain/mcp/test_oauth_provider_factory.py b/tests/neuro_san/internals/run_context/langchain/mcp/test_oauth_provider_factory.py new file mode 100644 index 000000000..81b34abfb --- /dev/null +++ b/tests/neuro_san/internals/run_context/langchain/mcp/test_oauth_provider_factory.py @@ -0,0 +1,280 @@ +# Copyright © 2023-2026 Cognizant Technology Solutions Corp, www.cognizant.com. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# END COPYRIGHT + +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +import pytest + +from neuro_san.internals.run_context.langchain.mcp.client_credentials_oauth_provider import \ + ClientCredentialsOauthProvider +from neuro_san.internals.run_context.langchain.mcp.oauth_provider_factory import OauthProviderFactory +from neuro_san.internals.run_context.langchain.mcp.refresh_token_oauth_provider import RefreshTokenOauthProvider + + +class TestOauthProviderFactory: + """Test suite for OauthProviderFactory class""" + + @pytest.fixture + def mock_storage(self): + """Create mock token storage""" + storage = MagicMock() + storage.get_client_info_dict = AsyncMock() + storage.get_tokens_dict = AsyncMock() + return storage + + @pytest.fixture + def factory(self, mock_storage): + """Create factory instance""" + return OauthProviderFactory( + server_url="https://mcp.example.com/mcp", + storage=mock_storage, + token_endpoint="https://auth.example.com/token", + timeout=300.0 + ) + + def test_init(self, mock_storage): + """Test factory initialization""" + factory = OauthProviderFactory( + server_url="https://mcp.example.com/mcp", + storage=mock_storage, + token_endpoint="https://auth.example.com/token", + timeout=600.0 + ) + + assert factory.server_url == "https://mcp.example.com/mcp" + assert factory.storage == mock_storage + assert factory.token_endpoint == "https://auth.example.com/token" + assert factory.timeout == 600.0 + assert factory.logger is not None + + def test_init_with_defaults(self, mock_storage): + """Test factory initialization with default values""" + factory = OauthProviderFactory( + server_url="https://mcp.example.com/mcp", + storage=mock_storage + ) + + assert factory.token_endpoint is None + assert factory.timeout == 300.0 + + @pytest.mark.asyncio + async def test_get_auth_returns_none_when_no_credentials(self, factory, mock_storage): + """Test that get_auth returns None when no credentials available""" + mock_storage.get_client_info_dict.return_value = {} + mock_storage.get_tokens_dict.return_value = {} + + auth = await factory.get_auth() + + assert auth is None + + @pytest.mark.asyncio + async def test_get_auth_returns_refresh_token_provider_when_refresh_token_available( + self, factory, mock_storage + ): + """Test that refresh token provider is used when refresh token is available""" + mock_storage.get_client_info_dict.return_value = { + "client_id": "test_client_id", + "client_secret": "test_client_secret" + } + mock_storage.get_tokens_dict.return_value = { + "access_token": "test_access_token", + "refresh_token": "test_refresh_token" + } + + auth = await factory.get_auth() + + assert isinstance(auth, RefreshTokenOauthProvider) + + @pytest.mark.asyncio + async def test_get_auth_returns_client_credentials_provider_when_no_refresh_token( + self, factory, mock_storage + ): + """Test that client credentials provider is used when no refresh token""" + mock_storage.get_client_info_dict.return_value = { + "client_id": "test_client_id", + "client_secret": "test_client_secret" + } + mock_storage.get_tokens_dict.return_value = { + "access_token": "test_access_token" + # No refresh_token + } + + auth = await factory.get_auth() + + assert isinstance(auth, ClientCredentialsOauthProvider) + + @pytest.mark.asyncio + async def test_get_auth_prioritizes_refresh_token_over_client_credentials( + self, factory, mock_storage + ): + """Test that refresh token flow is prioritized over client credentials""" + mock_storage.get_client_info_dict.return_value = { + "client_id": "test_client_id", + "client_secret": "test_client_secret" + } + mock_storage.get_tokens_dict.return_value = { + "refresh_token": "test_refresh_token" + } + + auth = await factory.get_auth() + + # Should use refresh token provider, not client credentials + assert isinstance(auth, RefreshTokenOauthProvider) + + # pylint: disable=protected-access + @pytest.mark.asyncio + async def test_create_client_credentials_provider_with_all_params(self, factory): + """Test client credentials provider creation with all parameters""" + credentials = { + "client_id": "test_id", + "client_secret": "test_secret", + "token_endpoint_auth_method": "client_secret_post", + "scope": "read write" + } + + provider = await factory._create_client_credentials_provider(credentials) + + assert isinstance(provider, ClientCredentialsOauthProvider) + assert provider._fixed_client_info.client_id == "test_id" + assert provider._fixed_client_info.client_secret == "test_secret" + assert provider._fixed_client_info.token_endpoint_auth_method == "client_secret_post" + assert provider._fixed_client_info.scope == "read write" + assert provider.token_endpoint == "https://auth.example.com/token" + assert provider.context.timeout == 300.0 + + @pytest.mark.asyncio + async def test_create_client_credentials_provider_with_defaults(self, factory): + """Test client credentials provider with default auth method""" + credentials = { + "client_id": "test_id", + "client_secret": "test_secret" + } + + provider = await factory._create_client_credentials_provider(credentials) + + assert provider._fixed_client_info.token_endpoint_auth_method == "client_secret_basic" + assert provider._fixed_client_info.scope is None + + @pytest.mark.asyncio + async def test_create_refresh_token_provider_with_all_params(self, factory): + """Test refresh token provider creation with all parameters""" + credentials = { + "client_id": "test_id", + "client_secret": "test_secret", + "token_endpoint_auth_method": "client_secret_post", + "scope": "read write" + } + + provider = await factory._create_refresh_token_provider(credentials) + + assert isinstance(provider, RefreshTokenOauthProvider) + assert provider._fixed_client_info.client_id == "test_id" + assert provider._fixed_client_info.client_secret == "test_secret" + assert provider._fixed_client_info.token_endpoint_auth_method == "client_secret_post" + assert provider._fixed_client_info.scope == "read write" + assert provider.token_endpoint == "https://auth.example.com/token" + assert provider.context.timeout == 300.0 + + @pytest.mark.asyncio + async def test_create_refresh_token_provider_with_defaults(self, factory): + """Test refresh token provider with default auth method""" + credentials = { + "client_id": "test_id", + "client_secret": "test_secret" + } + + provider = await factory._create_refresh_token_provider(credentials) + + assert provider._fixed_client_info.token_endpoint_auth_method == "client_secret_basic" + assert provider._fixed_client_info.scope is None + + @pytest.mark.asyncio + async def test_get_auth_handles_empty_tokens_dict(self, factory, mock_storage): + """Test that empty tokens dict is handled correctly""" + mock_storage.get_client_info_dict.return_value = { + "client_id": "test_id", + "client_secret": "test_secret" + } + mock_storage.get_tokens_dict.return_value = {} + + auth = await factory.get_auth() + + # Should fall back to client credentials (no refresh token) + assert isinstance(auth, ClientCredentialsOauthProvider) + + @pytest.mark.asyncio + async def test_get_auth_handles_none_refresh_token(self, factory, mock_storage): + """Test that None refresh token is handled correctly""" + mock_storage.get_client_info_dict.return_value = { + "client_id": "test_id", + "client_secret": "test_secret" + } + mock_storage.get_tokens_dict.return_value = { + "refresh_token": None + } + + auth = await factory.get_auth() + + # Should fall back to client credentials (None is falsy) + assert isinstance(auth, ClientCredentialsOauthProvider) + + @pytest.mark.asyncio + async def test_get_auth_handles_empty_string_refresh_token(self, factory, mock_storage): + """Test that empty string refresh token is handled correctly""" + mock_storage.get_client_info_dict.return_value = { + "client_id": "test_id", + "client_secret": "test_secret" + } + mock_storage.get_tokens_dict.return_value = { + "refresh_token": "" + } + + auth = await factory.get_auth() + + # Should fall back to client credentials (empty string is falsy) + assert isinstance(auth, ClientCredentialsOauthProvider) + + @pytest.mark.asyncio + async def test_factory_uses_provided_token_endpoint(self, mock_storage): + """Test that factory passes token_endpoint to providers""" + factory = OauthProviderFactory( + server_url="https://mcp.example.com/mcp", + storage=mock_storage, + token_endpoint="https://custom.auth.com/oauth/token" + ) + + mock_storage.get_client_info_dict.return_value = {"client_id": "test_id"} + mock_storage.get_tokens_dict.return_value = {} + + provider = await factory.get_auth() + + assert provider.token_endpoint == "https://custom.auth.com/oauth/token" + + @pytest.mark.asyncio + async def test_factory_uses_provided_timeout(self, mock_storage): + """Test that factory passes timeout to providers""" + factory = OauthProviderFactory( + server_url="https://mcp.example.com/mcp", + storage=mock_storage, + timeout=600.0 + ) + + mock_storage.get_client_info_dict.return_value = {"client_id": "test_id"} + mock_storage.get_tokens_dict.return_value = {} + + provider = await factory.get_auth() + + assert provider.context.timeout == 600.0 diff --git a/tests/neuro_san/internals/run_context/langchain/mcp/test_refresh_token_oauth_provider.py b/tests/neuro_san/internals/run_context/langchain/mcp/test_refresh_token_oauth_provider.py new file mode 100644 index 000000000..f635573b2 --- /dev/null +++ b/tests/neuro_san/internals/run_context/langchain/mcp/test_refresh_token_oauth_provider.py @@ -0,0 +1,377 @@ +# Copyright © 2023-2026 Cognizant Technology Solutions Corp, www.cognizant.com. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# END COPYRIGHT + +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch +import pytest + +from mcp.client.auth.exceptions import OAuthTokenError +from mcp.shared.auth import OAuthToken + +from neuro_san.internals.run_context.langchain.mcp.refresh_token_oauth_provider import RefreshTokenOauthProvider + + +# pylint: disable=too-many-public-methods +class TestRefreshTokenOauthProvider: + """Test suite for RefreshTokenOauthProvider class""" + + @pytest.fixture + def mock_storage(self): + """Create mock token storage""" + storage = MagicMock() + storage.get_tokens = AsyncMock(return_value=None) + return storage + + @pytest.fixture + def mock_token_with_refresh(self): + """Create mock token with refresh token""" + return OAuthToken( + access_token="test_access_token", + token_type="Bearer", + expires_in=3600, + refresh_token="test_refresh_token" + ) + + @pytest.fixture + def provider(self, mock_storage): + """Create provider instance with basic credentials""" + return RefreshTokenOauthProvider( + server_url="https://mcp.example.com/mcp", + storage=mock_storage, + client_id="test_client_id", + client_secret="test_client_secret" + ) + + # pylint: disable=protected-access + def test_init_with_minimal_params(self, mock_storage): + """Test initialization with minimal parameters""" + provider = RefreshTokenOauthProvider( + server_url="https://mcp.example.com/mcp", + storage=mock_storage, + client_id="test_id" + ) + + assert provider._fixed_client_info.client_id == "test_id" + assert provider._fixed_client_info.client_secret is None + assert provider._fixed_client_info.grant_types == ["refresh_token"] + assert provider._fixed_client_info.token_endpoint_auth_method == "client_secret_basic" + assert provider._fixed_client_info.scope is None + assert provider.token_endpoint is None + assert provider.context.timeout == 300.0 + + def test_init_with_all_params(self, mock_storage): + """Test initialization with all parameters""" + provider = RefreshTokenOauthProvider( + server_url="https://mcp.example.com/mcp", + storage=mock_storage, + client_id="test_id", + client_secret="test_secret", + token_endpoint="https://auth.example.com/token", + token_endpoint_auth_method="client_secret_post", + scopes="read write", + timeout=600.0 + ) + + assert provider._fixed_client_info.client_id == "test_id" + assert provider._fixed_client_info.client_secret == "test_secret" + assert provider.token_endpoint == "https://auth.example.com/token" + assert provider._fixed_client_info.token_endpoint_auth_method == "client_secret_post" + assert provider._fixed_client_info.scope == "read write" + assert provider.context.timeout == 600.0 + + def test_init_without_client_secret(self, mock_storage): + """Test initialization without client_secret (some servers don't require it)""" + provider = RefreshTokenOauthProvider( + server_url="https://mcp.example.com/mcp", + storage=mock_storage, + client_id="test_id", + client_secret=None + ) + + assert provider._fixed_client_info.client_secret is None + + def test_init_with_none_auth_method(self, mock_storage): + """Test initialization with None auth method""" + provider = RefreshTokenOauthProvider( + server_url="https://mcp.example.com/mcp", + storage=mock_storage, + client_id="test_id", + token_endpoint_auth_method=None + ) + + assert provider._fixed_client_info.token_endpoint_auth_method is None + + def test_fixed_client_info_structure(self, provider): + """Test that fixed client info has correct structure""" + client_info = provider._fixed_client_info + + assert client_info.redirect_uris is None + assert client_info.grant_types == ["refresh_token"] + assert hasattr(client_info, "client_id") + assert hasattr(client_info, "client_secret") + + def test_client_metadata_structure(self, provider): + """Test that client metadata is correctly configured""" + assert provider.context.client_metadata.grant_types == ["refresh_token"] + assert provider.context.client_metadata.redirect_uris is None + + @pytest.mark.asyncio + async def test_initialize_loads_tokens(self, provider, mock_storage, mock_token_with_refresh): + """Test that _initialize loads tokens from storage""" + mock_storage.get_tokens.return_value = mock_token_with_refresh + + await provider._initialize() + + assert provider.context.current_tokens == mock_token_with_refresh + mock_storage.get_tokens.assert_called_once() + + @pytest.mark.asyncio + async def test_initialize_sets_client_info(self, provider): + """Test that _initialize sets fixed client info""" + await provider._initialize() + + assert provider.context.client_info == provider._fixed_client_info + assert provider._initialized is True + + @pytest.mark.asyncio + async def test_initialize_does_not_call_dynamic_registration(self, provider, mock_storage): + """Test that _initialize bypasses dynamic client registration""" + await provider._initialize() + + # Client info should be set directly, not loaded from storage + assert provider.context.client_info == provider._fixed_client_info + mock_storage.get_tokens.assert_called_once() + + @pytest.mark.asyncio + @patch.object(RefreshTokenOauthProvider, '_refresh_token') + async def test_perform_authorization_calls_refresh_token(self, mock_refresh, provider): + """Test that _perform_authorization calls refresh token method""" + mock_request = MagicMock() + mock_refresh.return_value = mock_request + + result = await provider._perform_authorization() + + assert result == mock_request + mock_refresh.assert_called_once() + + @pytest.mark.asyncio + async def test_refresh_token_creates_correct_request(self, provider, mock_token_with_refresh): + """Test that refresh token creates correct HTTP request""" + await provider._initialize() + provider.context.current_tokens = mock_token_with_refresh + + request = await provider._refresh_token() + + assert request.method == "POST" + request_body = str(request.content) + assert "grant_type" in request_body + assert "refresh_token" in request_body + assert "client_id" in request_body + + @pytest.mark.asyncio + async def test_refresh_token_raises_error_when_no_refresh_token(self, provider): + """Test that error is raised when no refresh token available""" + await provider._initialize() + provider.context.current_tokens = OAuthToken( + access_token="test_token", + token_type="Bearer", + expires_in=3600 + # No refresh_token + ) + + with pytest.raises(OAuthTokenError, match="No refresh token available"): + await provider._refresh_token() + + @pytest.mark.asyncio + async def test_refresh_token_raises_error_when_no_tokens(self, provider): + """Test that error is raised when no tokens at all""" + await provider._initialize() + provider.context.current_tokens = None + + with pytest.raises(OAuthTokenError, match="No refresh token available"): + await provider._refresh_token() + + @pytest.mark.asyncio + async def test_refresh_token_raises_error_when_no_client_info(self, provider, mock_token_with_refresh): + """Test that error is raised when no client info available""" + await provider._initialize() + provider.context.current_tokens = mock_token_with_refresh + provider.context.client_info = None + + with pytest.raises(OAuthTokenError, match="No client info available"): + await provider._refresh_token() + + @pytest.mark.asyncio + async def test_refresh_token_raises_error_when_no_client_id(self, provider, mock_token_with_refresh): + """Test that error is raised when client_id is missing""" + await provider._initialize() + provider.context.current_tokens = mock_token_with_refresh + provider.context.client_info.client_id = None + + with pytest.raises(OAuthTokenError, match="No client info available"): + await provider._refresh_token() + + @pytest.mark.asyncio + async def test_refresh_token_uses_provided_endpoint(self, mock_storage, mock_token_with_refresh): + """Test that provided token_endpoint is used""" + provider = RefreshTokenOauthProvider( + server_url="https://mcp.example.com/mcp", + storage=mock_storage, + client_id="test_id", + token_endpoint="https://custom.auth.com/oauth/token" + ) + await provider._initialize() + provider.context.current_tokens = mock_token_with_refresh + + request = await provider._refresh_token() + + assert str(request.url) == "https://custom.auth.com/oauth/token" + + @pytest.mark.asyncio + async def test_refresh_token_content_type_header(self, provider, mock_token_with_refresh): + """Test that correct Content-Type header is set""" + await provider._initialize() + provider.context.current_tokens = mock_token_with_refresh + + request = await provider._refresh_token() + + assert request.headers.get("Content-Type") == "application/x-www-form-urlencoded" + + @pytest.mark.asyncio + async def test_refresh_token_uses_client_secret_basic_auth(self, mock_storage, mock_token_with_refresh): + """Test that client_secret_basic method is used correctly""" + provider = RefreshTokenOauthProvider( + server_url="https://mcp.example.com/mcp", + storage=mock_storage, + client_id="test_id", + client_secret="test_secret", + token_endpoint_auth_method="client_secret_basic" + ) + await provider._initialize() + provider.context.current_tokens = mock_token_with_refresh + + request = await provider._refresh_token() + + # client_secret_basic should use Authorization header + assert "Authorization" in request.headers or "authorization" in request.headers + + @pytest.mark.asyncio + async def test_refresh_token_uses_client_secret_post(self, mock_storage, mock_token_with_refresh): + """Test that client_secret_post method includes secret in body""" + provider = RefreshTokenOauthProvider( + server_url="https://mcp.example.com/mcp", + storage=mock_storage, + client_id="test_id", + client_secret="test_secret", + token_endpoint_auth_method="client_secret_post" + ) + await provider._initialize() + provider.context.current_tokens = mock_token_with_refresh + + request = await provider._refresh_token() + + # client_secret_post should include client_secret in request body + request_body = str(request.content) + assert "client_secret" in request_body or request.method == "POST" + + @pytest.mark.asyncio + @patch.object(RefreshTokenOauthProvider, '_get_token_endpoint') + async def test_refresh_token_prefers_provided_endpoint( + self, mock_get_endpoint, mock_storage, mock_token_with_refresh + ): + """Test that provided endpoint is preferred over discovered endpoint""" + provider = RefreshTokenOauthProvider( + server_url="https://mcp.example.com/mcp", + storage=mock_storage, + client_id="test_id", + token_endpoint="https://fallback.auth.com/token" + ) + await provider._initialize() + provider.context.current_tokens = mock_token_with_refresh + + # Mock that discovery found an endpoint + mock_get_endpoint.return_value = "https://discovered.auth.com/token" + + request = await provider._refresh_token() + + # Should use provided endpoint, not discovered one + assert str(request.url) == "https://fallback.auth.com/token" + + @pytest.mark.asyncio + @patch.object(RefreshTokenOauthProvider, '_get_token_endpoint') + async def test_refresh_token_falls_back_to_discovered_endpoint( + self, mock_get_endpoint, mock_storage, mock_token_with_refresh + ): + """Test fallback to discovered endpoint when provided endpoint is None""" + provider = RefreshTokenOauthProvider( + server_url="https://mcp.example.com/mcp", + storage=mock_storage, + client_id="test_id", + token_endpoint=None + ) + await provider._initialize() + provider.context.current_tokens = mock_token_with_refresh + + # Mock that discovery found nothing + mock_get_endpoint.return_value = "https://mcp.example.com/token" + + request = await provider._refresh_token() + + # Should use discovered endpoint as fallback + assert str(request.url) == "https://mcp.example.com/token" + + @pytest.mark.asyncio + async def test_refresh_token_includes_client_id(self, provider, mock_token_with_refresh): + """Test that client_id is included in refresh request""" + await provider._initialize() + provider.context.current_tokens = mock_token_with_refresh + + request = await provider._refresh_token() + + request_body = str(request.content) + assert "test_client_id" in request_body + + @pytest.mark.asyncio + async def test_refresh_token_includes_refresh_token_value(self, provider, mock_token_with_refresh): + """Test that refresh_token value is included in request""" + await provider._initialize() + provider.context.current_tokens = mock_token_with_refresh + + request = await provider._refresh_token() + + request_body = str(request.content) + assert "test_refresh_token" in request_body + + @pytest.mark.asyncio + async def test_refresh_token_works_without_client_secret(self, mock_storage, mock_token_with_refresh): + """Test that refresh works without client_secret (some OAuth servers allow this)""" + provider = RefreshTokenOauthProvider( + server_url="https://mcp.example.com/mcp", + storage=mock_storage, + client_id="test_id", + client_secret=None + ) + await provider._initialize() + provider.context.current_tokens = mock_token_with_refresh + + # Should not raise an error + request = await provider._refresh_token() + + assert request.method == "POST" + request_body = str(request.content) + assert "grant_type" in request_body + assert "refresh_token" in request_body diff --git a/tests/neuro_san/internals/run_context/langchain/mcp/test_sly_data_token_storage.py b/tests/neuro_san/internals/run_context/langchain/mcp/test_sly_data_token_storage.py new file mode 100644 index 000000000..3b3e24937 --- /dev/null +++ b/tests/neuro_san/internals/run_context/langchain/mcp/test_sly_data_token_storage.py @@ -0,0 +1,255 @@ +# Copyright © 2023-2026 Cognizant Technology Solutions Corp, www.cognizant.com. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# END COPYRIGHT + +import pytest +from mcp.shared.auth import OAuthClientInformationFull +from mcp.shared.auth import OAuthToken + +from neuro_san.internals.run_context.langchain.mcp.sly_data_token_storage import SlyDataTokenStorage + + +class TestSlyDataTokenStorage: + """Test suite for SlyDataTokenStorage class""" + + @pytest.fixture + def valid_token_dict(self): + """Create valid token dictionary""" + return { + "access_token": "test_access_token", + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "test_refresh_token" + } + + @pytest.fixture + def valid_client_info_dict(self): + """Create valid client info dictionary""" + return { + "client_id": "test_client_id", + "client_secret": "test_client_secret", + "grant_types": ["client_credentials"], + "token_endpoint_auth_method": "client_secret_basic", + # Not used in client credentials flow, but required for OAuthClientInformationFull validation + "redirect_uris": None + } + + @pytest.fixture + def storage_with_data(self, valid_client_info_dict, valid_token_dict): + """Create storage with valid data""" + return SlyDataTokenStorage( + client_info=valid_client_info_dict.copy(), + tokens=valid_token_dict.copy() + ) + + @pytest.fixture + def storage_empty(self): + """Create storage with empty data""" + return SlyDataTokenStorage(client_info={}, tokens={}) + + def test_init(self, valid_client_info_dict, valid_token_dict): + """Test storage initialization""" + storage = SlyDataTokenStorage( + client_info=valid_client_info_dict, + tokens=valid_token_dict + ) + + assert storage.client_info == valid_client_info_dict + assert storage.tokens == valid_token_dict + assert storage.logger is not None + + @pytest.mark.asyncio + async def test_get_tokens_valid(self, storage_with_data): + """Test getting valid tokens""" + tokens = await storage_with_data.get_tokens() + + assert isinstance(tokens, OAuthToken) + assert tokens.access_token == "test_access_token" + assert tokens.token_type == "Bearer" + assert tokens.expires_in == 3600 + assert tokens.refresh_token == "test_refresh_token" + + @pytest.mark.asyncio + async def test_get_tokens_empty(self, storage_empty): + """Test getting tokens when storage is empty""" + tokens = await storage_empty.get_tokens() + + assert tokens is None + + @pytest.mark.asyncio + async def test_get_tokens_invalid_data(self, caplog): + """Test getting tokens with invalid data""" + storage = SlyDataTokenStorage( + client_info={}, + tokens={"invalid": "data"} # Missing required fields + ) + + tokens = await storage.get_tokens() + + assert tokens is None + assert "Failed to load token from sly data" in caplog.text + + @pytest.mark.asyncio + async def test_get_tokens_dict(self, storage_with_data, valid_token_dict): + """Test getting raw token dictionary""" + tokens_dict = await storage_with_data.get_tokens_dict() + + assert tokens_dict == valid_token_dict + + @pytest.mark.asyncio + async def test_set_tokens(self, storage_empty, valid_token_dict): + """Test setting tokens""" + oauth_token = OAuthToken(**valid_token_dict) + + await storage_empty.set_tokens(oauth_token) + + assert storage_empty.tokens["access_token"] == "test_access_token" + assert storage_empty.tokens["token_type"] == "Bearer" + assert storage_empty.tokens["expires_in"] == 3600 + assert storage_empty.tokens["refresh_token"] == "test_refresh_token" + + @pytest.mark.asyncio + async def test_set_tokens_clears_existing(self, storage_with_data): + """Test that setting tokens clears existing tokens""" + new_token = OAuthToken( + access_token="new_access_token", + token_type="Bearer", + expires_in=7200 + ) + + await storage_with_data.set_tokens(new_token) + + assert storage_with_data.tokens["access_token"] == "new_access_token" + assert storage_with_data.tokens["expires_in"] == 7200 + + @pytest.mark.asyncio + async def test_set_tokens_error_handling(self, storage_empty, caplog): + """Test error handling when setting invalid tokens""" + # Create a mock that will cause AttributeError + invalid_token = None + + # This should trigger error handling + await storage_empty.set_tokens(invalid_token) + + assert "Failed to save tokens in sly data" in caplog.text + + @pytest.mark.asyncio + async def test_get_client_info_valid(self, storage_with_data): + """Test getting valid client info""" + client_info = await storage_with_data.get_client_info() + + assert isinstance(client_info, OAuthClientInformationFull) + assert client_info.client_id == "test_client_id" + assert client_info.client_secret == "test_client_secret" + + @pytest.mark.asyncio + async def test_get_client_info_empty(self, storage_empty): + """Test getting client info when storage is empty""" + client_info = await storage_empty.get_client_info() + + assert client_info is None + + @pytest.mark.asyncio + async def test_get_client_info_invalid_data(self, caplog): + """Test getting client info with invalid data""" + storage = SlyDataTokenStorage( + client_info={"invalid": "data"}, # Missing required fields + tokens={} + ) + + client_info = await storage.get_client_info() + + assert client_info is None + assert "Failed to instantiate OAuthClientInformationFull" in caplog.text + + @pytest.mark.asyncio + async def test_get_client_info_dict(self, storage_with_data, valid_client_info_dict): + """Test getting raw client info dictionary""" + client_info_dict = await storage_with_data.get_client_info_dict() + + assert client_info_dict == valid_client_info_dict + + @pytest.mark.asyncio + async def test_set_client_info(self, storage_empty): + """Test setting client info""" + client_info = OAuthClientInformationFull( + client_id="new_client_id", + client_secret="new_client_secret", + grant_types=["authorization_code"], + token_endpoint_auth_method="client_secret_post", + # Not used in client credentials flow, but required for OAuthClientInformationFull validation + redirect_uris=None + ) + + await storage_empty.set_client_info(client_info) + + assert storage_empty.client_info["client_id"] == "new_client_id" + assert storage_empty.client_info["client_secret"] == "new_client_secret" + + @pytest.mark.asyncio + async def test_set_client_info_clears_existing(self, storage_with_data): + """Test that setting client info clears existing data""" + new_client_info = OAuthClientInformationFull( + client_id="new_id", + # Not used in client credentials flow, but required for OAuthClientInformationFull validation + redirect_uris=None + ) + + await storage_with_data.set_client_info(new_client_info) + + assert storage_with_data.client_info["client_id"] == "new_id" + + @pytest.mark.asyncio + async def test_set_client_info_error_handling(self, storage_empty, caplog): + """Test error handling when setting invalid client info""" + # This should trigger error handling + await storage_empty.set_client_info(None) + + assert "Failed to save client info in sly data" in caplog.text + + @pytest.mark.asyncio + async def test_tokens_reference_preserved(self, valid_token_dict): + """Test that tokens dictionary reference is preserved""" + token_dict = valid_token_dict.copy() + storage = SlyDataTokenStorage(client_info={}, tokens=token_dict) + + # Modify through storage + oauth_token = OAuthToken( + access_token="updated_token", + token_type="Bearer", + expires_in=1800 + ) + await storage.set_tokens(oauth_token) + + # Check that original reference was updated + assert token_dict["access_token"] == "updated_token" + assert token_dict["expires_in"] == 1800 + + @pytest.mark.asyncio + async def test_client_info_reference_preserved(self, valid_client_info_dict): + """Test that client info dictionary reference is preserved""" + client_dict = valid_client_info_dict.copy() + storage = SlyDataTokenStorage(client_info=client_dict, tokens={}) + + # Modify through storage + new_client_info = OAuthClientInformationFull( + client_id="updated_id", + # Not used in client credentials flow, but required for OAuthClientInformationFull validation + redirect_uris=None + ) + await storage.set_client_info(new_client_info) + + # Check that original reference was updated + assert client_dict["client_id"] == "updated_id"