diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java b/mcp/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java index 2cc1c5dba..2fc669c15 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java @@ -289,6 +289,7 @@ public Mono withIntitialization(String actionName, Function this.initializationRef.get()) .timeout(this.initializationTimeout) .onErrorResume(ex -> { + this.initializationRef.compareAndSet(newInit, null); return Mono.error(new RuntimeException("Client failed to initialize " + actionName, ex)); }) .flatMap(operation); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerTests.java index 02021edbf..19de14c24 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerTests.java @@ -230,7 +230,10 @@ void shouldHandleConcurrentInitializationRequests() { @Test void shouldHandleInitializationFailure() { when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())) - .thenReturn(Mono.error(new RuntimeException("Connection failed"))); + // fail once + .thenReturn(Mono.error(new RuntimeException("Connection failed"))) + // succeeds on the second call + .thenReturn(Mono.just(MOCK_INIT_RESULT)); StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult()))) .expectError(RuntimeException.class) @@ -238,6 +241,15 @@ void shouldHandleInitializationFailure() { assertThat(initializer.isInitialized()).isFalse(); assertThat(initializer.currentInitializationResult()).isNull(); + + // The initializer can recover from previous errors + StepVerifier + .create(initializer.withIntitialization("successful init", init -> Mono.just(init.initializeResult()))) + .expectNext(MOCK_INIT_RESULT) + .verifyComplete(); + + assertThat(initializer.isInitialized()).isTrue(); + assertThat(initializer.currentInitializationResult()).isEqualTo(MOCK_INIT_RESULT); } @Test