Skip to content

Commit eee6e08

Browse files
Merge pull request #18630 from sIvanovKonstantyn/master
BAEL-9293 - Securing Spring AI MCP servers with OAuth2
2 parents 88737c2 + f1a30cf commit eee6e08

File tree

7 files changed

+239
-0
lines changed

7 files changed

+239
-0
lines changed

spring-ai-3/pom.xml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@
5151
<groupId>org.springframework.ai</groupId>
5252
<artifactId>spring-ai-openai-spring-boot-starter</artifactId>
5353
</dependency>
54+
<dependency>
55+
<groupId>org.springframework.ai</groupId>
56+
<artifactId>spring-ai-mcp-server-webmvc-spring-boot-starter</artifactId>
57+
</dependency>
5458
<dependency>
5559
<groupId>org.hsqldb</groupId>
5660
<artifactId>hsqldb</artifactId>
@@ -61,6 +65,16 @@
6165
<artifactId>spring-ai-starter-model-openai</artifactId>
6266
<version>${spring-ai-start-model-openai.version}</version>
6367
</dependency>
68+
<dependency>
69+
<groupId>org.springframework.boot</groupId>
70+
<artifactId>spring-boot-starter-oauth2-resource-server</artifactId>
71+
<version>${oauth2-resource-server.version}</version>
72+
</dependency>
73+
<dependency>
74+
<groupId>org.springframework.boot</groupId>
75+
<artifactId>spring-boot-starter-oauth2-authorization-server</artifactId>
76+
<version>${oauth2-authorization-server.version}</version>
77+
</dependency>
6478

6579
<!-- Test dependencies -->
6680
<dependency>
@@ -151,6 +165,8 @@
151165
<spring-boot.version>3.5.0</spring-boot.version>
152166
<spring-ai.version>1.0.0-M6</spring-ai.version>
153167
<spring-ai-start-model-openai.version>1.0.0-M7</spring-ai-start-model-openai.version>
168+
<oauth2-resource-server.version>3.4.2</oauth2-resource-server.version>
169+
<oauth2-authorization-server.version>3.3.3</oauth2-authorization-server.version>
154170
</properties>
155171

