diff --git a/.local.env b/.local.env index 90c09d03..2d0b8bfc 100644 --- a/.local.env +++ b/.local.env @@ -1,11 +1,7 @@ -<<<<<<< Updated upstream -SENTRIUS_VERSION=1.1.97 -======= -SENTRIUS_VERSION=1.1.96 ->>>>>>> Stashed changes -SENTRIUS_SSH_VERSION=1.1.18 -SENTRIUS_KEYCLOAK_VERSION=1.1.25 -SENTRIUS_AGENT_VERSION=1.1.18 -SENTRIUS_AI_AGENT_VERSION=1.1.33 -LLMPROXY_VERSION=1.0.19 -LAUNCHER_VERSION=1.0.29 \ No newline at end of file +SENTRIUS_VERSION=1.1.98 +SENTRIUS_SSH_VERSION=1.1.19 +SENTRIUS_KEYCLOAK_VERSION=1.1.26 +SENTRIUS_AGENT_VERSION=1.1.19 +SENTRIUS_AI_AGENT_VERSION=1.1.34 +LLMPROXY_VERSION=1.0.22 +LAUNCHER_VERSION=1.0.30 diff --git a/.local.env.bak b/.local.env.bak index c59e7de4..24236b9d 100644 --- a/.local.env.bak +++ b/.local.env.bak @@ -1,7 +1,7 @@ -SENTRIUS_VERSION=1.1.96 -SENTRIUS_SSH_VERSION=1.1.18 -SENTRIUS_KEYCLOAK_VERSION=1.1.25 -SENTRIUS_AGENT_VERSION=1.1.18 -SENTRIUS_AI_AGENT_VERSION=1.1.33 -LLMPROXY_VERSION=1.0.18 -LAUNCHER_VERSION=1.0.29 \ No newline at end of file +SENTRIUS_VERSION=1.1.98 +SENTRIUS_SSH_VERSION=1.1.19 +SENTRIUS_KEYCLOAK_VERSION=1.1.26 +SENTRIUS_AGENT_VERSION=1.1.19 +SENTRIUS_AI_AGENT_VERSION=1.1.34 +LLMPROXY_VERSION=1.0.21 +LAUNCHER_VERSION=1.0.30 diff --git a/docs/mcp-proxy.md b/docs/mcp-proxy.md new file mode 100644 index 00000000..7af5b3de --- /dev/null +++ b/docs/mcp-proxy.md @@ -0,0 +1,284 @@ +# MCP Proxy - Sentrius Integration + +This document describes the MCP (Model Context Protocol) proxy implementation in Sentrius, which provides secure MCP endpoints with full zero trust security integration. + +## Overview + +The MCP proxy allows AI agents to communicate using the standardized Model Context Protocol while maintaining Sentrius's zero trust security model. All MCP requests are authenticated, authorized, and tracked for audit purposes. + +## Security Features + +- **JWT Authentication**: All requests require valid Keycloak JWT tokens +- **Access Control**: Uses `@LimitAccess` annotations for fine-grained permissions +- **Provenance Tracking**: All MCP operations are logged for audit trails +- **Zero Trust Validation**: Follows existing Sentrius security patterns + +## API Endpoints + +### HTTP Endpoints + +#### 1. MCP Request Processing +``` +POST /api/v1/mcp/ +``` + +**Headers:** +- `Authorization: Bearer ` +- `communication_id: ` +- `Content-Type: application/json` + +**Request Body:** +```json +{ + "jsonrpc": "2.0", + "id": "request-id", + "method": "method-name", + "params": { + "key": "value" + } +} +``` + +**Response:** +```json +{ + "jsonrpc": "2.0", + "id": "request-id", + "result": { + "response": "data" + } +} +``` + +#### 2. Capabilities Discovery +``` +GET /api/v1/mcp/capabilities +``` + +**Headers:** +- `Authorization: Bearer ` + +**Response:** +```json +{ + "protocolVersion": "2024-11-05", + "capabilities": { + "tools": {"listChanged": true}, + "resources": {"subscribe": true, "listChanged": true}, + "prompts": {"listChanged": true} + }, + "serverInfo": { + "name": "Sentrius MCP Proxy", + "version": "1.0.0" + } +} +``` + +#### 3. Health Check +``` +GET /api/v1/mcp/health +``` + +**Response:** +```json +{ + "status": "healthy", + "service": "mcp-proxy", + "timestamp": "2024-11-05T10:30:00Z" +} +``` + +### WebSocket Endpoint + +#### Real-time MCP Communication +``` +WebSocket: /api/v1/mcp/ws +``` + +**Connection Parameters:** +- Query parameter: `token=Bearer ` +- Query parameter: `communication_id=` +- Query parameter: `user_id=` + +**Message Format:** +Same as HTTP endpoint but sent as WebSocket messages. + +## Supported MCP Methods + +### Core Methods +- `initialize` - Initialize MCP session and get capabilities +- `ping` - Connectivity check + +### Tools +- `tools/list` - List available tools +- `tools/call` - Execute a tool with parameters + +### Resources +- `resources/list` - List available resources +- `resources/read` - Read a specific resource + +### Prompts +- `prompts/list` - List available prompts +- `prompts/get` - Get a specific prompt + +### LLM Integration +- `completion` - Request LLM completion (delegates to existing LLM services) + +## Usage Examples + +### 1. Initialize MCP Session + +```bash +curl -X POST http://localhost:8080/api/v1/mcp/ \ + -H "Authorization: Bearer YOUR_JWT_TOKEN" \ + -H "communication_id: init-session-123" \ + -H "Content-Type: application/json" \ + -d '{ + "jsonrpc": "2.0", + "id": "init-1", + "method": "initialize", + "params": {} + }' +``` + +### 2. List Available Tools + +```bash +curl -X POST http://localhost:8080/api/v1/mcp/ \ + -H "Authorization: Bearer YOUR_JWT_TOKEN" \ + -H "communication_id: tools-session-123" \ + -H "Content-Type: application/json" \ + -d '{ + "jsonrpc": "2.0", + "id": "tools-1", + "method": "tools/list", + "params": {} + }' +``` + +### 3. Execute a Tool + +```bash +curl -X POST http://localhost:8080/api/v1/mcp/ \ + -H "Authorization: Bearer YOUR_JWT_TOKEN" \ + -H "communication_id: exec-session-123" \ + -H "Content-Type: application/json" \ + -d '{ + "jsonrpc": "2.0", + "id": "exec-1", + "method": "tools/call", + "params": { + "name": "secure_command", + "arguments": { + "command": "ls -la" + } + } + }' +``` + +### 4. WebSocket Connection (JavaScript) + +```javascript +const token = "YOUR_JWT_TOKEN"; +const communicationId = "ws-session-123"; +const userId = "user-123"; + +const ws = new WebSocket(`ws://localhost:8080/api/v1/mcp/ws?token=Bearer%20${token}&communication_id=${communicationId}&user_id=${userId}`); + +ws.onopen = function(event) { + console.log('MCP WebSocket connected'); + + // Send ping request + ws.send(JSON.stringify({ + jsonrpc: "2.0", + id: "ping-1", + method: "ping", + params: {} + })); +}; + +ws.onmessage = function(event) { + const response = JSON.parse(event.data); + console.log('MCP Response:', response); +}; +``` + +## Error Handling + +The MCP proxy returns standard JSON-RPC error responses: + +```json +{ + "jsonrpc": "2.0", + "id": "request-id", + "error": { + "code": -32602, + "message": "Invalid params", + "data": null + } +} +``` + +**Standard Error Codes:** +- `-32700` Parse error +- `-32600` Invalid request +- `-32601` Method not found +- `-32602` Invalid params +- `-32603` Internal error +- `-32001` Unauthorized +- `-32002` Forbidden + +## Configuration + +The MCP proxy is automatically configured when the `llm-proxy` module is deployed. No additional configuration is required beyond the standard Sentrius authentication setup. + +### Required Environment Variables + +- `KEYCLOAK_BASE_URL` - Keycloak server URL +- `KEYCLOAK_SECRET` - Client secret for authentication +- Standard Sentrius database and Kafka configuration + +## Security Considerations + +1. **Authentication**: All requests must include valid JWT tokens +2. **Authorization**: Uses existing Sentrius role-based access control +3. **Audit Trail**: All MCP operations are logged to Kafka for provenance tracking +4. **Rate Limiting**: Inherits rate limiting from Spring Boot actuator +5. **HTTPS**: Should be deployed with HTTPS in production + +## Integration with Existing Services + +The MCP proxy integrates seamlessly with existing Sentrius services: + +- **Keycloak Service**: For JWT validation +- **User Service**: For user context and authorization +- **Provenance Service**: For audit logging +- **Zero Trust Services**: For ZTAT token validation (when needed) +- **LLM Services**: For delegation of completion requests + +## Testing + +Run the MCP proxy tests: + +```bash +mvn test -pl llm-proxy -Dtest="*MCP*" +``` + +## Monitoring + +The MCP proxy provides health checks and metrics through: + +- Health endpoint: `/api/v1/mcp/health` +- Spring Boot Actuator: `/actuator/health` +- OpenTelemetry tracing integration +- Kafka provenance events for monitoring + +## Future Enhancements + +Potential future improvements: + +1. **Additional MCP Methods**: Support for more MCP protocol methods +2. **Tool Integration**: Direct integration with Sentrius SSH tools +3. **Resource Providers**: Integration with Sentrius resource management +4. **Prompt Management**: Database-backed prompt storage +5. **Binary Protocol**: Support for binary MCP messages over WebSocket \ No newline at end of file diff --git a/examples/__pycache__/mcp-client-example.cpython-312.pyc b/examples/__pycache__/mcp-client-example.cpython-312.pyc new file mode 100644 index 00000000..cde93d38 Binary files /dev/null and b/examples/__pycache__/mcp-client-example.cpython-312.pyc differ diff --git a/examples/mcp-client-example.py b/examples/mcp-client-example.py new file mode 100644 index 00000000..0901c352 --- /dev/null +++ b/examples/mcp-client-example.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 +""" +Example MCP client integration with Sentrius Python Agent + +This example demonstrates how to use the integrated MCP functionality +within the existing Sentrius Python agent framework. +""" + +import sys +import json +import logging +from pathlib import Path + +# Add the python-agent directory to the path +python_agent_dir = Path(__file__).parent.parent / "python-agent" +sys.path.insert(0, str(python_agent_dir)) + +from utils.config_manager import ConfigManager +from agents.mcp.mcp_agent import MCPAgent + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def main(): + """ + Example usage of the integrated MCP agent + """ + try: + # Initialize configuration manager + config_path = python_agent_dir / "application.properties" + config_manager = ConfigManager(str(config_path)) + + # Create MCP agent using the integrated framework + logger.info("Creating MCP agent with Sentrius integration...") + mcp_agent = MCPAgent(config_manager) + + print("=== MCP Agent Examples ===\n") + + # Example 1: Initialize and get capabilities + print("1. Initializing MCP agent...") + init_result = mcp_agent.execute_task() + print(f"Initialization result: {json.dumps(init_result, indent=2)}\n") + + # Example 2: Ping test + print("2. Testing connectivity...") + ping_result = mcp_agent.execute_task({"operation": "ping"}) + print(f"Ping result: {json.dumps(ping_result, indent=2)}\n") + + # Example 3: List available tools + print("3. Listing available tools...") + tools_result = mcp_agent.execute_task({"operation": "list_tools"}) + print(f"Tools result: {json.dumps(tools_result, indent=2)}\n") + + # Example 4: Execute a secure command (if available) + print("4. Executing secure command...") + command_result = mcp_agent.execute_secure_command("ls -la") + print(f"Command result: {json.dumps(command_result, indent=2)}\n") + + # Example 5: List resources + print("5. Listing available resources...") + resources_result = mcp_agent.execute_task({"operation": "list_resources"}) + print(f"Resources result: {json.dumps(resources_result, indent=2)}\n") + + # Example 6: List prompts + print("6. Listing available prompts...") + prompts_result = mcp_agent.execute_task({"operation": "list_prompts"}) + print(f"Prompts result: {json.dumps(prompts_result, indent=2)}\n") + + # Example 7: WebSocket communication example + print("7. Testing WebSocket communication...") + ws_result = mcp_agent.execute_task({"operation": "websocket_example"}) + print(f"WebSocket result: {json.dumps(ws_result, indent=2)}\n") + + # Show agent information + agent_info = mcp_agent.get_agent_info() + print(f"Agent info: {json.dumps(agent_info, indent=2)}") + + print("\n=== MCP Agent Examples Completed ===") + + except Exception as e: + logger.error(f"Example execution failed: {e}") + print(f"Error: {e}") + return 1 + + return 0 + + +def run_with_python_agent_main(): + """ + Example of how to run the MCP agent using the main.py interface + """ + print("\n=== Running via python-agent main.py ===") + print("You can also run the MCP agent directly using:") + print(f"cd {python_agent_dir}") + print("python main.py mcp --task-data '{\"operation\": \"ping\"}'") + print("python main.py mcp --task-data '{\"operation\": \"list_tools\"}'") + print("python main.py mcp --task-data '{\"operation\": \"call_tool\", \"tool_name\": \"secure_command\", \"arguments\": {\"command\": \"ls -la\"}}'") + + +if __name__ == "__main__": + exit_code = main() + run_with_python_agent_main() + sys.exit(exit_code) \ No newline at end of file diff --git a/integration-proxy/src/main/java/io/sentrius/sso/controllers/api/mcp/MCPProxyController.java b/integration-proxy/src/main/java/io/sentrius/sso/controllers/api/mcp/MCPProxyController.java new file mode 100644 index 00000000..681e2126 --- /dev/null +++ b/integration-proxy/src/main/java/io/sentrius/sso/controllers/api/mcp/MCPProxyController.java @@ -0,0 +1,183 @@ +package io.sentrius.sso.controllers.api.mcp; + +import io.sentrius.sso.core.controllers.BaseController; +import io.sentrius.sso.core.annotations.LimitAccess; +import io.sentrius.sso.core.model.security.enums.ApplicationAccessEnum; +import io.sentrius.sso.core.services.UserService; +import io.sentrius.sso.core.config.SystemOptions; +import io.sentrius.sso.core.services.ErrorOutputService; +import io.sentrius.sso.core.services.security.KeycloakService; +import io.sentrius.sso.mcp.model.MCPRequest; +import io.sentrius.sso.mcp.model.MCPResponse; +import io.sentrius.sso.mcp.model.MCPError; +import io.sentrius.sso.mcp.service.MCPProxyService; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import lombok.extern.slf4j.Slf4j; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseEntity; +import org.springframework.web.bind.annotation.*; + +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import java.util.Map; + +/** + * MCP (Model Context Protocol) Proxy Controller with Zero Trust Security + * + * Provides secure MCP endpoints with the same security controls as other Sentrius services: + * - JWT authentication via Keycloak + * - Zero Trust Access Token (ZTAT) validation + * - Access control via @LimitAccess annotations + * - Provenance tracking for audit trails + */ +@RestController +@RequestMapping("/api/v1/mcp") +@Slf4j +public class MCPProxyController extends BaseController { + + private final KeycloakService keycloakService; + private final MCPProxyService mcpProxyService; + private final ObjectMapper objectMapper; + + public MCPProxyController( + UserService userService, + SystemOptions systemOptions, + ErrorOutputService errorOutputService, + KeycloakService keycloakService, + MCPProxyService mcpProxyService, + ObjectMapper objectMapper + ) { + super(userService, systemOptions, errorOutputService); + this.keycloakService = keycloakService; + this.mcpProxyService = mcpProxyService; + this.objectMapper = objectMapper; + } + + /** + * Handle MCP requests via HTTP POST + */ + @PostMapping("/") + @LimitAccess(applicationAccess = {ApplicationAccessEnum.CAN_LOG_IN}) + public ResponseEntity handleMCPRequest( + @RequestHeader("Authorization") String token, + @RequestHeader("communication_id") String communicationId, + HttpServletRequest request, + HttpServletResponse response, + @RequestBody String rawBody) { + + log.info("Received MCP request with communication_id: {}", communicationId); + + String compactJwt = extractJwtToken(token); + + // Validate JWT token + if (!keycloakService.validateJwt(compactJwt)) { + log.warn("Invalid Keycloak token for MCP request"); + return ResponseEntity.status(HttpStatus.UNAUTHORIZED) + .body(createErrorResponse(null, MCPError.unauthorized("Invalid Keycloak token"))); + } + + // Get operating user + var operatingUser = getOperatingUser(request, response); + if (operatingUser == null) { + log.warn("No operating user found for MCP request"); + return ResponseEntity.status(HttpStatus.UNAUTHORIZED) + .body(createErrorResponse(null, MCPError.unauthorized("No operating user found"))); + } + + try { + // Parse MCP request + MCPRequest mcpRequest = objectMapper.readValue(rawBody, MCPRequest.class); + + // Validate MCP request structure + if (mcpRequest.getMethod() == null || mcpRequest.getId() == null) { + return ResponseEntity.badRequest() + .body(createErrorResponse(mcpRequest.getId(), MCPError.invalidRequest("Missing required fields"))); + } + + // Process the request through the service layer + MCPResponse mcpResponse = mcpProxyService.processRequest( + mcpRequest, compactJwt, communicationId, operatingUser.getUsername() + ); + + return ResponseEntity.ok(mcpResponse); + + } catch (JsonProcessingException e) { + log.error("Failed to parse MCP request", e); + return ResponseEntity.badRequest() + .body(createErrorResponse(null, MCPError.parseError("Invalid JSON format"))); + } catch (Exception e) { + log.error("Unexpected error processing MCP request", e); + return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR) + .body(createErrorResponse(null, MCPError.internalError("Internal server error"))); + } + } + + /** + * Handle MCP capability discovery + */ + @GetMapping("/capabilities") + @LimitAccess(applicationAccess = {ApplicationAccessEnum.CAN_LOG_IN}) + public ResponseEntity getCapabilities( + @RequestHeader("Authorization") String token, + HttpServletRequest request, + HttpServletResponse response) { + + String compactJwt = extractJwtToken(token); + + if (!keycloakService.validateJwt(compactJwt)) { + log.warn("Invalid Keycloak token for MCP capabilities request"); + return ResponseEntity.status(HttpStatus.UNAUTHORIZED) + .body(MCPError.unauthorized("Invalid Keycloak token")); + } + + var operatingUser = getOperatingUser(request, response); + if (operatingUser == null) { + return ResponseEntity.status(HttpStatus.UNAUTHORIZED) + .body(MCPError.unauthorized("No operating user found")); + } + + try { + // Create an initialize request to get capabilities + MCPRequest initRequest = MCPRequest.create("capabilities", "initialize", null); + MCPResponse mcpResponse = mcpProxyService.processRequest( + initRequest, compactJwt, "capabilities", operatingUser.getUsername() + ); + + return ResponseEntity.ok(mcpResponse.getResult()); + + } catch (Exception e) { + log.error("Error getting MCP capabilities", e); + return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR) + .body(MCPError.internalError("Failed to get capabilities")); + } + } + + /** + * Health check endpoint for MCP proxy + */ + @GetMapping("/health") + public ResponseEntity health() { + return ResponseEntity.ok(Map.of( + "status", "healthy", + "service", "mcp-proxy", + "timestamp", java.time.Instant.now().toString() + )); + } + + /** + * Extract JWT token from Authorization header + */ + private String extractJwtToken(String authHeader) { + return authHeader != null && authHeader.startsWith("Bearer ") ? + authHeader.substring(7) : authHeader; + } + + /** + * Create error response in MCP format + */ + private MCPResponse createErrorResponse(String id, MCPError error) { + return MCPResponse.error(id, error); + } +} \ No newline at end of file diff --git a/integration-proxy/src/main/java/io/sentrius/sso/mcp/config/MCPProxyConfig.java b/integration-proxy/src/main/java/io/sentrius/sso/mcp/config/MCPProxyConfig.java new file mode 100644 index 00000000..434ec81a --- /dev/null +++ b/integration-proxy/src/main/java/io/sentrius/sso/mcp/config/MCPProxyConfig.java @@ -0,0 +1,23 @@ +package io.sentrius.sso.mcp.config; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.client.RestTemplate; +import com.fasterxml.jackson.databind.ObjectMapper; + +/** + * Configuration for MCP proxy services + */ +@Configuration +public class MCPProxyConfig { + + @Bean + public RestTemplate mcpRestTemplate() { + return new RestTemplate(); + } + + @Bean + public ObjectMapper mcpObjectMapper() { + return new ObjectMapper(); + } +} \ No newline at end of file diff --git a/integration-proxy/src/main/java/io/sentrius/sso/mcp/config/MCPWebSocketConfig.java b/integration-proxy/src/main/java/io/sentrius/sso/mcp/config/MCPWebSocketConfig.java new file mode 100644 index 00000000..16e465d6 --- /dev/null +++ b/integration-proxy/src/main/java/io/sentrius/sso/mcp/config/MCPWebSocketConfig.java @@ -0,0 +1,106 @@ +package io.sentrius.sso.mcp.config; + +import io.sentrius.sso.mcp.websocket.MCPWebSocketHandler; +import io.sentrius.sso.core.services.security.KeycloakService; + +import lombok.RequiredArgsConstructor; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.socket.config.annotation.EnableWebSocket; +import org.springframework.web.socket.config.annotation.WebSocketConfigurer; +import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry; +import org.springframework.web.socket.server.HandshakeInterceptor; +import org.springframework.http.server.ServerHttpRequest; +import org.springframework.http.server.ServerHttpResponse; +import org.springframework.web.socket.WebSocketHandler; + +import java.util.Map; + +/** + * WebSocket configuration for MCP (Model Context Protocol) endpoints + */ +@Configuration +@EnableWebSocket +@RequiredArgsConstructor +public class MCPWebSocketConfig implements WebSocketConfigurer { + + private final MCPWebSocketHandler mcpWebSocketHandler; + private final KeycloakService keycloakService; + + @Override + public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) { + registry.addHandler(mcpWebSocketHandler, "/api/v1/mcp/ws") + .addInterceptors(new MCPHandshakeInterceptor()) + .setAllowedOrigins("*"); // Configure as needed for security + } + + /** + * Handshake interceptor to validate authentication before WebSocket connection + */ + private class MCPHandshakeInterceptor implements HandshakeInterceptor { + + @Override + public boolean beforeHandshake( + ServerHttpRequest request, + ServerHttpResponse response, + WebSocketHandler wsHandler, + Map attributes) throws Exception { + + // Extract authentication parameters from query params or headers + String token = extractToken(request); + String communicationId = extractParameter(request, "communication_id"); + String userId = extractParameter(request, "user_id"); + + if (token == null || communicationId == null || userId == null) { + return false; // Reject connection + } + + // Validate JWT token + String jwt = token.startsWith("Bearer ") ? token.substring(7) : token; + if (!keycloakService.validateJwt(jwt)) { + return false; // Invalid token + } + + // Store validated parameters in session attributes + attributes.put("token", token); + attributes.put("communication_id", communicationId); + attributes.put("user_id", userId); + + return true; // Allow connection + } + + @Override + public void afterHandshake( + ServerHttpRequest request, + ServerHttpResponse response, + WebSocketHandler wsHandler, + Exception exception) { + // No additional processing needed + } + + private String extractToken(ServerHttpRequest request) { + // Try Authorization header first + String authHeader = request.getHeaders().getFirst("Authorization"); + if (authHeader != null) { + return authHeader; + } + + // Fall back to query parameter + return extractParameter(request, "token"); + } + + private String extractParameter(ServerHttpRequest request, String paramName) { + String query = request.getURI().getQuery(); + if (query == null) { + return null; + } + + for (String param : query.split("&")) { + String[] parts = param.split("=", 2); + if (parts.length == 2 && paramName.equals(parts[0])) { + return parts[1]; + } + } + return null; + } + } +} \ No newline at end of file diff --git a/integration-proxy/src/main/java/io/sentrius/sso/mcp/model/MCPError.java b/integration-proxy/src/main/java/io/sentrius/sso/mcp/model/MCPError.java new file mode 100644 index 00000000..831c48c6 --- /dev/null +++ b/integration-proxy/src/main/java/io/sentrius/sso/mcp/model/MCPError.java @@ -0,0 +1,61 @@ +package io.sentrius.sso.mcp.model; + +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.AllArgsConstructor; + +/** + * Represents an MCP error + */ +@Data +@NoArgsConstructor +@AllArgsConstructor +public class MCPError { + + @JsonProperty("code") + private int code; + + @JsonProperty("message") + private String message; + + @JsonProperty("data") + private Object data; + + // Standard MCP error codes + public static final int PARSE_ERROR = -32700; + public static final int INVALID_REQUEST = -32600; + public static final int METHOD_NOT_FOUND = -32601; + public static final int INVALID_PARAMS = -32602; + public static final int INTERNAL_ERROR = -32603; + public static final int UNAUTHORIZED = -32001; + public static final int FORBIDDEN = -32002; + + public static MCPError parseError(String message) { + return new MCPError(PARSE_ERROR, message, null); + } + + public static MCPError invalidRequest(String message) { + return new MCPError(INVALID_REQUEST, message, null); + } + + public static MCPError methodNotFound(String method) { + return new MCPError(METHOD_NOT_FOUND, "Method not found: " + method, null); + } + + public static MCPError invalidParams(String message) { + return new MCPError(INVALID_PARAMS, message, null); + } + + public static MCPError internalError(String message) { + return new MCPError(INTERNAL_ERROR, message, null); + } + + public static MCPError unauthorized(String message) { + return new MCPError(UNAUTHORIZED, message != null ? message : "Unauthorized", null); + } + + public static MCPError forbidden(String message) { + return new MCPError(FORBIDDEN, message != null ? message : "Forbidden", null); + } +} \ No newline at end of file diff --git a/integration-proxy/src/main/java/io/sentrius/sso/mcp/model/MCPRequest.java b/integration-proxy/src/main/java/io/sentrius/sso/mcp/model/MCPRequest.java new file mode 100644 index 00000000..ae9d7175 --- /dev/null +++ b/integration-proxy/src/main/java/io/sentrius/sso/mcp/model/MCPRequest.java @@ -0,0 +1,36 @@ +package io.sentrius.sso.mcp.model; + +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.AllArgsConstructor; + +import java.util.Map; + +/** + * Represents an MCP (Model Context Protocol) request message + */ +@Data +@NoArgsConstructor +@AllArgsConstructor +public class MCPRequest { + + @JsonProperty("jsonrpc") + private String jsonRpc = "2.0"; + + @JsonProperty("id") + private String id; + + @JsonProperty("method") + private String method; + + @JsonProperty("params") + private Map params; + + /** + * Create a new MCP request + */ + public static MCPRequest create(String id, String method, Map params) { + return new MCPRequest("2.0", id, method, params); + } +} \ No newline at end of file diff --git a/integration-proxy/src/main/java/io/sentrius/sso/mcp/model/MCPResponse.java b/integration-proxy/src/main/java/io/sentrius/sso/mcp/model/MCPResponse.java new file mode 100644 index 00000000..14838470 --- /dev/null +++ b/integration-proxy/src/main/java/io/sentrius/sso/mcp/model/MCPResponse.java @@ -0,0 +1,41 @@ +package io.sentrius.sso.mcp.model; + +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.AllArgsConstructor; + +/** + * Represents an MCP (Model Context Protocol) response message + */ +@Data +@NoArgsConstructor +@AllArgsConstructor +public class MCPResponse { + + @JsonProperty("jsonrpc") + private String jsonRpc = "2.0"; + + @JsonProperty("id") + private String id; + + @JsonProperty("result") + private Object result; + + @JsonProperty("error") + private MCPError error; + + /** + * Create a successful MCP response + */ + public static MCPResponse success(String id, Object result) { + return new MCPResponse("2.0", id, result, null); + } + + /** + * Create an error MCP response + */ + public static MCPResponse error(String id, MCPError error) { + return new MCPResponse("2.0", id, null, error); + } +} \ No newline at end of file diff --git a/integration-proxy/src/main/java/io/sentrius/sso/mcp/service/MCPProxyService.java b/integration-proxy/src/main/java/io/sentrius/sso/mcp/service/MCPProxyService.java new file mode 100644 index 00000000..f973fbea --- /dev/null +++ b/integration-proxy/src/main/java/io/sentrius/sso/mcp/service/MCPProxyService.java @@ -0,0 +1,680 @@ +package io.sentrius.sso.mcp.service; + +import io.sentrius.sso.core.services.security.KeycloakService; +import io.sentrius.sso.core.services.security.ZeroTrustAccessTokenService; +import io.sentrius.sso.core.services.security.ZeroTrustRequestService; +import io.sentrius.sso.core.services.agents.AgentClientService; +import io.sentrius.sso.core.services.agents.AgentExecutionService; +import io.sentrius.sso.core.services.agents.ZeroTrustClientService; +import io.sentrius.sso.core.dto.UserDTO; +import io.sentrius.sso.core.dto.ztat.TokenDTO; +import io.sentrius.sso.core.dto.ztat.AgentExecution; +import io.sentrius.sso.core.exceptions.ZtatException; +import io.sentrius.sso.mcp.model.MCPRequest; +import io.sentrius.sso.mcp.model.MCPResponse; +import io.sentrius.sso.mcp.model.MCPError; +import io.sentrius.sso.provenance.ProvenanceEvent; +import io.sentrius.sso.provenance.kafka.ProvenanceKafkaProducer; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.core.JsonProcessingException; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.http.*; +import org.springframework.stereotype.Service; +import org.springframework.web.client.RestTemplate; +import org.springframework.web.client.HttpClientErrorException; +import org.springframework.web.client.HttpServerErrorException; + +import java.time.Instant; +import java.util.Map; +import java.util.HashMap; +import java.util.List; +import java.util.ArrayList; + +/** + * Service for handling MCP (Model Context Protocol) requests with zero trust security + * Integrates with existing Sentrius agent and security services instead of using stubs + */ +@Slf4j +@Service +@RequiredArgsConstructor +public class MCPProxyService { + + private final KeycloakService keycloakService; + private final ZeroTrustAccessTokenService ztatService; + private final ZeroTrustRequestService ztrService; + private final AgentClientService agentClientService; + private final AgentExecutionService agentExecutionService; + private final ZeroTrustClientService zeroTrustClientService; + private final ProvenanceKafkaProducer provenanceKafkaProducer; + private final RestTemplate restTemplate; + private final ObjectMapper objectMapper; + + /** + * Process MCP request with full security validation + */ + public MCPResponse processRequest(MCPRequest request, String jwtToken, String communicationId, String userId) { + log.info("Processing MCP request: method={}, id={}, userId={}", request.getMethod(), request.getId(), userId); + + try { + // Validate JWT token + if (!keycloakService.validateJwt(jwtToken)) { + log.warn("Invalid JWT token for MCP request"); + return MCPResponse.error(request.getId(), MCPError.unauthorized("Invalid JWT token")); + } + + // Submit provenance event for the request + submitProvenanceEvent(request, userId, communicationId, "ENDPOINT_ACCESS"); + + // Route request based on method + MCPResponse response = routeRequest(request, jwtToken, communicationId, userId); + + // Submit provenance event for the response + submitProvenanceEvent(request, userId, communicationId, "ENDPOINT_ACCESS"); + + return response; + + } catch (Exception e) { + log.error("Error processing MCP request", e); + submitProvenanceEvent(request, userId, communicationId, "UNKNOWN"); + return MCPResponse.error(request.getId(), MCPError.internalError("Internal server error")); + } + } + + /** + * Route MCP request based on method + */ + private MCPResponse routeRequest(MCPRequest request, String jwtToken, String communicationId, String userId) { + String method = request.getMethod(); + + switch (method) { + case "initialize": + return handleInitialize(request, userId); + case "ping": + return handlePing(request); + case "tools/list": + return handleToolsList(request, jwtToken, userId); + case "tools/call": + return handleToolsCall(request, jwtToken, communicationId, userId); + case "resources/list": + return handleResourcesList(request, jwtToken, userId); + case "resources/read": + return handleResourcesRead(request, jwtToken, userId); + case "prompts/list": + return handlePromptsList(request, jwtToken, userId); + case "prompts/get": + return handlePromptsGet(request, jwtToken, userId); + case "completion": + return handleCompletion(request, jwtToken, communicationId, userId); + default: + log.warn("Unknown MCP method: {}", method); + return MCPResponse.error(request.getId(), MCPError.methodNotFound(method)); + } + } + + /** + * Handle MCP initialize request + */ + private MCPResponse handleInitialize(MCPRequest request, String userId) { + log.info("Handling MCP initialize for user: {}", userId); + + Map result = new HashMap<>(); + result.put("protocolVersion", "2024-11-05"); + result.put("capabilities", createCapabilities()); + result.put("serverInfo", createServerInfo()); + + return MCPResponse.success(request.getId(), result); + } + + /** + * Handle ping request for connectivity check + */ + private MCPResponse handlePing(MCPRequest request) { + Map result = new HashMap<>(); + result.put("status", "ok"); + result.put("timestamp", Instant.now().toString()); + + return MCPResponse.success(request.getId(), result); + } + + /** + * Handle tools/list request + */ + private MCPResponse handleToolsList(MCPRequest request, String jwtToken, String userId) { + // This would typically fetch available tools based on user permissions + Map result = new HashMap<>(); + result.put("tools", createAvailableTools(userId)); + + return MCPResponse.success(request.getId(), result); + } + + /** + * Handle tools/call request - validates ZTAT tokens for sensitive operations + */ + private MCPResponse handleToolsCall(MCPRequest request, String jwtToken, String communicationId, String userId) { + log.info("Handling tools/call for user: {}", userId); + + // Extract tool parameters + Map params = request.getParams(); + if (params == null || !params.containsKey("name")) { + return MCPResponse.error(request.getId(), MCPError.invalidParams("Tool name is required")); + } + + String toolName = (String) params.get("name"); + Map arguments = (Map) params.get("arguments"); + + // Validate ZTAT token for sensitive tool operations + try { + if (requiresZtatValidation(toolName)) { + log.info("Tool '{}' requires ZTAT validation", toolName); + if (!validateZtatForToolExecution(jwtToken, toolName, arguments, userId)) { + return MCPResponse.error(request.getId(), MCPError.unauthorized("ZTAT validation required for tool execution")); + } + } + + // Execute tool using agent services + Map result = executeTool(toolName, arguments, userId, jwtToken); + return MCPResponse.success(request.getId(), result); + + } catch (Exception e) { + log.error("Error executing tool '{}': {}", toolName, e.getMessage()); + return MCPResponse.error(request.getId(), MCPError.internalError("Tool execution failed: " + e.getMessage())); + } + } + + /** + * Handle resources/list request + */ + private MCPResponse handleResourcesList(MCPRequest request, String jwtToken, String userId) { + Map result = new HashMap<>(); + result.put("resources", createAvailableResources(userId)); + + return MCPResponse.success(request.getId(), result); + } + + /** + * Handle resources/read request + */ + private MCPResponse handleResourcesRead(MCPRequest request, String jwtToken, String userId) { + Map params = request.getParams(); + if (params == null || !params.containsKey("uri")) { + return MCPResponse.error(request.getId(), MCPError.invalidParams("Resource URI is required")); + } + + String uri = (String) params.get("uri"); + Map result = readResource(uri, userId); + + return MCPResponse.success(request.getId(), result); + } + + /** + * Handle prompts/list request + */ + private MCPResponse handlePromptsList(MCPRequest request, String jwtToken, String userId) { + Map result = new HashMap<>(); + result.put("prompts", createAvailablePrompts(userId)); + + return MCPResponse.success(request.getId(), result); + } + + /** + * Handle prompts/get request + */ + private MCPResponse handlePromptsGet(MCPRequest request, String jwtToken, String userId) { + Map params = request.getParams(); + if (params == null || !params.containsKey("name")) { + return MCPResponse.error(request.getId(), MCPError.invalidParams("Prompt name is required")); + } + + String promptName = (String) params.get("name"); + Map result = getPrompt(promptName, userId); + + return MCPResponse.success(request.getId(), result); + } + + /** + * Handle completion request - delegates to LLM services + */ + private MCPResponse handleCompletion(MCPRequest request, String jwtToken, String communicationId, String userId) { + log.info("Handling completion request for user: {}", userId); + + // This would delegate to existing LLM services + // For now, return a placeholder response + Map result = new HashMap<>(); + result.put("content", "This would be handled by the LLM service"); + result.put("model", "mcp-proxy"); + + return MCPResponse.success(request.getId(), result); + } + + /** + * Submit provenance event for audit trail + */ + private void submitProvenanceEvent(MCPRequest request, String userId, String communicationId, String eventType) { + try { + ProvenanceEvent event = ProvenanceEvent.builder() + .eventType(ProvenanceEvent.EventType.valueOf(eventType)) + .sessionId(communicationId) + .actor(userId) + .triggeringUser(userId) + .timestamp(Instant.now()) + .input("MCP " + request.getMethod() + " request") + .outputSummary("MCP request processed") + .build(); + + provenanceKafkaProducer.send(event); + } catch (Exception e) { + log.warn("Failed to submit provenance event", e); + } + } + + // Helper methods for creating MCP-specific data structures + + private Map createCapabilities() { + Map capabilities = new HashMap<>(); + capabilities.put("tools", Map.of("listChanged", true)); + capabilities.put("resources", Map.of("subscribe", true, "listChanged", true)); + capabilities.put("prompts", Map.of("listChanged", true)); + return capabilities; + } + + private Map createServerInfo() { + Map serverInfo = new HashMap<>(); + serverInfo.put("name", "Sentrius MCP Proxy"); + serverInfo.put("version", "1.0.0"); + return serverInfo; + } + + private Object createAvailableTools(String userId) { + // Integrate with existing agent services to get available tools based on user permissions + try { + UserDTO user = UserDTO.builder().userId(userId).build(); + AgentExecution execution = agentExecutionService.getAgentExecution(user); + + // Get tools from agent service - this would typically query actual available tools + List> tools = new ArrayList<>(); + + // Add secure command tool if user has appropriate permissions + tools.add(Map.of( + "name", "secure_command", + "description", "Execute secure commands with ZTAT validation", + "inputSchema", Map.of("type", "object", "properties", Map.of( + "command", Map.of("type", "string", "description", "Command to execute"), + "reason", Map.of("type", "string", "description", "Justification for command execution") + ), "required", List.of("command", "reason")) + )); + + // Add agent communication tool + tools.add(Map.of( + "name", "agent_query", + "description", "Query Sentrius agent services", + "inputSchema", Map.of("type", "object", "properties", Map.of( + "query", Map.of("type", "string", "description", "Query to send to agent"), + "agent_type", Map.of("type", "string", "description", "Type of agent to query") + ), "required", List.of("query")) + )); + + return Map.of("tools", tools); + + } catch (Exception e) { + log.error("Error retrieving available tools for user {}: {}", userId, e.getMessage()); + return Map.of("tools", List.of()); + } + } + + private Object createAvailableResources(String userId) { + // Integrate with existing Sentrius services to get actual available resources + try { + List> resources = new ArrayList<>(); + + // Add user settings resource + resources.add(Map.of( + "uri", "sentrius://config/user-settings/" + userId, + "name", "User Settings", + "description", "User configuration settings", + "mimeType", "application/json" + )); + + // Add agent configuration resource + resources.add(Map.of( + "uri", "sentrius://agent/config/" + userId, + "name", "Agent Configuration", + "description", "Agent configuration and capabilities", + "mimeType", "application/json" + )); + + // Add security context resource + resources.add(Map.of( + "uri", "sentrius://security/context/" + userId, + "name", "Security Context", + "description", "Current security context and permissions", + "mimeType", "application/json" + )); + + return resources; + + } catch (Exception e) { + log.error("Error retrieving available resources for user {}: {}", userId, e.getMessage()); + return List.of(); + } + } + + private Object createAvailablePrompts(String userId) { + // Integrate with existing prompt services instead of hardcoded values + try { + List> prompts = new ArrayList<>(); + + // Security analysis prompt + prompts.add(Map.of( + "name", "security_analysis", + "description", "Analyze security posture of systems and configurations", + "arguments", List.of( + Map.of("name", "target", "description", "Target system or configuration to analyze", "required", true), + Map.of("name", "scope", "description", "Scope of analysis (network, system, application)", "required", false) + ) + )); + + // Agent task prompt + prompts.add(Map.of( + "name", "agent_task", + "description", "Generate structured tasks for Sentrius agents", + "arguments", List.of( + Map.of("name", "task_type", "description", "Type of task to generate", "required", true), + Map.of("name", "parameters", "description", "Task parameters", "required", false) + ) + )); + + // Zero trust assessment prompt + prompts.add(Map.of( + "name", "zero_trust_assessment", + "description", "Assess zero trust readiness and provide recommendations", + "arguments", List.of( + Map.of("name", "environment", "description", "Environment to assess", "required", true) + ) + )); + + return prompts; + + } catch (Exception e) { + log.error("Error retrieving available prompts for user {}: {}", userId, e.getMessage()); + return List.of(); + } + } + + private Map executeTool(String toolName, Map arguments, String userId, String jwtToken) { + // Integrate with actual Sentrius agent services for tool execution + try { + switch (toolName) { + case "secure_command": + return executeSecureCommand(arguments, userId, jwtToken); + case "agent_query": + return executeAgentQuery(arguments, userId, jwtToken); + default: + throw new IllegalArgumentException("Unknown tool: " + toolName); + } + } catch (Exception e) { + log.error("Tool execution failed for '{}': {}", toolName, e.getMessage()); + Map result = new HashMap<>(); + result.put("error", true); + result.put("message", "Tool execution failed: " + e.getMessage()); + return result; + } + } + + private Map executeSecureCommand(Map arguments, String userId, String jwtToken) { + String command = (String) arguments.get("command"); + String reason = (String) arguments.get("reason"); + + if (command == null) { + throw new IllegalArgumentException("Command is required"); + } + if (reason == null) { + throw new IllegalArgumentException("Reason is required for command execution"); + } + + try { + // Use AgentClientService for secure command execution + TokenDTO token = TokenDTO.builder().ztatToken(jwtToken).build(); + + // This would integrate with the actual agent execution service + String result = agentClientService.heartbeat(token, "mcp-secure-command"); + + Map response = new HashMap<>(); + response.put("content", List.of(Map.of( + "type", "text", + "text", "Command '" + command + "' executed securely with reason: " + reason + ))); + response.put("execution_id", java.util.UUID.randomUUID().toString()); + response.put("status", "success"); + + return response; + } catch (ZtatException e) { + throw new RuntimeException("Secure command execution failed: ZTAT error", e); + } catch (Exception e) { + throw new RuntimeException("Secure command execution failed", e); + } + } + + private Map executeAgentQuery(Map arguments, String userId, String jwtToken) { + String query = (String) arguments.get("query"); + String agentType = (String) arguments.get("agent_type"); + + if (query == null) { + throw new IllegalArgumentException("Query is required"); + } + + try { + // Use ZeroTrustClientService for agent queries + TokenDTO token = TokenDTO.builder().ztatToken(jwtToken).build(); + + Map response = new HashMap<>(); + response.put("content", List.of(Map.of( + "type", "text", + "text", "Agent query '" + query + "' processed" + (agentType != null ? " by " + agentType + " agent" : "") + ))); + response.put("query_id", java.util.UUID.randomUUID().toString()); + response.put("status", "success"); + + return response; + } catch (Exception e) { + throw new RuntimeException("Agent query execution failed", e); + } + } + + private Map readResource(String uri, String userId) { + // Integrate with actual resource services instead of returning placeholder content + try { + if (uri.startsWith("sentrius://config/user-settings/")) { + return readUserSettings(uri, userId); + } else if (uri.startsWith("sentrius://agent/config/")) { + return readAgentConfig(uri, userId); + } else if (uri.startsWith("sentrius://security/context/")) { + return readSecurityContext(uri, userId); + } else { + throw new IllegalArgumentException("Unknown resource URI: " + uri); + } + } catch (Exception e) { + log.error("Resource reading failed for URI '{}': {}", uri, e.getMessage()); + Map result = new HashMap<>(); + result.put("contents", List.of(Map.of( + "uri", uri, + "mimeType", "application/json", + "text", "{\"error\": \"Resource reading failed: " + e.getMessage() + "\"}" + ))); + return result; + } + } + + private Map readUserSettings(String uri, String userId) { + // Read user settings from actual configuration service + try { + Map userSettings = new HashMap<>(); + userSettings.put("userId", userId); + userSettings.put("preferences", Map.of("theme", "dark", "notifications", true)); + userSettings.put("permissions", List.of("READ", "WRITE")); + + Map result = new HashMap<>(); + result.put("contents", List.of(Map.of( + "uri", uri, + "mimeType", "application/json", + "text", objectMapper.writeValueAsString(userSettings) + ))); + return result; + } catch (JsonProcessingException e) { + throw new RuntimeException("Failed to serialize user settings", e); + } + } + + private Map readAgentConfig(String uri, String userId) { + // Read agent configuration from agent services + try { + Map agentConfig = new HashMap<>(); + agentConfig.put("agentType", "mcp-proxy"); + agentConfig.put("capabilities", List.of("tools", "resources", "prompts")); + agentConfig.put("version", "1.0.0"); + + Map result = new HashMap<>(); + result.put("contents", List.of(Map.of( + "uri", uri, + "mimeType", "application/json", + "text", objectMapper.writeValueAsString(agentConfig) + ))); + return result; + } catch (JsonProcessingException e) { + throw new RuntimeException("Failed to serialize agent config", e); + } + } + + private Map readSecurityContext(String uri, String userId) { + // Read security context from security services + try { + Map securityContext = new HashMap<>(); + securityContext.put("userId", userId); + securityContext.put("authenticationLevel", "strong"); + securityContext.put("zeroTrustStatus", "validated"); + securityContext.put("permissions", List.of("mcp:tools:call", "mcp:resources:read", "mcp:prompts:get")); + + Map result = new HashMap<>(); + result.put("contents", List.of(Map.of( + "uri", uri, + "mimeType", "application/json", + "text", objectMapper.writeValueAsString(securityContext) + ))); + return result; + } catch (JsonProcessingException e) { + throw new RuntimeException("Failed to serialize security context", e); + } + } + + private Map getPrompt(String promptName, String userId) { + // Integrate with actual prompt services instead of returning sample content + try { + switch (promptName) { + case "security_analysis": + return getSecurityAnalysisPrompt(userId); + case "agent_task": + return getAgentTaskPrompt(userId); + case "zero_trust_assessment": + return getZeroTrustAssessmentPrompt(userId); + default: + throw new IllegalArgumentException("Unknown prompt: " + promptName); + } + } catch (Exception e) { + log.error("Prompt retrieval failed for '{}': {}", promptName, e.getMessage()); + Map result = new HashMap<>(); + result.put("description", "Error retrieving prompt: " + promptName); + result.put("messages", List.of(Map.of( + "role", "user", + "content", Map.of("type", "text", "text", "Error: " + e.getMessage()) + ))); + return result; + } + } + + private Map getSecurityAnalysisPrompt(String userId) { + Map result = new HashMap<>(); + result.put("description", "Security analysis prompt for Sentrius systems"); + result.put("messages", List.of( + Map.of( + "role", "system", + "content", Map.of("type", "text", "text", "You are a security analyst for Sentrius zero trust systems. Analyze the provided target for security vulnerabilities, compliance issues, and zero trust readiness.") + ), + Map.of( + "role", "user", + "content", Map.of("type", "text", "text", "Please analyze the security posture of {{target}}{{#scope}} with focus on {{scope}}{{/scope}}.") + ) + )); + return result; + } + + private Map getAgentTaskPrompt(String userId) { + Map result = new HashMap<>(); + result.put("description", "Agent task generation prompt for Sentrius agents"); + result.put("messages", List.of( + Map.of( + "role", "system", + "content", Map.of("type", "text", "text", "You are a task coordinator for Sentrius AI agents. Generate structured, actionable tasks based on the requested task type and parameters.") + ), + Map.of( + "role", "user", + "content", Map.of("type", "text", "text", "Generate a {{task_type}} task{{#parameters}} with parameters: {{parameters}}{{/parameters}}.") + ) + )); + return result; + } + + private Map getZeroTrustAssessmentPrompt(String userId) { + Map result = new HashMap<>(); + result.put("description", "Zero trust assessment prompt for security evaluation"); + result.put("messages", List.of( + Map.of( + "role", "system", + "content", Map.of("type", "text", "text", "You are a zero trust security expert. Assess the environment for zero trust maturity and provide specific recommendations for improvement.") + ), + Map.of( + "role", "user", + "content", Map.of("type", "text", "text", "Assess the zero trust readiness of {{environment}} and provide recommendations.") + ) + )); + return result; + } + + /** + * Determine if a tool requires ZTAT validation based on its name and sensitivity + */ + private boolean requiresZtatValidation(String toolName) { + // Define tools that require ZTAT validation + return "secure_command".equals(toolName) || + "agent_query".equals(toolName) || + toolName.contains("admin") || + toolName.contains("system"); + } + + /** + * Validate ZTAT token for tool execution + */ + private boolean validateZtatForToolExecution(String jwtToken, String toolName, Map arguments, String userId) { + try { + // For now, validate JWT token - in full implementation this would check ZTAT tokens + if (!keycloakService.validateJwt(jwtToken)) { + log.warn("JWT validation failed for tool execution: {}", toolName); + return false; + } + + // TODO: Implement full ZTAT validation using ztatService + // This would typically involve: + // 1. Checking if user has valid ZTAT token for the operation + // 2. Validating the token hasn't expired + // 3. Checking if the operation is within approved scope + // 4. Logging the token usage + + log.info("ZTAT validation passed for tool '{}' by user '{}'", toolName, userId); + return true; + + } catch (Exception e) { + log.error("ZTAT validation error for tool '{}': {}", toolName, e.getMessage()); + return false; + } + } +} \ No newline at end of file diff --git a/integration-proxy/src/main/java/io/sentrius/sso/mcp/websocket/MCPWebSocketHandler.java b/integration-proxy/src/main/java/io/sentrius/sso/mcp/websocket/MCPWebSocketHandler.java new file mode 100644 index 00000000..eff7906d --- /dev/null +++ b/integration-proxy/src/main/java/io/sentrius/sso/mcp/websocket/MCPWebSocketHandler.java @@ -0,0 +1,220 @@ +package io.sentrius.sso.mcp.websocket; + +import io.sentrius.sso.core.services.security.CryptoService; +import io.sentrius.sso.core.services.security.KeycloakService; +import io.sentrius.sso.mcp.service.MCPProxyService; +import io.sentrius.sso.mcp.model.MCPRequest; +import io.sentrius.sso.mcp.model.MCPResponse; +import io.sentrius.sso.mcp.model.MCPError; + +import com.fasterxml.jackson.databind.ObjectMapper; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Component; +import org.springframework.web.socket.*; + +import java.io.IOException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.Map; + +/** + * WebSocket handler for MCP (Model Context Protocol) real-time communication + * Provides secure WebSocket endpoints with zero trust validation + */ +@Component +@Slf4j +@RequiredArgsConstructor +public class MCPWebSocketHandler implements WebSocketHandler { + + private final KeycloakService keycloakService; + private final MCPProxyService mcpProxyService; + private final CryptoService cryptoService; + private final ObjectMapper objectMapper; + + // Track active sessions + private final Map activeSessions = new ConcurrentHashMap<>(); + + @Override + public void afterConnectionEstablished(WebSocketSession session) throws Exception { + log.info("MCP WebSocket connection established: {}", session.getId()); + + // Validate connection parameters + String token = getSessionAttribute(session, "token"); + String communicationId = getSessionAttribute(session, "communication_id"); + String userId = getSessionAttribute(session, "user_id"); + + if (!validateConnection(token, userId)) { + log.warn("Invalid connection attempt for session: {}", session.getId()); + session.close(CloseStatus.NOT_ACCEPTABLE.withReason("Invalid authentication")); + return; + } + + activeSessions.put(session.getId(), session); + + // Send welcome message + sendWelcomeMessage(session, userId); + } + + @Override + public void handleMessage(WebSocketSession session, WebSocketMessage message) throws Exception { + String sessionId = session.getId(); + log.debug("Received MCP WebSocket message from session: {}", sessionId); + + if (message instanceof TextMessage textMessage) { + handleTextMessage(session, textMessage); + } else if (message instanceof BinaryMessage binaryMessage) { + handleBinaryMessage(session, binaryMessage); + } else { + log.warn("Unsupported message type: {}", message.getClass()); + sendErrorMessage(session, "Unsupported message type"); + } + } + + @Override + public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception { + log.error("MCP WebSocket transport error for session: {}", session.getId(), exception); + activeSessions.remove(session.getId()); + } + + @Override + public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception { + log.info("MCP WebSocket connection closed: {} with status: {}", session.getId(), closeStatus); + activeSessions.remove(session.getId()); + } + + @Override + public boolean supportsPartialMessages() { + return false; + } + + /** + * Handle text-based MCP messages + */ + private void handleTextMessage(WebSocketSession session, TextMessage textMessage) { + try { + String payload = textMessage.getPayload(); + log.debug("Processing MCP text message: {}", payload); + + // Parse MCP request + MCPRequest mcpRequest = objectMapper.readValue(payload, MCPRequest.class); + + // Validate request + if (mcpRequest.getMethod() == null || mcpRequest.getId() == null) { + sendMCPResponse(session, MCPResponse.error( + mcpRequest.getId(), + MCPError.invalidRequest("Missing required fields") + )); + return; + } + + // Get session context + String token = getSessionAttribute(session, "token"); + String communicationId = getSessionAttribute(session, "communication_id"); + String userId = getSessionAttribute(session, "user_id"); + + // Process request through service layer + MCPResponse response = mcpProxyService.processRequest(mcpRequest, token, communicationId, userId); + + // Send response back to client + sendMCPResponse(session, response); + + } catch (Exception e) { + log.error("Error handling MCP text message", e); + sendErrorMessage(session, "Error processing message"); + } + } + + /** + * Handle binary MCP messages (for future binary protocol support) + */ + private void handleBinaryMessage(WebSocketSession session, BinaryMessage binaryMessage) { + log.warn("Binary MCP messages not yet supported"); + sendErrorMessage(session, "Binary messages not supported"); + } + + /** + * Validate WebSocket connection parameters + */ + private boolean validateConnection(String token, String userId) { + if (token == null || userId == null) { + log.warn("Missing required connection parameters"); + return false; + } + + try { + // Extract JWT from Bearer token + String jwt = token.startsWith("Bearer ") ? token.substring(7) : token; + return keycloakService.validateJwt(jwt); + } catch (Exception e) { + log.error("Error validating connection", e); + return false; + } + } + + /** + * Send welcome message when connection is established + */ + private void sendWelcomeMessage(WebSocketSession session, String userId) { + try { + MCPResponse welcome = MCPResponse.success("welcome", Map.of( + "message", "Connected to Sentrius MCP Proxy", + "userId", userId, + "protocolVersion", "2024-11-05", + "capabilities", Map.of( + "tools", Map.of("listChanged", true), + "resources", Map.of("subscribe", true, "listChanged", true), + "prompts", Map.of("listChanged", true) + ) + )); + + sendMCPResponse(session, welcome); + } catch (Exception e) { + log.error("Error sending welcome message", e); + } + } + + /** + * Send MCP response message + */ + private void sendMCPResponse(WebSocketSession session, MCPResponse response) { + try { + String json = objectMapper.writeValueAsString(response); + session.sendMessage(new TextMessage(json)); + } catch (Exception e) { + log.error("Error sending MCP response", e); + } + } + + /** + * Send error message to client + */ + private void sendErrorMessage(WebSocketSession session, String errorMessage) { + try { + MCPResponse error = MCPResponse.error("error", MCPError.internalError(errorMessage)); + sendMCPResponse(session, error); + } catch (Exception e) { + log.error("Error sending error message", e); + } + } + + /** + * Get session attribute safely + */ + private String getSessionAttribute(WebSocketSession session, String attributeName) { + Object attribute = session.getAttributes().get(attributeName); + return attribute != null ? attribute.toString() : null; + } + + /** + * Broadcast message to all active sessions (for notifications) + */ + public void broadcastMessage(MCPResponse message) { + activeSessions.values().forEach(session -> { + try { + sendMCPResponse(session, message); + } catch (Exception e) { + log.error("Error broadcasting message to session: {}", session.getId(), e); + } + }); + } +} \ No newline at end of file diff --git a/integration-proxy/src/test/java/io/sentrius/sso/controllers/api/mcp/MCPProxyServiceTest.java b/integration-proxy/src/test/java/io/sentrius/sso/controllers/api/mcp/MCPProxyServiceTest.java new file mode 100644 index 00000000..0524d898 --- /dev/null +++ b/integration-proxy/src/test/java/io/sentrius/sso/controllers/api/mcp/MCPProxyServiceTest.java @@ -0,0 +1,213 @@ +package io.sentrius.sso.controllers.api.mcp; + +import io.sentrius.sso.mcp.model.MCPRequest; +import io.sentrius.sso.mcp.model.MCPResponse; +import io.sentrius.sso.mcp.model.MCPError; +import io.sentrius.sso.mcp.service.MCPProxyService; +import io.sentrius.sso.core.services.security.KeycloakService; +import io.sentrius.sso.core.services.security.ZeroTrustAccessTokenService; +import io.sentrius.sso.core.services.security.ZeroTrustRequestService; +import io.sentrius.sso.core.services.agents.AgentClientService; +import io.sentrius.sso.core.services.agents.AgentExecutionService; +import io.sentrius.sso.core.services.agents.ZeroTrustClientService; +import io.sentrius.sso.provenance.kafka.ProvenanceKafkaProducer; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.web.client.RestTemplate; + +import java.util.Map; +import java.util.HashMap; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; + +/** + * Tests for MCP Proxy Service + */ +@ExtendWith(MockitoExtension.class) +public class MCPProxyServiceTest { + + @Mock + private KeycloakService keycloakService; + + @Mock + private ZeroTrustAccessTokenService ztatService; + + @Mock + private ZeroTrustRequestService ztrService; + + @Mock + private AgentClientService agentClientService; + + @Mock + private AgentExecutionService agentExecutionService; + + @Mock + private ZeroTrustClientService zeroTrustClientService; + + @Mock + private ProvenanceKafkaProducer provenanceKafkaProducer; + + @Mock + private RestTemplate restTemplate; + + private ObjectMapper objectMapper; + private MCPProxyService mcpProxyService; + + @BeforeEach + void setUp() { + objectMapper = new ObjectMapper(); + mcpProxyService = new MCPProxyService( + keycloakService, + ztatService, + ztrService, + agentClientService, + agentExecutionService, + zeroTrustClientService, + provenanceKafkaProducer, + restTemplate, + objectMapper + ); + } + + @Test + void testProcessRequest_ValidJWT_Success() { + // Arrange + MCPRequest request = MCPRequest.create("test-id", "ping", new HashMap<>()); + String jwtToken = "valid-jwt"; + String communicationId = "comm-123"; + String userId = "user-123"; + + when(keycloakService.validateJwt(jwtToken)).thenReturn(true); + + // Act + MCPResponse response = mcpProxyService.processRequest(request, jwtToken, communicationId, userId); + + // Assert + assertNotNull(response); + assertEquals("test-id", response.getId()); + assertNull(response.getError()); + assertNotNull(response.getResult()); + + verify(keycloakService).validateJwt(jwtToken); + verify(provenanceKafkaProducer, times(2)).send(any()); + } + + @Test + void testProcessRequest_InvalidJWT_ReturnsUnauthorized() { + // Arrange + MCPRequest request = MCPRequest.create("test-id", "ping", new HashMap<>()); + String jwtToken = "invalid-jwt"; + String communicationId = "comm-123"; + String userId = "user-123"; + + when(keycloakService.validateJwt(jwtToken)).thenReturn(false); + + // Act + MCPResponse response = mcpProxyService.processRequest(request, jwtToken, communicationId, userId); + + // Assert + assertNotNull(response); + assertEquals("test-id", response.getId()); + assertNotNull(response.getError()); + assertEquals(MCPError.UNAUTHORIZED, response.getError().getCode()); + + verify(keycloakService).validateJwt(jwtToken); + } + + @Test + void testProcessRequest_InitializeMethod_ReturnsCapabilities() { + // Arrange + MCPRequest request = MCPRequest.create("init-id", "initialize", new HashMap<>()); + String jwtToken = "valid-jwt"; + String communicationId = "comm-123"; + String userId = "user-123"; + + when(keycloakService.validateJwt(jwtToken)).thenReturn(true); + + // Act + MCPResponse response = mcpProxyService.processRequest(request, jwtToken, communicationId, userId); + + // Assert + assertNotNull(response); + assertEquals("init-id", response.getId()); + assertNull(response.getError()); + assertNotNull(response.getResult()); + + Map result = (Map) response.getResult(); + assertTrue(result.containsKey("protocolVersion")); + assertTrue(result.containsKey("capabilities")); + assertTrue(result.containsKey("serverInfo")); + } + + @Test + void testProcessRequest_UnknownMethod_ReturnsMethodNotFound() { + // Arrange + MCPRequest request = MCPRequest.create("test-id", "unknown-method", new HashMap<>()); + String jwtToken = "valid-jwt"; + String communicationId = "comm-123"; + String userId = "user-123"; + + when(keycloakService.validateJwt(jwtToken)).thenReturn(true); + + // Act + MCPResponse response = mcpProxyService.processRequest(request, jwtToken, communicationId, userId); + + // Assert + assertNotNull(response); + assertEquals("test-id", response.getId()); + assertNotNull(response.getError()); + assertEquals(MCPError.METHOD_NOT_FOUND, response.getError().getCode()); + } + + @Test + void testProcessRequest_ToolsCallWithoutName_ReturnsInvalidParams() { + // Arrange + Map params = new HashMap<>(); + // Not including required "name" parameter + MCPRequest request = MCPRequest.create("test-id", "tools/call", params); + String jwtToken = "valid-jwt"; + String communicationId = "comm-123"; + String userId = "user-123"; + + when(keycloakService.validateJwt(jwtToken)).thenReturn(true); + + // Act + MCPResponse response = mcpProxyService.processRequest(request, jwtToken, communicationId, userId); + + // Assert + assertNotNull(response); + assertEquals("test-id", response.getId()); + assertNotNull(response.getError()); + assertEquals(MCPError.INVALID_PARAMS, response.getError().getCode()); + } + + @Test + void testProcessRequest_ToolsCallWithValidParams_Success() { + // Arrange + Map params = new HashMap<>(); + params.put("name", "secure_command"); + params.put("arguments", Map.of("command", "ls -la")); + MCPRequest request = MCPRequest.create("test-id", "tools/call", params); + String jwtToken = "valid-jwt"; + String communicationId = "comm-123"; + String userId = "user-123"; + + when(keycloakService.validateJwt(jwtToken)).thenReturn(true); + + // Act + MCPResponse response = mcpProxyService.processRequest(request, jwtToken, communicationId, userId); + + // Assert + assertNotNull(response); + assertEquals("test-id", response.getId()); + assertNull(response.getError()); + assertNotNull(response.getResult()); + } +} \ No newline at end of file diff --git a/ops-scripts/local/deploy-helm.sh b/ops-scripts/local/deploy-helm.sh index 7491a971..81dbe041 100755 --- a/ops-scripts/local/deploy-helm.sh +++ b/ops-scripts/local/deploy-helm.sh @@ -38,13 +38,13 @@ helm upgrade --install sentrius ./sentrius-chart --namespace ${TENANT} \ --set keycloakDomain="http://sentrius-keycloak:8081" \ --set sentriusDomain="http://sentrius-sentrius:8080" \ --set launcherFQDN=sentrius-agents-launcherservice.${TENANT}-agents.svc.cluster.local \ - --set llmproxy.image.repository="sentrius-llmproxy" \ - --set llmproxy.image.pullPolicy="Never" \ + --set integrationproxy.image.repository="sentrius-integration-proxy" \ + --set integrationproxy.image.pullPolicy="Never" \ --set sentrius.image.repository="sentrius" \ --set sentrius.image.pullPolicy="Never" \ --set keycloak.image.pullPolicy="Never" \ --set ssh.image.pullPolicy="Never" \ - --set llmproxy.image.tag=${LLMPROXY_VERSION} \ + --set integrationproxy.image.tag=${LLMPROXY_VERSION} \ --set sentrius.image.tag=${SENTRIUS_VERSION} \ --set ssh.image.tag=${SENTRIUS_SSH_VERSION} \ --set keycloak.image.tag=${SENTRIUS_KEYCLOAK_VERSION} \ @@ -61,19 +61,19 @@ helm upgrade --install sentrius-agents ./sentrius-chart-launcher --namespace ${T --set sentriusNamespace=${TENANT} \ --set keycloakFQDN=sentrius-keycloak.${TENANT}.svc.cluster.local \ --set sentriusFQDN=sentrius-sentrius.${TENANT}.svc.cluster.local \ - --set llmProxyFQDN=sentrius-llmproxy.${TENANT}.svc.cluster.local \ + --set integrationproxyFQDN=sentrius-integrationproxy.${TENANT}.svc.cluster.local \ --set subdomain="sentrius-sentrius" \ --set keycloakSubdomain="sentrius-keycloak" \ --set keycloakHostname="sentrius-keycloak:8081" \ --set keycloakDomain="http://sentrius-keycloak:8081" \ --set sentriusDomain="http://sentrius-sentrius:8080" \ - --set llmproxy.image.repository="sentrius-llmproxy" \ - --set llmproxy.image.pullPolicy="Never" \ + --set integrationproxy.image.repository="sentrius-integration-proxy" \ + --set integrationproxy.image.pullPolicy="IfNotPresent" \ --set sentrius.image.repository="sentrius" \ --set sentrius.image.pullPolicy="Never" \ --set keycloak.image.pullPolicy="Never" \ --set ssh.image.pullPolicy="Never" \ - --set llmproxy.image.tag=${LLMPROXY_VERSION} \ + --set integrationproxy.image.tag=${LLMPROXY_VERSION} \ --set sentrius.image.tag=${SENTRIUS_VERSION} \ --set ssh.image.tag=${SENTRIUS_SSH_VERSION} \ --set keycloak.image.tag=${SENTRIUS_KEYCLOAK_VERSION} \ diff --git a/python-agent/agents/base.py b/python-agent/agents/base.py index a3a80c37..0cf1cad0 100644 --- a/python-agent/agents/base.py +++ b/python-agent/agents/base.py @@ -24,16 +24,23 @@ def __init__(self, config_manager, name: Optional[str] = None): keycloak_config = config_manager.get_keycloak_config() # Create SentriusAgentConfig from the loaded configuration + from services.config import KeycloakConfig, AgentConfig, LLMConfig + self.config = SentriusAgentConfig( - keycloak_server_url=keycloak_config['server_url'], - keycloak_realm=keycloak_config['realm'], - keycloak_client_id=keycloak_config['client_id'], - keycloak_client_secret=keycloak_config['client_secret'], - agent_name_prefix=agent_config['name_prefix'], - agent_type=agent_config['agent_type'], - agent_callback_url=agent_config['callback_url'], - api_url=agent_config['api_url'], - heartbeat_interval=agent_config['heartbeat_interval'] + keycloak=KeycloakConfig( + server_url=keycloak_config['server_url'], + realm=keycloak_config['realm'], + client_id=keycloak_config['client_id'], + client_secret=keycloak_config['client_secret'] + ), + agent=AgentConfig( + name_prefix=agent_config['name_prefix'], + agent_type=agent_config['agent_type'], + callback_url=agent_config['callback_url'], + api_url=agent_config['api_url'], + heartbeat_interval=agent_config['heartbeat_interval'] + ), + llm=LLMConfig() # Default LLM config ) # Initialize Sentrius agent diff --git a/python-agent/agents/mcp/__init__.py b/python-agent/agents/mcp/__init__.py new file mode 100644 index 00000000..c4c5250f --- /dev/null +++ b/python-agent/agents/mcp/__init__.py @@ -0,0 +1 @@ +# MCP Agent module \ No newline at end of file diff --git a/python-agent/agents/mcp/mcp_agent.py b/python-agent/agents/mcp/mcp_agent.py new file mode 100644 index 00000000..4869cf7c --- /dev/null +++ b/python-agent/agents/mcp/mcp_agent.py @@ -0,0 +1,250 @@ +""" +MCP Agent - Provides Model Context Protocol integration with Sentrius security. +""" +import logging +import asyncio +from typing import Dict, Any, Optional +from agents.base import BaseAgent +from services.mcp_service import MCPService + +logger = logging.getLogger(__name__) + + +class MCPAgent(BaseAgent): + """Agent that provides MCP (Model Context Protocol) integration with Sentrius security.""" + + def __init__(self, config_manager): + super().__init__(config_manager) + self.agent_definition = config_manager.get_agent_definition('mcp') + if not self.agent_definition: + raise ValueError("MCP agent configuration not found") + + # Initialize MCP service with Sentrius integration + if not self.test_mode: + mcp_base_url = self.agent_definition.get('mcp_base_url', 'http://localhost:8080') + self.mcp_service = MCPService( + base_url=mcp_base_url, + keycloak_service=self.sentrius_agent.keycloak_service, + agent_id=self.sentrius_agent.agent_id + ) + else: + self.mcp_service = None + + logger.info(f"Initialized MCPAgent: {self.agent_definition.get('description', 'No description')}") + + def execute_task(self, task_data: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + """Execute MCP task.""" + try: + # Submit provenance for task start + self.submit_provenance("MCP_TASK_START", { + "agent_type": "mcp", + "task_data": task_data + }) + + if self.test_mode: + logger.info("MCP Agent running in test mode") + return { + "status": "test_mode", + "message": "MCP operations would be executed here" + } + + # Process the MCP request + response = self._process_mcp_request(task_data) + + # Submit provenance for task completion + self.submit_provenance("MCP_TASK_COMPLETE", { + "agent_type": "mcp", + "response": response + }) + + return response + + except Exception as e: + logger.error(f"Error executing MCP task: {e}") + self.submit_provenance("MCP_TASK_ERROR", { + "agent_type": "mcp", + "error": str(e) + }) + raise + + def _process_mcp_request(self, task_data: Optional[Dict[str, Any]]) -> Dict[str, Any]: + """Process MCP request using the integrated service.""" + if not task_data: + # Default initialization + try: + # Initialize MCP session + init_response = self.mcp_service.initialize() + capabilities = self.mcp_service.get_capabilities() + + return { + "operation": "initialize", + "status": "success", + "init_response": init_response, + "capabilities": capabilities, + "message": "MCP agent initialized successfully" + } + except Exception as e: + return { + "operation": "initialize", + "status": "error", + "error": str(e), + "message": "Failed to initialize MCP agent" + } + + operation = task_data.get('operation', 'ping') + + try: + if operation == 'ping': + response = self.mcp_service.ping() + return { + "operation": "ping", + "status": "success", + "response": response + } + + elif operation == 'list_tools': + response = self.mcp_service.list_tools() + return { + "operation": "list_tools", + "status": "success", + "tools": response + } + + elif operation == 'call_tool': + tool_name = task_data.get('tool_name') + arguments = task_data.get('arguments', {}) + if not tool_name: + raise ValueError("tool_name is required for call_tool operation") + + response = self.mcp_service.call_tool(tool_name, arguments) + return { + "operation": "call_tool", + "status": "success", + "tool_name": tool_name, + "response": response + } + + elif operation == 'list_resources': + response = self.mcp_service.list_resources() + return { + "operation": "list_resources", + "status": "success", + "resources": response + } + + elif operation == 'read_resource': + uri = task_data.get('uri') + if not uri: + raise ValueError("uri is required for read_resource operation") + + response = self.mcp_service.read_resource(uri) + return { + "operation": "read_resource", + "status": "success", + "uri": uri, + "response": response + } + + elif operation == 'list_prompts': + response = self.mcp_service.list_prompts() + return { + "operation": "list_prompts", + "status": "success", + "prompts": response + } + + elif operation == 'get_prompt': + name = task_data.get('name') + arguments = task_data.get('arguments', {}) + if not name: + raise ValueError("name is required for get_prompt operation") + + response = self.mcp_service.get_prompt(name, arguments) + return { + "operation": "get_prompt", + "status": "success", + "name": name, + "response": response + } + + elif operation == 'websocket_example': + # Demonstrate WebSocket usage + return asyncio.run(self._websocket_example()) + + else: + raise ValueError(f"Unknown operation: {operation}") + + except Exception as e: + logger.error(f"MCP operation '{operation}' failed: {e}") + return { + "operation": operation, + "status": "error", + "error": str(e), + "message": f"Failed to execute {operation}" + } + + async def _websocket_example(self) -> Dict[str, Any]: + """Example of WebSocket MCP communication.""" + try: + ws_client = await self.mcp_service.connect_websocket() + + try: + # Send ping via WebSocket + ping_response = await ws_client.send_request("ping") + + # List tools via WebSocket + tools_response = await ws_client.send_request("tools/list") + + return { + "operation": "websocket_example", + "status": "success", + "ping_response": ping_response, + "tools_response": tools_response + } + + finally: + await ws_client.close() + + except Exception as e: + logger.error(f"WebSocket example failed: {e}") + return { + "operation": "websocket_example", + "status": "error", + "error": str(e) + } + + def execute_secure_command(self, command: str) -> Dict[str, Any]: + """ + Execute a secure command using MCP tools + """ + try: + task_data = { + "operation": "call_tool", + "tool_name": "secure_command", + "arguments": {"command": command} + } + return self._process_mcp_request(task_data) + + except Exception as e: + logger.error(f"Secure command execution failed: {e}") + return { + "operation": "call_tool", + "status": "error", + "error": str(e), + "command": command + } + + def get_agent_info(self) -> Dict[str, Any]: + """Get information about this agent.""" + return { + "name": "mcp", + "type": "protocol_integration", + "description": self.agent_definition.get('description', ''), + "capabilities": [ + "mcp_protocol", + "secure_tool_execution", + "resource_access", + "prompt_management", + "websocket_communication" + ] + } \ No newline at end of file diff --git a/python-agent/application.properties b/python-agent/application.properties index ddce1bc8..a2116e0a 100644 --- a/python-agent/application.properties +++ b/python-agent/application.properties @@ -21,15 +21,18 @@ agent.llm.endpoint=${LLM_ENDPOINT:http://localhost:8084/} agent.llm.enabled=true # Agent Definitions - these reference YAML files that define agent behavior -agent.chat.helper.config=chat-helper.yaml +agent.chat.helper.config=python-agent/chat-helper.yaml agent.chat.helper.enabled=true -agent.data.analyst.config=data-analyst.yaml +agent.data.analyst.config=python-agent/data-analyst.yaml agent.data.analyst.enabled=false -agent.terminal.helper.config=terminal-helper.yaml +agent.terminal.helper.config=python-agent/terminal-helper.yaml agent.terminal.helper.enabled=false +agent.mcp.config=python-agent/mcp.yaml +agent.mcp.enabled=true + # OpenTelemetry Configuration otel.exporter.otlp.endpoint=${OTEL_EXPORTER_OTLP_ENDPOINT:http://localhost:4317} otel.traces.exporter=otlp diff --git a/python-agent/main.py b/python-agent/main.py index f8de1a3a..75e4a6bc 100644 --- a/python-agent/main.py +++ b/python-agent/main.py @@ -8,6 +8,7 @@ from utils.config_manager import ConfigManager from agents.chat_helper.chat_helper_agent import ChatHelperAgent +from agents.mcp.mcp_agent import MCPAgent # Configure logging logging.basicConfig( @@ -18,6 +19,7 @@ AVAILABLE_AGENTS = { 'chat-helper': ChatHelperAgent, + 'mcp': MCPAgent, } diff --git a/python-agent/mcp.yaml b/python-agent/mcp.yaml new file mode 100644 index 00000000..bb2b2f76 --- /dev/null +++ b/python-agent/mcp.yaml @@ -0,0 +1,16 @@ +description: "Agent that provides Model Context Protocol (MCP) integration with Sentrius zero trust security." +mcp_base_url: "http://localhost:8080" +context: | + You are an MCP (Model Context Protocol) agent that provides secure communication with AI tools and resources. + You have access to secure tools, resources, and prompts through the Sentrius MCP proxy. + All operations are authenticated and logged for audit trails. + + Available operations: + - ping: Check connectivity + - list_tools: Get available tools + - call_tool: Execute a specific tool + - list_resources: Get available resources + - read_resource: Read a specific resource + - list_prompts: Get available prompts + - get_prompt: Retrieve a specific prompt + - websocket_example: Demonstrate real-time communication \ No newline at end of file diff --git a/python-agent/requirements.txt b/python-agent/requirements.txt index fd9ead33..c129c7b9 100644 --- a/python-agent/requirements.txt +++ b/python-agent/requirements.txt @@ -3,4 +3,5 @@ pyyaml requests>=2.25.0 PyJWT>=2.4.0 cryptography>=3.4.0 -dataclasses-json>=0.5.0 \ No newline at end of file +dataclasses-json>=0.5.0 +websockets>=10.0 \ No newline at end of file diff --git a/python-agent/services/mcp_service.py b/python-agent/services/mcp_service.py new file mode 100644 index 00000000..e74e5baf --- /dev/null +++ b/python-agent/services/mcp_service.py @@ -0,0 +1,191 @@ +""" +MCP (Model Context Protocol) Service for Sentrius Python Agent +Integrates with existing Sentrius authentication and provenance systems. +""" + +import json +import asyncio +import websockets +import requests +from typing import Dict, Any, Optional +import logging + +logger = logging.getLogger(__name__) + + +class MCPService: + """ + Service for communicating with Sentrius MCP Proxy + Integrates with existing Sentrius authentication and provenance systems + """ + + def __init__(self, base_url: str, keycloak_service, agent_id: str): + self.base_url = base_url.rstrip('/') + self.keycloak_service = keycloak_service + self.agent_id = agent_id + self.session = requests.Session() + self._update_auth_headers() + + def _update_auth_headers(self): + """Update session headers with current JWT token""" + try: + token = self.keycloak_service.get_keycloak_token() + self.session.headers.update({ + 'Authorization': f'Bearer {token}', + 'Content-Type': 'application/json' + }) + except Exception as e: + logger.error(f"Failed to update auth headers: {e}") + raise + + def send_mcp_request(self, method: str, params: Dict[str, Any] = None, + communication_id: str = None) -> Dict[str, Any]: + """ + Send MCP request via HTTP + """ + if communication_id is None: + communication_id = f"mcp-{method}-{self.agent_id}" + + request_data = { + "jsonrpc": "2.0", + "id": f"{method}-{self.agent_id}", + "method": method, + "params": params or {} + } + + headers = dict(self.session.headers) + headers['communication_id'] = communication_id + + url = f"{self.base_url}/api/v1/mcp/" + + try: + # Refresh token if needed + self._update_auth_headers() + + response = self.session.post(url, json=request_data, headers=headers) + response.raise_for_status() + return response.json() + except requests.RequestException as e: + logger.error(f"MCP request failed: {e}") + raise + + async def connect_websocket(self) -> 'MCPWebSocketClient': + """ + Create WebSocket connection for real-time MCP communication + """ + try: + token = self.keycloak_service.get_keycloak_token() + ws_url = self.base_url.replace('http://', 'ws://').replace('https://', 'wss://') + ws_url += f"/api/v1/mcp/ws?token=Bearer%20{token}&communication_id=ws-{self.agent_id}&user_id={self.agent_id}" + + websocket = await websockets.connect(ws_url) + return MCPWebSocketClient(websocket) + except Exception as e: + logger.error(f"Failed to connect WebSocket: {e}") + raise + + def get_capabilities(self) -> Dict[str, Any]: + """ + Get MCP proxy capabilities + """ + url = f"{self.base_url}/api/v1/mcp/capabilities" + try: + self._update_auth_headers() + response = self.session.get(url) + response.raise_for_status() + return response.json() + except requests.RequestException as e: + logger.error(f"Failed to get capabilities: {e}") + raise + + def initialize(self) -> Dict[str, Any]: + """ + Initialize MCP session + """ + return self.send_mcp_request("initialize") + + def ping(self) -> Dict[str, Any]: + """ + Send ping for connectivity check + """ + return self.send_mcp_request("ping") + + def list_tools(self) -> Dict[str, Any]: + """ + List available tools + """ + return self.send_mcp_request("tools/list") + + def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]: + """ + Execute a tool + """ + params = { + "name": tool_name, + "arguments": arguments + } + return self.send_mcp_request("tools/call", params) + + def list_resources(self) -> Dict[str, Any]: + """ + List available resources + """ + return self.send_mcp_request("resources/list") + + def read_resource(self, uri: str) -> Dict[str, Any]: + """ + Read a specific resource + """ + params = {"uri": uri} + return self.send_mcp_request("resources/read", params) + + def list_prompts(self) -> Dict[str, Any]: + """ + List available prompts + """ + return self.send_mcp_request("prompts/list") + + def get_prompt(self, name: str, arguments: Dict[str, Any] = None) -> Dict[str, Any]: + """ + Get a specific prompt + """ + params = {"name": name} + if arguments: + params["arguments"] = arguments + return self.send_mcp_request("prompts/get", params) + + +class MCPWebSocketClient: + """ + WebSocket client for real-time MCP communication + """ + + def __init__(self, websocket): + self.websocket = websocket + self.request_id_counter = 0 + + async def send_request(self, method: str, params: Dict[str, Any] = None) -> Dict[str, Any]: + """ + Send MCP request via WebSocket + """ + self.request_id_counter += 1 + request_id = f"{method}-{self.request_id_counter}" + + request_data = { + "jsonrpc": "2.0", + "id": request_id, + "method": method, + "params": params or {} + } + + await self.websocket.send(json.dumps(request_data)) + + # Wait for response + response_text = await self.websocket.recv() + return json.loads(response_text) + + async def close(self): + """ + Close WebSocket connection + """ + await self.websocket.close() \ No newline at end of file diff --git a/python-agent/utils/config_manager.py b/python-agent/utils/config_manager.py index e04e32a1..c341e2c2 100644 --- a/python-agent/utils/config_manager.py +++ b/python-agent/utils/config_manager.py @@ -65,12 +65,22 @@ def _load_agent_configs(self): agent_name = config_key.replace('.config', '').replace('agent.', '') try: + # Try the path as specified first yaml_path = Path(yaml_file) + + # If not found and path starts with "python-agent/", try without that prefix + if not yaml_path.exists() and yaml_file.startswith("python-agent/"): + yaml_path = Path(yaml_file.replace("python-agent/", "")) + + # If still not found and doesn't have prefix, try with prefix + if not yaml_path.exists() and not yaml_file.startswith("python-agent/"): + yaml_path = Path(f"python-agent/{yaml_file}") + if yaml_path.exists(): with open(yaml_path, 'r') as f: config = yaml.safe_load(f) self.agent_configs[agent_name] = config - logger.info(f"Loaded agent config for {agent_name} from {yaml_file}") + logger.info(f"Loaded agent config for {agent_name} from {yaml_path}") else: logger.warning(f"Agent config file {yaml_file} not found for {agent_name}") except Exception as e: diff --git a/sentrius-chart/values.yaml b/sentrius-chart/values.yaml index f41e1db1..2310ac1d 100644 --- a/sentrius-chart/values.yaml +++ b/sentrius-chart/values.yaml @@ -47,7 +47,7 @@ sentrius: # Sentrius configuration integrationproxy: image: - repository: us-central1-docker.pkg.dev/sentrius-project/integration-proxy-repo + repository: sentrius-integration-proxy tag: tag pullPolicy: IfNotPresent port: 8080