-
Notifications
You must be signed in to change notification settings - Fork 704
feat(ws): adds ws transport client #139
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,221 @@ | ||
/* | ||
* Copyright 2024 - 2024 the original author or authors. | ||
*/ | ||
|
||
package io.modelcontextprotocol.client.transport; | ||
|
||
import java.net.URI; | ||
import java.net.http.HttpClient; | ||
import java.net.http.WebSocket; | ||
import java.time.Duration; | ||
import java.util.concurrent.CompletableFuture; | ||
import java.util.concurrent.CompletionStage; | ||
import java.util.concurrent.atomic.AtomicReference; | ||
import java.util.function.Consumer; | ||
import java.util.function.Function; | ||
|
||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
|
||
import com.fasterxml.jackson.core.type.TypeReference; | ||
import com.fasterxml.jackson.databind.ObjectMapper; | ||
|
||
import io.modelcontextprotocol.spec.McpClientTransport; | ||
import io.modelcontextprotocol.spec.McpSchema; | ||
import io.modelcontextprotocol.util.Assert; | ||
import reactor.core.publisher.Mono; | ||
import reactor.core.publisher.Sinks; | ||
import reactor.util.retry.Retry; | ||
|
||
/** | ||
* The WebSocket (WS) implementation of the | ||
* {@link io.modelcontextprotocol.spec.McpTransport} that follows the MCP HTTP with WS | ||
* transport specification, using Java's HttpClient. | ||
* | ||
* @author Aliaksei Darafeyeu | ||
*/ | ||
public class WebSocketClientTransport implements McpClientTransport { | ||
|
||
private static final Logger LOGGER = LoggerFactory.getLogger(WebSocketClientTransport.class); | ||
|
||
private final HttpClient httpClient; | ||
|
||
private final ObjectMapper objectMapper; | ||
|
||
private final URI uri; | ||
|
||
private final AtomicReference<WebSocket> webSocketRef = new AtomicReference<>(); | ||
|
||
private final AtomicReference<TransportState> state = new AtomicReference<>(TransportState.DISCONNECTED); | ||
|
||
private final Sinks.Many<Throwable> errorSink = Sinks.many().multicast().onBackpressureBuffer(); | ||
|
||
/** | ||
* The constructor for the WebSocketClientTransport. | ||
* @param uri the URI to connect to | ||
* @param clientBuilder the HttpClient builder | ||
* @param objectMapper the ObjectMapper for JSON serialization/deserialization | ||
*/ | ||
WebSocketClientTransport(final URI uri, final HttpClient.Builder clientBuilder, final ObjectMapper objectMapper) { | ||
this.uri = uri; | ||
this.httpClient = clientBuilder.build(); | ||
this.objectMapper = objectMapper; | ||
} | ||
|
||
/** | ||
* Creates a new WebSocketClientTransport instance with the specified URI. | ||
* @param uri the URI to connect to | ||
* @return a new Builder instance | ||
*/ | ||
public static Builder builder(final URI uri) { | ||
return new Builder().uri(uri); | ||
} | ||
|
||
/** | ||
* The state of the Transport connection. | ||
*/ | ||
public enum TransportState { | ||
|
||
DISCONNECTED, CONNECTING, CONNECTED, CLOSED | ||
|
||
} | ||
|
||
/** | ||
* A builder for creating instances of WebSocketClientTransport. | ||
*/ | ||
public static class Builder { | ||
|
||
private URI uri; | ||
|
||
private final HttpClient.Builder clientBuilder = HttpClient.newBuilder() | ||
.version(HttpClient.Version.HTTP_1_1) | ||
.connectTimeout(Duration.ofSeconds(10)); | ||
|
||
private ObjectMapper objectMapper = new ObjectMapper(); | ||
|
||
public Builder uri(final URI uri) { | ||
this.uri = uri; | ||
return this; | ||
} | ||
|
||
public Builder customizeClient(final Consumer<HttpClient.Builder> clientCustomizer) { | ||
Assert.notNull(clientCustomizer, "clientCustomizer must not be null"); | ||
clientCustomizer.accept(clientBuilder); | ||
return this; | ||
} | ||
|
||
public Builder objectMapper(final ObjectMapper objectMapper) { | ||
this.objectMapper = objectMapper; | ||
return this; | ||
} | ||
|
||
public WebSocketClientTransport build() { | ||
return new WebSocketClientTransport(uri, clientBuilder, objectMapper); | ||
} | ||
|
||
} | ||
|
||
public Mono<Void> connect(final Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) { | ||
if (!state.compareAndSet(TransportState.DISCONNECTED, TransportState.CONNECTING)) { | ||
return Mono.error(new IllegalStateException("WebSocket is already connecting or connected")); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It might be better to point the actual state in the exception. I always hate error messages like "its either this or that", instead tell me what it was. |
||
} | ||
|
||
return Mono.fromFuture(httpClient.newWebSocketBuilder().buildAsync(uri, new WebSocket.Listener() { | ||
private final StringBuilder messageBuffer = new StringBuilder(); | ||
|
||
@Override | ||
public void onOpen(WebSocket webSocket) { | ||
webSocketRef.set(webSocket); | ||
state.set(TransportState.CONNECTED); | ||
} | ||
|
||
@Override | ||
public CompletionStage<?> onText(WebSocket webSocket, CharSequence data, boolean last) { | ||
messageBuffer.append(data); | ||
if (last) { | ||
final String fullMessage = messageBuffer.toString(); | ||
messageBuffer.setLength(0); | ||
try { | ||
final McpSchema.JSONRPCMessage msg = McpSchema.deserializeJsonRpcMessage(objectMapper, | ||
fullMessage); | ||
handler.apply(Mono.just(msg)).subscribe(); | ||
} | ||
catch (Exception e) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I fixed this in #156. It would be nice to merge that and then we can stop catch Exception use a more fine grained type. |
||
errorSink.tryEmitNext(e); | ||
LOGGER.error("Error processing WS event", e); | ||
} | ||
} | ||
|
||
webSocket.request(1); | ||
return CompletableFuture.completedFuture(null); | ||
} | ||
|
||
@Override | ||
public void onError(WebSocket webSocket, Throwable error) { | ||
errorSink.tryEmitNext(error); | ||
state.set(TransportState.CLOSED); | ||
LOGGER.error("WS connection error", error); | ||
} | ||
|
||
@Override | ||
public CompletionStage<?> onClose(WebSocket webSocket, int statusCode, String reason) { | ||
state.set(TransportState.CLOSED); | ||
return CompletableFuture.completedFuture(null); | ||
} | ||
|
||
})).then(); | ||
} | ||
|
||
@Override | ||
public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) { | ||
|
||
return Mono.defer(() -> { | ||
WebSocket ws = webSocketRef.get(); | ||
if (ws == null && state.get() == TransportState.CONNECTING) { | ||
return Mono.error(new IllegalStateException("WebSocket is connecting.")); | ||
} | ||
|
||
if (ws == null || state.get() == TransportState.DISCONNECTED || state.get() == TransportState.CLOSED) { | ||
return Mono.error(new IllegalStateException("WebSocket is closed.")); | ||
} | ||
|
||
try { | ||
String json = objectMapper.writeValueAsString(message); | ||
return Mono.fromFuture(ws.sendText(json, true)).then(); | ||
} | ||
catch (Exception e) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. UncheckedIOException seems nice here. |
||
return Mono.error(e); | ||
} | ||
}).retryWhen(Retry.backoff(3, Duration.ofSeconds(3)).filter(err -> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Externalize or at least put in constants. |
||
if (err instanceof IllegalStateException) { | ||
return err.getMessage().equals("WebSocket is connecting."); | ||
} | ||
return true; | ||
})).onErrorResume(e -> { | ||
LOGGER.error("Failed to send message after retries", e); | ||
errorSink.tryEmitNext(e); | ||
return Mono.error(new IllegalStateException("WebSocket send failed after retries", e)); | ||
}); | ||
|
||
} | ||
|
||
@Override | ||
public Mono<Void> closeGracefully() { | ||
WebSocket webSocket = webSocketRef.getAndSet(null); | ||
if (webSocket != null && state.get() == TransportState.CONNECTED) { | ||
state.set(TransportState.CLOSED); | ||
return Mono.fromFuture(webSocket.sendClose(WebSocket.NORMAL_CLOSURE, "Closing")).then(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A constant here would be better possibly. |
||
} | ||
return Mono.empty(); | ||
} | ||
|
||
@Override | ||
public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) { | ||
return objectMapper.convertValue(data, typeRef); | ||
} | ||
|
||
public TransportState getState() { | ||
return state.get(); | ||
} | ||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
/* | ||
* Copyright 2024-2024 the original author or authors. | ||
*/ | ||
|
||
package io.modelcontextprotocol.client.transport; | ||
|
||
import static org.junit.jupiter.api.Assertions.assertEquals; | ||
|
||
import java.net.URI; | ||
import java.util.List; | ||
|
||
import org.junit.jupiter.api.AfterAll; | ||
import org.junit.jupiter.api.BeforeAll; | ||
import org.junit.jupiter.api.BeforeEach; | ||
import org.junit.jupiter.api.Test; | ||
import org.testcontainers.containers.GenericContainer; | ||
import org.testcontainers.images.builder.ImageFromDockerfile; | ||
|
||
import io.modelcontextprotocol.spec.McpSchema; | ||
import reactor.core.publisher.Mono; | ||
import reactor.test.StepVerifier; | ||
|
||
/** | ||
* Tests for the {@link WebSocketClientTransport} class. | ||
* | ||
* @author Aliaksei Darafeyeu | ||
*/ | ||
class WebSocketClientTransportTest { | ||
|
||
private static GenericContainer<?> wsContainer; | ||
|
||
private static URI websocketUri; | ||
|
||
private WebSocketClientTransport transport; | ||
|
||
@BeforeAll | ||
static void startContainer() { | ||
wsContainer = new GenericContainer<>( | ||
new ImageFromDockerfile().withFileFromClasspath("server.js", "ws/server.js") | ||
.withFileFromClasspath("Dockerfile", "ws/Dockerfile")) | ||
.withExposedPorts(8080); | ||
|
||
wsContainer.start(); | ||
|
||
int port = wsContainer.getMappedPort(8080); | ||
websocketUri = URI.create("ws://localhost:" + port); | ||
} | ||
|
||
@BeforeEach | ||
public void setUp() { | ||
transport = WebSocketClientTransport.builder(websocketUri).build(); | ||
} | ||
|
||
@AfterAll | ||
static void tearDown() { | ||
wsContainer.stop(); | ||
} | ||
|
||
@Test | ||
void testConnectSuccessfully() { | ||
// Try to connect to the WebSocket server | ||
Mono<Void> connection = transport.connect(message -> Mono.empty()); | ||
|
||
// Wait for the connection to complete | ||
StepVerifier.create(connection).expectComplete().verify(); | ||
|
||
// Ensure that connection is established | ||
assertEquals(WebSocketClientTransport.TransportState.CONNECTED, transport.getState()); | ||
} | ||
|
||
@Test | ||
void testSendMessage() { | ||
// Connect to the server | ||
Mono<Void> connection = transport.connect(message -> Mono.empty()); | ||
|
||
// Ensure connection is successful | ||
StepVerifier.create(connection).expectComplete().verify(); | ||
|
||
// Create a simple message to send | ||
var messageRequest = new McpSchema.CreateMessageRequest( | ||
List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))), | ||
null, null, null, null, 0, null, null); | ||
McpSchema.JSONRPCMessage message = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, | ||
McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, "test-id", messageRequest); | ||
|
||
// Send a message to the server | ||
Mono<Void> sendMessage = transport.sendMessage(message); | ||
|
||
// Ensure message is sent successfully | ||
StepVerifier.create(sendMessage).expectComplete().verify(); | ||
} | ||
|
||
@Test | ||
void testCloseConnectionGracefully() { | ||
Mono<Void> connection = transport.connect(message -> Mono.empty()); | ||
|
||
StepVerifier.create(connection).expectComplete().verify(); | ||
|
||
// Close the connection gracefully | ||
Mono<Void> closeConnection = transport.closeGracefully(); | ||
|
||
// Verify that the connection is closed successfully | ||
StepVerifier.create(closeConnection).expectComplete().verify(); | ||
|
||
assertEquals(WebSocketClientTransport.TransportState.CLOSED, transport.getState()); | ||
} | ||
|
||
@Test | ||
void testSendMessageAfterConnectionClosed() { | ||
// Send a message before connection is established | ||
// Create a simple message to send | ||
var messageRequest = new McpSchema.CreateMessageRequest( | ||
List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))), | ||
null, null, null, null, 0, null, null); | ||
McpSchema.JSONRPCMessage message = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, | ||
McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, "test-id", messageRequest); | ||
|
||
Mono<Void> sendMessageBeforeConnect = transport.sendMessage(message); | ||
|
||
// Verify that the transport returns an error because the connection is closed | ||
StepVerifier.create(sendMessageBeforeConnect).expectError(IllegalStateException.class).verify(); | ||
} | ||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Use a Node.js base image | ||
FROM node:14 | ||
|
||
# Set the working directory inside the container | ||
WORKDIR /usr/src/app | ||
|
||
# Copy the server.js file into the container | ||
COPY server.js /usr/src/app/ | ||
|
||
# Install dependencies (e.g., the ws package) | ||
RUN npm init -y && npm install ws | ||
|
||
# Expose the port for WebSocket (e.g., 8080) | ||
EXPOSE 8080 | ||
|
||
# Command to run the WebSocket server | ||
CMD ["node", "server.js"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
// Import the WebSocket package | ||
const WebSocket = require('ws'); | ||
|
||
// Set up the WebSocket server to listen on port 8080 | ||
const wss = new WebSocket.Server({ port: 8080 }); | ||
|
||
// When a new WebSocket connection is established | ||
wss.on('connection', function connection(ws) { | ||
console.log('New client connected'); | ||
|
||
// When a message is received from the client | ||
ws.on('message', function incoming(message) { | ||
console.log('received: %s', message); | ||
}); | ||
|
||
// Send a welcome message to the client | ||
ws.send('Welcome to the WebSocket server!'); | ||
}); | ||
|
||
// Log the WebSocket server start | ||
console.log('WebSocket server is listening on ws://localhost:8080'); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lots of object mappers defined in lots of files. It would seem like it might be time to build a Jackson module and lock in on one that can be used across the code base.