156172
</project>
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package com.baeldung.springai.mcp.oauth2;
2+
3+
import org.springframework.ai.autoconfigure.chat.client.ChatClientAutoConfiguration;
4+
import org.springframework.ai.autoconfigure.mistralai.MistralAiAutoConfiguration;
5+
import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
6+
import org.springframework.ai.model.openai.autoconfigure.*;
7+
import org.springframework.boot.SpringApplication;
8+
import org.springframework.boot.autoconfigure.SpringBootApplication;
9+
import org.springframework.boot.autoconfigure.data.mongo.MongoDataAutoConfiguration;
10+
import org.springframework.boot.autoconfigure.mongo.MongoAutoConfiguration;
11+
12+
@SpringBootApplication(exclude = {
13+
ChatClientAutoConfiguration.class,
14+
MongoAutoConfiguration.class,
15+
MistralAiAutoConfiguration.class,
16+
MongoDataAutoConfiguration.class,
17+
org.springframework.ai.autoconfigure.vectorstore.mongo.MongoDBAtlasVectorStoreAutoConfiguration.class,
18+
org.springframework.ai.vectorstore.mongodb.autoconfigure.MongoDBAtlasVectorStoreAutoConfiguration.class,
19+
OpenAiAudioSpeechAutoConfiguration.class,
20+
OpenAiAutoConfiguration.class,
21+
OpenAiAudioTranscriptionAutoConfiguration.class,
22+
OpenAiChatAutoConfiguration.class,
23+
OpenAiEmbeddingAutoConfiguration.class,
24+
OpenAiImageAutoConfiguration.class,
25+
OpenAiModerationAutoConfiguration.class})
26+
class McpServerApplication {
27+
28+
public static void main(String[] args) {
29+
SpringApplication app = new SpringApplication(McpServerApplication.class);
30+
app.setAdditionalProfiles("mcp");
31+
app.run(args);
32+
}
33+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package com.baeldung.springai.mcp.oauth2;
2+
3+
import org.springframework.ai.tool.annotation.Tool;
4+
import org.springframework.ai.tool.annotation.ToolParam;
5+
6+
public class StockInformationHolder {
7+
@Tool(description = "Get stock price for a company symbol")
8+
public String getStockPrice(@ToolParam String symbol) {
9+
if ("AAPL".equalsIgnoreCase(symbol)) {
10+
return "AAPL: $150.00";
11+
} else if ("GOOGL".equalsIgnoreCase(symbol)) {
12+
return "GOOGL: $2800.00";
13+
} else {
14+
return symbol + ": Data not available";
15+
}
16+
}
17+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package com.baeldung.springai.mcp.oauth2.configuration;
2+
3+
import com.baeldung.springai.mcp.oauth2.StockInformationHolder;
4+
import org.springframework.ai.tool.ToolCallbackProvider;
5+
import org.springframework.ai.tool.method.MethodToolCallbackProvider;
6+
import org.springframework.context.annotation.Bean;
7+
import org.springframework.context.annotation.Configuration;
8+
import org.springframework.context.annotation.Profile;
9+
10+
@Profile("mcp")
11+
@Configuration
12+
public class McpServerConfiguration {
13+
14+
@Bean
15+
public ToolCallbackProvider stockTools() {
16+
return MethodToolCallbackProvider
17+
.builder()
18+
.toolObjects(new StockInformationHolder())
19+
.build();
20+
}
21+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package com.baeldung.springai.mcp.oauth2.configuration;
2+
3+
import org.springframework.context.annotation.Bean;
4+
import org.springframework.context.annotation.Configuration;
5+
import org.springframework.security.config.Customizer;
6+
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
7+
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
8+
import org.springframework.security.config.annotation.web.configurers.CsrfConfigurer;
9+
import org.springframework.security.oauth2.server.authorization.config.annotation.web.configurers.OAuth2AuthorizationServerConfigurer;
10+
import org.springframework.security.web.SecurityFilterChain;
11+
12+
@Configuration
13+
@EnableWebSecurity
14+
public class McpServerSecurityConfiguration {
15+
@Bean
16+
public SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
17+
return http
18+
.authorizeHttpRequests(auth -> auth
19+
.requestMatchers("/mcp/**").authenticated()
20+
.requestMatchers("/sse").authenticated()
21+
.anyRequest().permitAll())
22+
.with(OAuth2AuthorizationServerConfigurer.authorizationServer(), Customizer.withDefaults())
23+
.oauth2ResourceServer(oauth2 -> oauth2.jwt(Customizer.withDefaults()))
24+
.csrf(CsrfConfigurer::disable)
25+
.cors(Customizer.withDefaults())
26+
.build();
27+
}
28+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
spring:
2+
security:
3+
oauth2:
4+
authorizationserver:
5+
client:
6+
oidc-client:
7+
registration:
8+
client-id: mcp-client
9+
client-secret: "{noop}secret"
10+
client-authentication-methods: client_secret_basic
11+
authorization-grant-types: client_credentials
12+
# Avoid starting docker from the shared codebase
13+
docker:
14+
compose:
15+
enabled: false
16+
17+
logging:
18+
level:
19+
org.springframework.ai.mcp: DEBUG
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
package com.baeldung.springai.mcp.oauth2;
2+
3+
import com.fasterxml.jackson.databind.JsonNode;
4+
import org.junit.jupiter.api.BeforeEach;
5+
import org.junit.jupiter.api.Test;
6+
import org.slf4j.Logger;
7+
import org.slf4j.LoggerFactory;
8+
import org.springframework.boot.test.context.SpringBootTest;
9+
import org.springframework.boot.test.web.server.LocalServerPort;
10+
import org.springframework.http.HttpHeaders;
11+
import org.springframework.http.MediaType;
12+
import org.springframework.test.context.ActiveProfiles;
13+
import org.springframework.web.reactive.function.BodyInserters;
14+
import org.springframework.web.reactive.function.client.WebClient;
15+
import reactor.core.publisher.Flux;
16+
17+
import java.nio.charset.StandardCharsets;
18+
import java.time.Duration;
19+
import java.util.Base64;
20+
21+
import static org.assertj.core.api.Assertions.assertThat;
22+
23+
@ActiveProfiles("mcp")
24+
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
25+
class McpServerOAuth2LiveTest {
26+
27+
private static final Logger log = LoggerFactory.getLogger(McpServerOAuth2LiveTest.class);
28+
29+
@LocalServerPort
30+
private int port;
31+
32+
private WebClient webClient;
33+
34+
@BeforeEach
35+
void setup() {
36+
webClient = WebClient.create("http://localhost:" + port);
37+
}
38+
39+
@Test
40+
void givenSecuredMcpServer_whenCallingTheEndpointsWithValidAuthorizationHeader_thenExpectedResponseShouldBeObtained() {
41+
Flux<String> eventStream = webClient.get()
42+
.uri("/sse")
43+
.header("Authorization", obtainAccessToken())
44+
.accept(MediaType.TEXT_EVENT_STREAM)
45+
.retrieve()
46+
.bodyToFlux(String.class);
47+
48+
eventStream.subscribe(
49+
data -> {
50+
log.info("Response received: {}", data);
51+
if(!isRequestMessage(data)) {
52+
assertThat(data).containsSequence("AAPL", "$150");
53+
}
54+
},
55+
error -> log.error(error.getMessage()),
56+
() -> log.info("Stream completed"));
57+
58+
Flux<String> sendMessage = webClient.post()
59+
.uri("/mcp/message")
60+
.header("Authorization", obtainAccessToken())
61+
.contentType(MediaType.APPLICATION_JSON)
62+
.accept(MediaType.TEXT_EVENT_STREAM)
63+
.bodyValue("""
64+
{
65+
"jsonrpc": "2.0",
66+
"id": "1",
67+
"method": "tools/call",
68+
"params": {
69+
"name": "getStockPrice",
70+
"arguments": {
71+
"arg0": "AAPL"
72+
}
73+
}
74+
}
75+
""")
76+
.retrieve()
77+
.bodyToFlux(String.class);
78+
79+
sendMessage.blockLast();
80+
eventStream.blockLast();
81+
}
82+
83+
private boolean isRequestMessage(String data) {
84+
return data.contains("/mcp/message");
85+
}
86+
87+
public String obtainAccessToken() {
88+
String clientId = "mcp-client";
89+
String clientSecret = "secret";
90+
String basicToken = Base64.getEncoder()
91+
.encodeToString((clientId + ":" + clientSecret).getBytes(StandardCharsets.UTF_8));
92+
93+
return "Bearer " + webClient.post()
94+
.uri("/oauth2/token")
95+
.header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE)
96+
.header(HttpHeaders.AUTHORIZATION, "Basic " + basicToken)
97+
.body(BodyInserters
98+
.fromFormData("grant_type", "client_credentials")
99+
)
100+
.retrieve()
101+
.bodyToMono(JsonNode.class)
102+
.map(node -> node.get("access_token").asText())
103+
.block(Duration.ofSeconds(5));
104+
}
105+
}

0 commit comments

Comments
 (0)