Skip to content

Commit 6249b1b

Browse files
Copilotphrocker
andcommitted
Implement MCP proxy with zero trust security integration
Co-authored-by: phrocker <[email protected]>
1 parent 6ce7f19 commit 6249b1b

File tree

9 files changed

+1202
-0
lines changed

9 files changed

+1202
-0
lines changed
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
package io.sentrius.sso.controllers.api.mcp;
2+
3+
import io.sentrius.sso.core.controllers.BaseController;
4+
import io.sentrius.sso.core.annotations.LimitAccess;
5+
import io.sentrius.sso.core.model.security.enums.ApplicationAccessEnum;
6+
import io.sentrius.sso.core.services.UserService;
7+
import io.sentrius.sso.core.config.SystemOptions;
8+
import io.sentrius.sso.core.services.ErrorOutputService;
9+
import io.sentrius.sso.core.services.security.KeycloakService;
10+
import io.sentrius.sso.mcp.model.MCPRequest;
11+
import io.sentrius.sso.mcp.model.MCPResponse;
12+
import io.sentrius.sso.mcp.model.MCPError;
13+
import io.sentrius.sso.mcp.service.MCPProxyService;
14+
15+
import com.fasterxml.jackson.core.JsonProcessingException;
16+
import com.fasterxml.jackson.databind.ObjectMapper;
17+
import lombok.extern.slf4j.Slf4j;
18+
import org.springframework.http.HttpStatus;
19+
import org.springframework.http.ResponseEntity;
20+
import org.springframework.web.bind.annotation.*;
21+
22+
import jakarta.servlet.http.HttpServletRequest;
23+
import jakarta.servlet.http.HttpServletResponse;
24+
import java.util.Map;
25+
26+
/**
27+
* MCP (Model Context Protocol) Proxy Controller with Zero Trust Security
28+
*
29+
* Provides secure MCP endpoints with the same security controls as other Sentrius services:
30+
* - JWT authentication via Keycloak
31+
* - Zero Trust Access Token (ZTAT) validation
32+
* - Access control via @LimitAccess annotations
33+
* - Provenance tracking for audit trails
34+
*/
35+
@RestController
36+
@RequestMapping("/api/v1/mcp")
37+
@Slf4j
38+
public class MCPProxyController extends BaseController {
39+
40+
private final KeycloakService keycloakService;
41+
private final MCPProxyService mcpProxyService;
42+
private final ObjectMapper objectMapper;
43+
44+
public MCPProxyController(
45+
UserService userService,
46+
SystemOptions systemOptions,
47+
ErrorOutputService errorOutputService,
48+
KeycloakService keycloakService,
49+
MCPProxyService mcpProxyService,
50+
ObjectMapper objectMapper
51+
) {
52+
super(userService, systemOptions, errorOutputService);
53+
this.keycloakService = keycloakService;
54+
this.mcpProxyService = mcpProxyService;
55+
this.objectMapper = objectMapper;
56+
}
57+
58+
/**
59+
* Handle MCP requests via HTTP POST
60+
*/
61+
@PostMapping("/")
62+
@LimitAccess(applicationAccess = {ApplicationAccessEnum.CAN_LOG_IN})
63+
public ResponseEntity<?> handleMCPRequest(
64+
@RequestHeader("Authorization") String token,
65+
@RequestHeader("communication_id") String communicationId,
66+
HttpServletRequest request,
67+
HttpServletResponse response,
68+
@RequestBody String rawBody) {
69+
70+
log.info("Received MCP request with communication_id: {}", communicationId);
71+
72+
String compactJwt = extractJwtToken(token);
73+
74+
// Validate JWT token
75+
if (!keycloakService.validateJwt(compactJwt)) {
76+
log.warn("Invalid Keycloak token for MCP request");
77+
return ResponseEntity.status(HttpStatus.UNAUTHORIZED)
78+
.body(createErrorResponse(null, MCPError.unauthorized("Invalid Keycloak token")));
79+
}
80+
81+
// Get operating user
82+
var operatingUser = getOperatingUser(request, response);
83+
if (operatingUser == null) {
84+
log.warn("No operating user found for MCP request");
85+
return ResponseEntity.status(HttpStatus.UNAUTHORIZED)
86+
.body(createErrorResponse(null, MCPError.unauthorized("No operating user found")));
87+
}
88+
89+
try {
90+
// Parse MCP request
91+
MCPRequest mcpRequest = objectMapper.readValue(rawBody, MCPRequest.class);
92+
93+
// Validate MCP request structure
94+
if (mcpRequest.getMethod() == null || mcpRequest.getId() == null) {
95+
return ResponseEntity.badRequest()
96+
.body(createErrorResponse(mcpRequest.getId(), MCPError.invalidRequest("Missing required fields")));
97+
}
98+
99+
// Process the request through the service layer
100+
MCPResponse mcpResponse = mcpProxyService.processRequest(
101+
mcpRequest, compactJwt, communicationId, operatingUser.getUsername()
102+
);
103+
104+
return ResponseEntity.ok(mcpResponse);
105+
106+
} catch (JsonProcessingException e) {
107+
log.error("Failed to parse MCP request", e);
108+
return ResponseEntity.badRequest()
109+
.body(createErrorResponse(null, MCPError.parseError("Invalid JSON format")));
110+
} catch (Exception e) {
111+
log.error("Unexpected error processing MCP request", e);
112+
return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR)
113+
.body(createErrorResponse(null, MCPError.internalError("Internal server error")));
114+
}
115+
}
116+
117+
/**
118+
* Handle MCP capability discovery
119+
*/
120+
@GetMapping("/capabilities")
121+
@LimitAccess(applicationAccess = {ApplicationAccessEnum.CAN_LOG_IN})
122+
public ResponseEntity<?> getCapabilities(
123+
@RequestHeader("Authorization") String token,
124+
HttpServletRequest request,
125+
HttpServletResponse response) {
126+
127+
String compactJwt = extractJwtToken(token);
128+
129+
if (!keycloakService.validateJwt(compactJwt)) {
130+
log.warn("Invalid Keycloak token for MCP capabilities request");
131+
return ResponseEntity.status(HttpStatus.UNAUTHORIZED)
132+
.body(MCPError.unauthorized("Invalid Keycloak token"));
133+
}
134+
135+
var operatingUser = getOperatingUser(request, response);
136+
if (operatingUser == null) {
137+
return ResponseEntity.status(HttpStatus.UNAUTHORIZED)
138+
.body(MCPError.unauthorized("No operating user found"));
139+
}
140+
141+
try {
142+
// Create an initialize request to get capabilities
143+
MCPRequest initRequest = MCPRequest.create("capabilities", "initialize", null);
144+
MCPResponse mcpResponse = mcpProxyService.processRequest(
145+
initRequest, compactJwt, "capabilities", operatingUser.getUsername()
146+
);
147+
148+
return ResponseEntity.ok(mcpResponse.getResult());
149+
150+
} catch (Exception e) {
151+
log.error("Error getting MCP capabilities", e);
152+
return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR)
153+
.body(MCPError.internalError("Failed to get capabilities"));
154+
}
155+
}
156+
157+
/**
158+
* Health check endpoint for MCP proxy
159+
*/
160+
@GetMapping("/health")
161+
public ResponseEntity<?> health() {
162+
return ResponseEntity.ok(Map.of(
163+
"status", "healthy",
164+
"service", "mcp-proxy",
165+
"timestamp", java.time.Instant.now().toString()
166+
));
167+
}
168+
169+
/**
170+
* Extract JWT token from Authorization header
171+
*/
172+
private String extractJwtToken(String authHeader) {
173+
return authHeader != null && authHeader.startsWith("Bearer ") ?
174+
authHeader.substring(7) : authHeader;
175+
}
176+
177+
/**
178+
* Create error response in MCP format
179+
*/
180+
private MCPResponse createErrorResponse(String id, MCPError error) {
181+
return MCPResponse.error(id, error);
182+
}
183+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package io.sentrius.sso.mcp.config;
2+
3+
import org.springframework.context.annotation.Bean;
4+
import org.springframework.context.annotation.Configuration;
5+
import org.springframework.web.client.RestTemplate;
6+
import com.fasterxml.jackson.databind.ObjectMapper;
7+
8+
/**
9+
* Configuration for MCP proxy services
10+
*/
11+
@Configuration
12+
public class MCPProxyConfig {
13+
14+
@Bean
15+
public RestTemplate mcpRestTemplate() {
16+
return new RestTemplate();
17+
}
18+
19+
@Bean
20+
public ObjectMapper mcpObjectMapper() {
21+
return new ObjectMapper();
22+
}
23+
}
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
package io.sentrius.sso.mcp.config;
2+
3+
import io.sentrius.sso.mcp.websocket.MCPWebSocketHandler;
4+
import io.sentrius.sso.core.services.security.KeycloakService;
5+
6+
import lombok.RequiredArgsConstructor;
7+
import org.springframework.context.annotation.Configuration;
8+
import org.springframework.web.socket.config.annotation.EnableWebSocket;
9+
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
10+
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
11+
import org.springframework.web.socket.server.HandshakeInterceptor;
12+
import org.springframework.http.server.ServerHttpRequest;
13+
import org.springframework.http.server.ServerHttpResponse;
14+
import org.springframework.web.socket.WebSocketHandler;
15+
16+
import java.util.Map;
17+
18+
/**
19+
* WebSocket configuration for MCP (Model Context Protocol) endpoints
20+
*/
21+
@Configuration
22+
@EnableWebSocket
23+
@RequiredArgsConstructor
24+
public class MCPWebSocketConfig implements WebSocketConfigurer {
25+
26+
private final MCPWebSocketHandler mcpWebSocketHandler;
27+
private final KeycloakService keycloakService;
28+
29+
@Override
30+
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
31+
registry.addHandler(mcpWebSocketHandler, "/api/v1/mcp/ws")
32+
.addInterceptors(new MCPHandshakeInterceptor())
33+
.setAllowedOrigins("*"); // Configure as needed for security
34+
}
35+
36+
/**
37+
* Handshake interceptor to validate authentication before WebSocket connection
38+
*/
39+
private class MCPHandshakeInterceptor implements HandshakeInterceptor {
40+
41+
@Override
42+
public boolean beforeHandshake(
43+
ServerHttpRequest request,
44+
ServerHttpResponse response,
45+
WebSocketHandler wsHandler,
46+
Map<String, Object> attributes) throws Exception {
47+
48+
// Extract authentication parameters from query params or headers
49+
String token = extractToken(request);
50+
String communicationId = extractParameter(request, "communication_id");
51+
String userId = extractParameter(request, "user_id");
52+
53+
if (token == null || communicationId == null || userId == null) {
54+
return false; // Reject connection
55+
}
56+
57+
// Validate JWT token
58+
String jwt = token.startsWith("Bearer ") ? token.substring(7) : token;
59+
if (!keycloakService.validateJwt(jwt)) {
60+
return false; // Invalid token
61+
}
62+
63+
// Store validated parameters in session attributes
64+
attributes.put("token", token);
65+
attributes.put("communication_id", communicationId);
66+
attributes.put("user_id", userId);
67+
68+
return true; // Allow connection
69+
}
70+
71+
@Override
72+
public void afterHandshake(
73+
ServerHttpRequest request,
74+
ServerHttpResponse response,
75+
WebSocketHandler wsHandler,
76+
Exception exception) {
77+
// No additional processing needed
78+
}
79+
80+
private String extractToken(ServerHttpRequest request) {
81+
// Try Authorization header first
82+
String authHeader = request.getHeaders().getFirst("Authorization");
83+
if (authHeader != null) {
84+
return authHeader;
85+
}
86+
87+
// Fall back to query parameter
88+
return extractParameter(request, "token");
89+
}
90+
91+
private String extractParameter(ServerHttpRequest request, String paramName) {
92+
String query = request.getURI().getQuery();
93+
if (query == null) {
94+
return null;
95+
}
96+
97+
for (String param : query.split("&")) {
98+
String[] parts = param.split("=", 2);
99+
if (parts.length == 2 && paramName.equals(parts[0])) {
100+
return parts[1];
101+
}
102+
}
103+
return null;
104+
}
105+
}
106+
}
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
package io.sentrius.sso.mcp.model;
2+
3+
import com.fasterxml.jackson.annotation.JsonProperty;
4+
import lombok.Data;
5+
import lombok.NoArgsConstructor;
6+
import lombok.AllArgsConstructor;
7+
8+
/**
9+
* Represents an MCP error
10+
*/
11+
@Data
12+
@NoArgsConstructor
13+
@AllArgsConstructor
14+
public class MCPError {
15+
16+
@JsonProperty("code")
17+
private int code;
18+
19+
@JsonProperty("message")
20+
private String message;
21+
22+
@JsonProperty("data")
23+
private Object data;
24+
25+
// Standard MCP error codes
26+
public static final int PARSE_ERROR = -32700;
27+
public static final int INVALID_REQUEST = -32600;
28+
public static final int METHOD_NOT_FOUND = -32601;
29+
public static final int INVALID_PARAMS = -32602;
30+
public static final int INTERNAL_ERROR = -32603;
31+
public static final int UNAUTHORIZED = -32001;
32+
public static final int FORBIDDEN = -32002;
33+
34+
public static MCPError parseError(String message) {
35+
return new MCPError(PARSE_ERROR, message, null);
36+
}
37+
38+
public static MCPError invalidRequest(String message) {
39+
return new MCPError(INVALID_REQUEST, message, null);
40+
}
41+
42+
public static MCPError methodNotFound(String method) {
43+
return new MCPError(METHOD_NOT_FOUND, "Method not found: " + method, null);
44+
}
45+
46+
public static MCPError invalidParams(String message) {
47+
return new MCPError(INVALID_PARAMS, message, null);
48+
}
49+
50+
public static MCPError internalError(String message) {
51+
return new MCPError(INTERNAL_ERROR, message, null);
52+
}
53+
54+
public static MCPError unauthorized(String message) {
55+
return new MCPError(UNAUTHORIZED, message != null ? message : "Unauthorized", null);
56+
}
57+
58+
public static MCPError forbidden(String message) {
59+
return new MCPError(FORBIDDEN, message != null ? message : "Forbidden", null);
60+
}
61+
}

0 commit comments

Comments
 (0)