diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index 8f0433eb1..40d2a7371 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -387,7 +387,12 @@ public Mono initialize() { return withSession("by explicit API call", init -> Mono.just(init.get())); } - private Mono doInitialize(McpClientSession mcpClientSession) { + private Mono doInitialize(Initialization initializaiton) { + + initializaiton.setMcpClientSession(this.sessionSupplier.get()); + + McpClientSession mcpClientSession = initializaiton.mcpSession(); + String latestVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); McpSchema.InitializeRequest initializeRequest = new McpSchema.InitializeRequest(// @formatter:off @@ -410,6 +415,9 @@ private Mono doInitialize(McpClientSession mcpClient return mcpClientSession.sendNotification(McpSchema.METHOD_NOTIFICATION_INITIALIZED, null) .thenReturn(initializeResult); + }).doOnNext(initializaiton::complete).onErrorResume(ex -> { + initializaiton.error(ex); + return Mono.error(ex); }); } @@ -477,15 +485,9 @@ private Mono withSession(String actionName, Function initializationJob = needsToInitialize - ? doInitialize(newInit.mcpSession()).doOnNext(newInit::complete).onErrorResume(ex -> { - newInit.error(ex); - return Mono.error(ex); - }) : previous.await(); + Mono initializationJob = needsToInitialize ? doInitialize(newInit) + : previous.await(); return initializationJob.map(initializeResult -> this.initializationRef.get()) .timeout(this.initializationTimeout) diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index c83992407..59611eacd 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -229,27 +229,27 @@ private String generateRequestId() { public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { String requestId = this.generateRequestId(); - return Mono.deferContextual(ctx -> Mono.create(sink -> { + return Mono.deferContextual(ctx -> Mono.create(responseSink -> { logger.debug("Sending message for method {}", method); - this.pendingResponses.put(requestId, sink); + this.pendingResponses.put(requestId, responseSink); McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, method, requestId, requestParams); this.transport.sendMessage(jsonrpcRequest).contextWrite(ctx).subscribe(v -> { }, error -> { this.pendingResponses.remove(requestId); - sink.error(error); + responseSink.error(error); }); - })).timeout(this.requestTimeout).handle((jsonRpcResponse, sink) -> { + })).timeout(this.requestTimeout).handle((jsonRpcResponse, resultSink) -> { if (jsonRpcResponse.error() != null) { logger.error("Error handling request: {}", jsonRpcResponse.error()); - sink.error(new McpError(jsonRpcResponse.error())); + resultSink.error(new McpError(jsonRpcResponse.error())); } else { if (typeRef.getType().equals(Void.class)) { - sink.complete(); + resultSink.complete(); } else { - sink.next(this.transport.unmarshalFrom(jsonRpcResponse.result(), typeRef)); + resultSink.next(this.transport.unmarshalFrom(jsonRpcResponse.result(), typeRef)); } } });