Skip to content

Commit c7a22da

Browse files
authored
Merge branch 'master' into fix-6138
2 parents 02a2536 + c2efae1 commit c7a22da

File tree

20 files changed

+749
-117
lines changed

20 files changed

+749
-117
lines changed

shenyu-client/shenyu-client-mcp/shenyu-client-mcp-common/src/main/java/org/apache/shenyu/client/mcp/generator/McpRequestConfigGenerator.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,10 @@ public static JsonObject generateRequestConfig(final JsonObject openApiJson, fin
6060
requestTemplate.addProperty(RequestTemplateConstants.METHOD_KEY, methodType);
6161

6262
// argsPosition
63+
JsonObject argsPosition = new JsonObject();
6364
JsonObject methodTypeJson = method.getAsJsonObject(methodType);
6465
JsonArray parameters = methodTypeJson.getAsJsonArray(OpenApiConstants.OPEN_API_PATH_OPERATION_METHOD_PARAMETERS_KEY);
6566
if (Objects.nonNull(parameters)) {
66-
JsonObject argsPosition = new JsonObject();
67-
6867
for (JsonElement parameter : parameters) {
6968
JsonObject paramObj = parameter.getAsJsonObject();
7069

@@ -77,8 +76,11 @@ public static JsonObject generateRequestConfig(final JsonObject openApiJson, fin
7776
argsPosition.addProperty(name, inValue);
7877
}
7978
}
80-
requestTemplate.add(RequestTemplateConstants.ARGS_POSITION_KEY, argsPosition);
8179
}
80+
// Keep root-level argsPosition as canonical format used by gateway parser.
81+
root.add(RequestTemplateConstants.ARGS_POSITION_KEY, argsPosition.deepCopy());
82+
// Keep requestTemplate-level argsPosition for backward compatibility.
83+
requestTemplate.add(RequestTemplateConstants.ARGS_POSITION_KEY, argsPosition);
8284

8385
// argsToJsonBody
8486
requestTemplate.addProperty(RequestTemplateConstants.BODY_JSON_KEY, shenyuMcpRequestConfig.getBodyToJson());

shenyu-client/shenyu-client-mcp/shenyu-client-mcp-register/src/main/java/org/apache/shenyu/client/mcp/McpServiceEventListener.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,9 @@ private static String concatPaths(final String path1, final String path2) {
304304
@Override
305305
protected String buildApiSuperPath(final Class<?> clazz, final ShenyuMcpTool beanShenyuClient) {
306306
Server[] servers = beanShenyuClient.definition().servers();
307+
if (servers.length == 0) {
308+
return "";
309+
}
307310
if (servers.length != 1) {
308311
log.warn("The shenyuMcp service supports only a single server entry. Please ensure that only one server is configured");
309312
}
@@ -363,7 +366,7 @@ private McpToolsRegisterDTO buildMcpToolsRegisterDTO(final Object bean, final Cl
363366
validateClientConfig(shenyuMcpTool, url);
364367
JsonObject openApiJson = McpOpenApiGenerator.generateOpenApiJson(classShenyuClient, shenyuMcpTool, url);
365368
McpToolsRegisterDTO mcpToolsRegisterDTO = McpToolsRegisterDTOGenerator.generateRegisterDTO(shenyuMcpTool, openApiJson, url, namespaceId);
366-
MetaDataRegisterDTO metaDataRegisterDTO = buildMetaDataDTO(bean, classShenyuClient, superPath, clazz, method, namespaceId);
369+
MetaDataRegisterDTO metaDataRegisterDTO = buildMetaDataDTO(bean, classShenyuClient, url, clazz, method, namespaceId);
367370
metaDataRegisterDTO.setEnabled(shenyuMcpTool.getEnable());
368371
mcpToolsRegisterDTO.setMetaDataRegisterDTO(metaDataRegisterDTO);
369372
return mcpToolsRegisterDTO;

shenyu-loadbalancer/src/main/java/org/apache/shenyu/loadbalancer/cache/UpstreamCacheManager.java

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import org.apache.shenyu.common.utils.Singleton;
2727
import org.apache.shenyu.loadbalancer.entity.Upstream;
2828

29-
import java.util.ArrayList;
3029
import java.util.HashSet;
3130
import java.util.List;
3231
import java.util.Map;
@@ -150,8 +149,7 @@ public void submit(final String selectorId, final List<Upstream> upstreamList) {
150149

151150
// Check if the list is empty first to avoid unnecessary processing
152151
if (actualUpstreamList.isEmpty()) {
153-
List<Upstream> existUpstreamList = MapUtils.computeIfAbsent(UPSTREAM_MAP, selectorId, k -> Lists.newArrayList());
154-
removeAllUpstreams(selectorId, existUpstreamList);
152+
removeByKey(selectorId);
155153
return;
156154
}
157155

@@ -179,11 +177,6 @@ private void initializeUpstreamHealthStatus(final List<Upstream> upstreamList) {
179177
});
180178
}
181179

182-
private void removeAllUpstreams(final String selectorId, final List<Upstream> existUpstreamList) {
183-
List<Upstream> toRemove = new ArrayList<>(existUpstreamList);
184-
toRemove.forEach(up -> task.triggerRemoveOne(selectorId, up));
185-
}
186-
187180
private void processOfflineUpstreams(final String selectorId, final List<Upstream> offlineUpstreamList,
188181
final List<Upstream> existUpstreamList) {
189182
Map<String, Upstream> currentUnhealthyMap = getCurrentUnhealthyMap(selectorId);

shenyu-loadbalancer/src/main/java/org/apache/shenyu/loadbalancer/entity/Upstream.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,7 @@ public boolean equals(final Object o) {
509509

510510
@Override
511511
public int hashCode() {
512-
return Objects.hash(protocol, url, weight);
512+
return Objects.hash(protocol, url);
513513
}
514514

515515
@Override

shenyu-loadbalancer/src/test/java/org/apache/shenyu/loadbalancer/cache/UpstreamCacheManagerTest.java

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,88 @@ public void testSubmitWithHealthCheckDisabledAndStatusFalse() {
278278
upstreamCacheManager.removeByKey(testSelectorId);
279279
}
280280

281+
@Test
282+
@Order(10)
283+
public void testSubmitCanRecoverAfterEmptyUpstreamEvent() {
284+
final UpstreamCacheManager upstreamCacheManager = UpstreamCacheManager.getInstance();
285+
final String testSelectorId = "RECOVER_AFTER_EMPTY_EVENT_TEST";
286+
287+
List<Upstream> initialList = new ArrayList<>(1);
288+
initialList.add(Upstream.builder()
289+
.protocol("http://")
290+
.url("recover-upstream:8080")
291+
.status(true)
292+
.healthCheckEnabled(false)
293+
.build());
294+
upstreamCacheManager.submit(testSelectorId, initialList);
295+
List<Upstream> firstSubmitResult = upstreamCacheManager.findUpstreamListBySelectorId(testSelectorId);
296+
Assertions.assertNotNull(firstSubmitResult);
297+
Assertions.assertFalse(firstSubmitResult.isEmpty());
298+
299+
upstreamCacheManager.submit(testSelectorId, new ArrayList<>());
300+
List<Upstream> afterEmptySubmitResult = upstreamCacheManager.findUpstreamListBySelectorId(testSelectorId);
301+
Assertions.assertTrue(Objects.isNull(afterEmptySubmitResult) || afterEmptySubmitResult.isEmpty());
302+
303+
List<Upstream> recoveredList = new ArrayList<>(1);
304+
recoveredList.add(Upstream.builder()
305+
.protocol("http://")
306+
.url("recover-upstream:8080")
307+
.status(true)
308+
.healthCheckEnabled(false)
309+
.build());
310+
upstreamCacheManager.submit(testSelectorId, recoveredList);
311+
List<Upstream> secondSubmitResult = upstreamCacheManager.findUpstreamListBySelectorId(testSelectorId);
312+
Assertions.assertNotNull(secondSubmitResult);
313+
Assertions.assertFalse(secondSubmitResult.isEmpty());
314+
Assertions.assertTrue(secondSubmitResult.stream().anyMatch(upstream -> "recover-upstream:8080".equals(upstream.getUrl())));
315+
316+
upstreamCacheManager.removeByKey(testSelectorId);
317+
}
318+
319+
@Test
320+
@Order(11)
321+
public void testSubmitEmptyEventClearsUnhealthyState() {
322+
final UpstreamCacheManager upstreamCacheManager = UpstreamCacheManager.getInstance();
323+
final String testSelectorId = "EMPTY_EVENT_CLEARS_UNHEALTHY_TEST";
324+
325+
List<Upstream> offlineList = new ArrayList<>(1);
326+
offlineList.add(Upstream.builder()
327+
.protocol("http://")
328+
.url("stale-upstream:8080")
329+
.status(false)
330+
.healthCheckEnabled(true)
331+
.build());
332+
upstreamCacheManager.submit(testSelectorId, offlineList);
333+
334+
UpstreamCheckTask task = getUpstreamCheckTask(upstreamCacheManager);
335+
if (Objects.nonNull(task)) {
336+
List<Upstream> unhealthyBeforeEmpty = task.getUnhealthyUpstream().get(testSelectorId);
337+
Assertions.assertNotNull(unhealthyBeforeEmpty);
338+
Assertions.assertFalse(unhealthyBeforeEmpty.isEmpty());
339+
}
340+
341+
upstreamCacheManager.submit(testSelectorId, new ArrayList<>());
342+
343+
if (Objects.nonNull(task)) {
344+
List<Upstream> unhealthyAfterEmpty = task.getUnhealthyUpstream().get(testSelectorId);
345+
Assertions.assertTrue(Objects.isNull(unhealthyAfterEmpty) || unhealthyAfterEmpty.isEmpty());
346+
}
347+
348+
List<Upstream> recoveredList = new ArrayList<>(1);
349+
recoveredList.add(Upstream.builder()
350+
.protocol("http://")
351+
.url("stale-upstream:8080")
352+
.status(true)
353+
.healthCheckEnabled(false)
354+
.build());
355+
upstreamCacheManager.submit(testSelectorId, recoveredList);
356+
List<Upstream> finalResult = upstreamCacheManager.findUpstreamListBySelectorId(testSelectorId);
357+
Assertions.assertNotNull(finalResult);
358+
Assertions.assertFalse(finalResult.isEmpty());
359+
360+
upstreamCacheManager.removeByKey(testSelectorId);
361+
}
362+
281363
/**
282364
* Helper method to get the UpstreamCheckTask using reflection.
283365
*/

shenyu-loadbalancer/src/test/java/org/apache/shenyu/loadbalancer/entity/UpstreamTest.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,15 @@ public void upstreamTest() {
7373
.weight(1)
7474
.status(true)
7575
.build();
76+
Upstream upstream4 = Upstream.builder()
77+
.protocol("https://")
78+
.url("url")
79+
.weight(2)
80+
.status(true)
81+
.build();
7682
Assertions.assertEquals(upstream2, upstream3);
83+
Assertions.assertEquals(upstream2, upstream4);
84+
Assertions.assertEquals(upstream2.hashCode(), upstream4.hashCode());
7785
Assertions.assertNotNull(upstream2.toString());
7886
Assertions.assertTrue(upstream2.hashCode() >= 0);
7987
}

shenyu-plugin/shenyu-plugin-mcp-server/src/main/java/org/apache/shenyu/plugin/mcp/server/McpServerPlugin.java

Lines changed: 88 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,19 @@
3636
import org.apache.shenyu.plugin.mcp.server.transport.MessageHandlingResult;
3737
import org.slf4j.Logger;
3838
import org.slf4j.LoggerFactory;
39+
import org.springframework.http.HttpHeaders;
3940
import org.springframework.http.HttpStatus;
4041
import org.springframework.http.codec.HttpMessageReader;
4142
import org.springframework.web.reactive.function.server.ServerRequest;
4243
import org.springframework.web.server.ServerWebExchange;
4344
import reactor.core.publisher.Mono;
4445
import com.fasterxml.jackson.databind.ObjectMapper;
4546

47+
import java.util.LinkedHashSet;
4648
import java.util.List;
4749
import java.util.Map;
4850
import java.util.Objects;
51+
import java.util.Set;
4952
import java.nio.charset.StandardCharsets;
5053

5154
/**
@@ -109,10 +112,19 @@ public class McpServerPlugin extends AbstractShenyuPlugin {
109112
*/
110113
private static final String BEARER_PREFIX = "Bearer ";
111114

115+
private static final String CORS_ALLOW_METHODS = "GET, POST, OPTIONS";
116+
117+
private static final String CORS_STREAMABLE_ALLOW_METHODS = "POST, OPTIONS";
118+
119+
private static final String CORS_FALLBACK_ALLOW_HEADERS =
120+
"Content-Type, Mcp-Session-Id, Authorization, Last-Event-ID, Mcp-Protocol-Version, X-Request, XRequest, xrequest";
121+
112122
private final ShenyuMcpServerManager shenyuMcpServerManager;
113123

114124
private final List<HttpMessageReader<?>> messageReaders;
115125

126+
private final String configuredCorsAllowHeaders;
127+
116128
/**
117129
* Constructs a new MCP server plugin.
118130
*
@@ -121,8 +133,22 @@ public class McpServerPlugin extends AbstractShenyuPlugin {
121133
*/
122134
public McpServerPlugin(final ShenyuMcpServerManager shenyuMcpServerManager,
123135
final List<HttpMessageReader<?>> messageReaders) {
136+
this(shenyuMcpServerManager, messageReaders, null);
137+
}
138+
139+
/**
140+
* Constructs a new MCP server plugin.
141+
*
142+
* @param shenyuMcpServerManager the MCP server manager for handling transport providers
143+
* @param messageReaders the HTTP message readers for request processing
144+
* @param configuredCorsAllowHeaders CORS allow headers configured by {@code shenyu.cross.allowedHeaders}
145+
*/
146+
public McpServerPlugin(final ShenyuMcpServerManager shenyuMcpServerManager,
147+
final List<HttpMessageReader<?>> messageReaders,
148+
final String configuredCorsAllowHeaders) {
124149
this.shenyuMcpServerManager = shenyuMcpServerManager;
125150
this.messageReaders = messageReaders;
151+
this.configuredCorsAllowHeaders = configuredCorsAllowHeaders;
126152
}
127153

128154
@Override
@@ -203,6 +229,10 @@ private Mono<Void> routeByProtocol(final ServerWebExchange exchange,
203229
final SelectorData selector,
204230
final String uri) {
205231

232+
if ("OPTIONS".equalsIgnoreCase(exchange.getRequest().getMethod().name())) {
233+
return handleCorsPreflight(exchange, uri);
234+
}
235+
206236
if (isStreamableHttpProtocol(uri)) {
207237
return handleStreamableHttpRequest(exchange, chain, request, uri);
208238
} else if (isSseProtocol(uri)) {
@@ -276,6 +306,19 @@ private boolean isSseProtocol(final String uri) {
276306
return uri.contains(SSE_PATH) || uri.endsWith(SSE_PATH) || uri.endsWith(MESSAGE_ENDPOINT);
277307
}
278308

309+
/**
310+
* Handles CORS preflight (OPTIONS) requests.
311+
*
312+
* @param exchange the server web exchange
313+
* @return a Mono representing completion
314+
*/
315+
private Mono<Void> handleCorsPreflight(final ServerWebExchange exchange, final String uri) {
316+
exchange.getResponse().setStatusCode(HttpStatus.OK);
317+
setCorsHeaders(exchange, resolveAllowMethods(uri));
318+
exchange.getResponse().getHeaders().set("Access-Control-Max-Age", "3600");
319+
return exchange.getResponse().setComplete();
320+
}
321+
279322
/**
280323
* Handles Streamable HTTP MCP requests with unified endpoint processing.
281324
*
@@ -579,11 +622,51 @@ private Mono<Void> handleMessageEndpoint(final ServerWebExchange exchange,
579622
* @param exchange the server web exchange
580623
*/
581624
private void setCorsHeaders(final ServerWebExchange exchange) {
582-
exchange.getResponse().getHeaders().set("Access-Control-Allow-Origin", "*");
583-
exchange.getResponse().getHeaders().set("Access-Control-Allow-Headers",
584-
"Content-Type, Mcp-Session-Id, Authorization, Last-Event-ID, Mcp-Protocol-Version");
585-
exchange.getResponse().getHeaders().set("Access-Control-Allow-Methods",
586-
"GET, POST, OPTIONS");
625+
setCorsHeaders(exchange, resolveAllowMethods(exchange.getRequest().getURI().getRawPath()));
626+
}
627+
628+
private void setCorsHeaders(final ServerWebExchange exchange, final String allowMethods) {
629+
exchange.getResponse().getHeaders().set("Access-Control-Allow-Origin", resolveAllowOrigin(exchange));
630+
exchange.getResponse().getHeaders().set("Access-Control-Allow-Headers", resolveAllowHeaders(exchange));
631+
exchange.getResponse().getHeaders().set("Access-Control-Allow-Methods", allowMethods);
632+
mergeVaryHeaders(exchange);
633+
}
634+
635+
private String resolveAllowMethods(final String uri) {
636+
return isStreamableHttpProtocol(uri) ? CORS_STREAMABLE_ALLOW_METHODS : CORS_ALLOW_METHODS;
637+
}
638+
639+
private String resolveAllowOrigin(final ServerWebExchange exchange) {
640+
final String origin = exchange.getRequest().getHeaders().getFirst("Origin");
641+
return Objects.nonNull(origin) && !origin.isBlank() ? origin : "*";
642+
}
643+
644+
private String resolveAllowHeaders(final ServerWebExchange exchange) {
645+
final Set<String> allowedHeaders = new LinkedHashSet<>();
646+
final String allowHeaders = Objects.nonNull(configuredCorsAllowHeaders) && !configuredCorsAllowHeaders.isBlank()
647+
? configuredCorsAllowHeaders : CORS_FALLBACK_ALLOW_HEADERS;
648+
for (String header : allowHeaders.split(",")) {
649+
final String trimmed = header.trim();
650+
if (!trimmed.isEmpty()) {
651+
allowedHeaders.add(trimmed);
652+
}
653+
}
654+
return String.join(", ", allowedHeaders);
655+
}
656+
657+
private void mergeVaryHeaders(final ServerWebExchange exchange) {
658+
final Set<String> varyValues = new LinkedHashSet<>();
659+
for (String varyHeader : exchange.getResponse().getHeaders().getOrEmpty(HttpHeaders.VARY)) {
660+
for (String varyValue : varyHeader.split(",")) {
661+
final String trimmed = varyValue.trim();
662+
if (!trimmed.isEmpty()) {
663+
varyValues.add(trimmed);
664+
}
665+
}
666+
}
667+
varyValues.add("Origin");
668+
varyValues.add("Access-Control-Request-Headers");
669+
exchange.getResponse().getHeaders().setVary(List.copyOf(varyValues));
587670
}
588671

589672
/**

0 commit comments

Comments
 (0)