From cdd48df449f9f3f1c6c30c4a585fef2a0e309f2f Mon Sep 17 00:00:00 2001 From: Scott Lewis Date: Fri, 1 Aug 2025 15:27:16 -0700 Subject: [PATCH 01/25] Initial checkin to address feature request https://github.com/modelcontextprotocol/java-sdk/issues/415 i.mcp.client/server.transport packages contains new transport providers classes : UDSClient/ServerTransportProvider. The name UDS refers to UnixDomainSocket as that's the SocketChannel type being used. These transport providers use the new classses in util: UDSClientNonBlockingSocketChannel and UDSServerNonBlockingSocketChannel. These further depend upon super classes ClientNonBlockingSocketChannel and ServerNonBlockingSocketChannel which both depend upon superclass NonBlockSocketChannel, which has most of the actual implementation of the single-threaded/Selector based non-blocking read and write. This subclass/superclass structure means that Inet4 and Inet6 client/server SocketChannel classes are also present. These will work just the same as the UDSClient/ServerSocketChannel classses but rather will use Inet4 and Inet6 connections rather than UnixDomainSockets. It will be very easy to create server/client transport providers that use inet4 or inet6 tcp stacks for localhost or non localhost connections. But for the moment, I've only created UDSServer/ClientTransportProviders for testing and review. Signed-off-by: Scott Lewis --- .../transport/UDSClientTransportProvider.java | 183 +++++++++ .../transport/UDSServerTransportProvider.java | 239 +++++++++++ .../util/ClientNonBlockingSocketChannel.java | 74 ++++ .../Inet4ClientNonBlockingSocketChannel.java | 36 ++ .../Inet4ServerNonBlockingSocketChannel.java | 34 ++ .../Inet6ClientNonBlockingSocketChannel.java | 36 ++ .../Inet6ServerNonBlockingSocketChannel.java | 34 ++ .../util/NonBlockingSocketChannel.java | 387 ++++++++++++++++++ .../util/ServerNonBlockingSocketChannel.java | 96 +++++ .../UDSClientNonBlockingSocketChannel.java | 34 ++ .../UDSServerNonBlockingSocketChannel.java | 33 ++ 11 files changed, 1186 insertions(+) create mode 100644 mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/ClientNonBlockingSocketChannel.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/Inet4ClientNonBlockingSocketChannel.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/Inet4ServerNonBlockingSocketChannel.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/Inet6ClientNonBlockingSocketChannel.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/Inet6ServerNonBlockingSocketChannel.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/NonBlockingSocketChannel.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/ServerNonBlockingSocketChannel.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/UDSClientNonBlockingSocketChannel.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/UDSServerNonBlockingSocketChannel.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java new file mode 100644 index 000000000..bc364dea8 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java @@ -0,0 +1,183 @@ +package io.modelcontextprotocol.client.transport; + +import java.io.IOException; +import java.net.UnixDomainSocketAddress; +import java.time.Duration; +import java.util.concurrent.Executors; +import java.util.function.Function; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.UDSClientNonBlockingSocketChannel; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; + +public class UDSClientTransportProvider implements McpClientTransport { + + private static final Logger logger = LoggerFactory.getLogger(UDSClientTransportProvider.class); + + private final Sinks.Many inboundSink; + + private final Sinks.Many outboundSink; + + private ObjectMapper objectMapper; + + private UDSClientNonBlockingSocketChannel clientChannel; + + private UnixDomainSocketAddress targetAddress; + + private Scheduler outboundScheduler; + + private volatile boolean isClosing = false; + + public UDSClientTransportProvider(ObjectMapper objectMapper, UnixDomainSocketAddress targetAddress) + throws IOException { + Assert.notNull(objectMapper, "The ObjectMapper can not be null"); + + this.inboundSink = Sinks.many().unicast().onBackpressureBuffer(); + this.outboundSink = Sinks.many().unicast().onBackpressureBuffer(); + + this.objectMapper = objectMapper; + + // Start threads + this.outboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), "outbound"); + this.clientChannel = new UDSClientNonBlockingSocketChannel(); + this.targetAddress = targetAddress; + } + + @Override + public Mono connect(Function, Mono> handler) { + return Mono.fromRunnable(() -> { + handleIncomingMessages(handler); + try { + this.clientChannel.connectBlocking(targetAddress, (client) -> { + logger.info("CONNECTED to targetAddress=" + targetAddress); + }, (data) -> { + JSONRPCMessage json = McpSchema.deserializeJsonRpcMessage(this.objectMapper, data); + if (!this.inboundSink.tryEmitNext(json).isSuccess()) { + if (!isClosing) { + logger.error("Failed to enqueue inbound message: {}", json); + } + } + }); + } + catch (IOException e) { + this.clientChannel.close(); + throw new RuntimeException( + "Connect to address=" + targetAddress + " failed message: " + e.getMessage()); + } + startOutboundProcessing(); + }).subscribeOn(Schedulers.boundedElastic()); + } + + private void handleIncomingMessages(Function, Mono> inboundMessageHandler) { + this.inboundSink.asFlux() + .flatMap(message -> Mono.just(message) + .transform(inboundMessageHandler) + .contextWrite(ctx -> ctx.put("observation", "myObservation"))) + .subscribe(); + } + + @Override + public Mono sendMessage(JSONRPCMessage message) { + if (this.outboundSink.tryEmitNext(message).isSuccess()) { + // TODO: essentially we could reschedule ourselves in some time and make + // another attempt with the already read data but pause reading until + // success + // In this approach we delegate the retry and the backpressure onto the + // caller. This might be enough for most cases. + return Mono.empty(); + } + else { + return Mono.error(new RuntimeException("Failed to enqueue message")); + } + } + + private void startOutboundProcessing() { + this.handleOutbound(messages -> messages + // this bit is important since writes come from user threads, and we + // want to ensure that the actual writing happens on a dedicated thread + .publishOn(outboundScheduler) + .handle((message, s) -> { + if (message != null && !isClosing) { + try { + this.clientChannel.writeMessageBlocking(objectMapper.writeValueAsString(message)); + s.next(message); + } + catch (IOException e) { + s.error(new RuntimeException(e)); + } + } + })); + } + + protected void handleOutbound(Function, Flux> outboundConsumer) { + outboundConsumer.apply(outboundSink.asFlux()).doOnComplete(() -> { + isClosing = true; + outboundSink.tryEmitComplete(); + }).doOnError(e -> { + if (!isClosing) { + logger.error("Error in outbound processing", e); + isClosing = true; + outboundSink.tryEmitComplete(); + } + }).subscribe(); + } + + /** + * Gracefully closes the transport by destroying the process and disposing of the + * schedulers. This method sends a TERM signal to the process and waits for it to exit + * before cleaning up resources. + * @return A Mono that completes when the transport is closed + */ + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + isClosing = true; + logger.debug("Initiating graceful shutdown"); + }).then(Mono.defer(() -> { + // First complete all sinks to stop accepting new messages + inboundSink.tryEmitComplete(); + outboundSink.tryEmitComplete(); + // Give a short time for any pending messages to be processed + return Mono.delay(Duration.ofMillis(100)).then(); + })).then(Mono.defer(() -> { + // Close our clientChannel + if (this.clientChannel != null) { + this.clientChannel.close(); + this.clientChannel = null; + } + return Mono.empty(); + })).doOnNext(o -> { + logger.info("MCP server process stopped"); + }).then(Mono.fromRunnable(() -> { + try { + // The Threads are blocked on readLine so disposeGracefully would not + // interrupt them, therefore we issue an async hard dispose. + outboundScheduler.dispose(); + + logger.debug("Graceful shutdown completed"); + } + catch (Exception e) { + logger.error("Error during graceful shutdown", e); + } + })).then().subscribeOn(Schedulers.boundedElastic()); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return this.objectMapper.convertValue(data, typeRef); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java new file mode 100644 index 000000000..d6e76a56a --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java @@ -0,0 +1,239 @@ +package io.modelcontextprotocol.server.transport; + +import java.io.IOException; +import java.net.UnixDomainSocketAddress; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpServerTransport; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.UDSServerNonBlockingSocketChannel; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; + +public class UDSServerTransportProvider implements McpServerTransportProvider { + + private static final Logger logger = LoggerFactory.getLogger(UDSServerTransportProvider.class); + + private final ObjectMapper objectMapper; + + private McpServerSession session; + + private final AtomicBoolean isClosing = new AtomicBoolean(false); + + private final Sinks.One inboundReady = Sinks.one(); + + private UDSServerNonBlockingSocketChannel serverSocketChannel; + + private UnixDomainSocketAddress address; + + private UDSMcpSessionTransport transport; + + public UDSServerTransportProvider(ObjectMapper objectMapper, UnixDomainSocketAddress unixSocketAddress) + throws IOException { + Assert.notNull(objectMapper, "The ObjectMapper can not be null"); + this.objectMapper = objectMapper; + this.address = unixSocketAddress; + this.serverSocketChannel = new UDSServerNonBlockingSocketChannel(); + } + + @Override + public void setSessionFactory(McpServerSession.Factory sessionFactory) { + try { + this.serverSocketChannel.start(this.address, (clientChannel) -> { + this.transport = new UDSMcpSessionTransport(); + this.session = sessionFactory.create(transport); + this.transport.initProcessing(); + }, (dataLine) -> { + String message = (String) dataLine; + try { + this.transport + .handleMessage(McpSchema.deserializeJsonRpcMessage(this.objectMapper, message.trim())); + } + catch (IOException e) { + this.serverSocketChannel.close(); + } + }); + } + catch (IOException e) { + this.serverSocketChannel.close(); + throw new RuntimeException("accepterNonBlockSocketChannel could not be started"); + } + } + + @Override + public Mono notifyClients(String method, Object params) { + if (this.session == null) { + return Mono.error(new McpError("No session to close")); + } + return this.session.sendNotification(method, params) + .doOnError(e -> logger.error("Failed to send notification: {}", e.getMessage())); + } + + @Override + public Mono closeGracefully() { + if (this.session == null) { + return Mono.empty(); + } + return this.session.closeGracefully(); + } + + /** + * Implementation of McpServerTransport for the stdio session. + */ + private class UDSMcpSessionTransport implements McpServerTransport { + + private final Sinks.Many inboundSink; + + private final Sinks.Many outboundSink; + + private final AtomicBoolean isStarted = new AtomicBoolean(false); + + /** Scheduler for handling outbound messages */ + private Scheduler outboundScheduler; + + private final Sinks.One outboundReady = Sinks.one(); + + public UDSMcpSessionTransport() { + + this.inboundSink = Sinks.many().unicast().onBackpressureBuffer(); + this.outboundSink = Sinks.many().unicast().onBackpressureBuffer(); + + // Use bounded schedulers for better resource management + // this.inboundScheduler = + // Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), + // "uds-inbound"); + this.outboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), + "uds-outbound"); + } + + public void handleMessage(McpSchema.JSONRPCMessage json) throws IOException { + try { + if (!this.inboundSink.tryEmitNext(json).isSuccess()) { + throw new Exception("Failed to enqueue message"); + } + } + catch (Exception e) { + logIfNotClosing("Error processing inbound message", e); + throw new IOException("Error in processing inbound message", e); + } + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + + return Mono.zip(inboundReady.asMono(), outboundReady.asMono()).then(Mono.defer(() -> { + if (outboundSink.tryEmitNext(message).isSuccess()) { + return Mono.empty(); + } + else { + return Mono.error(new RuntimeException("Failed to enqueue message")); + } + })); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + isClosing.set(true); + logger.debug("Session transport closing gracefully"); + inboundSink.tryEmitComplete(); + }); + } + + @Override + public void close() { + isClosing.set(true); + serverSocketChannel.close(); + logger.debug("Session transport closed"); + } + + private void initProcessing() { + handleIncomingMessages(); + if (isStarted.compareAndSet(false, true)) { + inboundReady.tryEmitValue(null); + } + startOutboundProcessing(); + } + + private void handleIncomingMessages() { + this.inboundSink.asFlux().flatMap(message -> session.handle(message)).doOnTerminate(() -> { + // The outbound processing will dispose its scheduler upon completion + this.outboundSink.tryEmitComplete(); + // this.inboundScheduler.dispose(); + }).subscribe(); + } + + /** + * Starts the outbound processing thread that writes JSON-RPC messages to stdout. + * Messages are serialized to JSON and written with a newline delimiter. + */ + private void startOutboundProcessing() { + Function, Flux> outboundConsumer = messages -> messages // @formatter:off + .doOnSubscribe(subscription -> outboundReady.tryEmitValue(null)) + .publishOn(outboundScheduler) + .handle((message, sink) -> { + if (message != null && !isClosing.get()) { + try { + serverSocketChannel.writeMessageBlocking(objectMapper.writeValueAsString(message)); + sink.next(message); + } + catch (IOException e) { + if (!isClosing.get()) { + logger.error("Error writing message", e); + sink.error(new RuntimeException(e)); + } + else { + logger.debug("Stream closed during shutdown", e); + } + } + } + else if (isClosing.get()) { + sink.complete(); + } + }) + .doOnComplete(() -> { + isClosing.set(true); + outboundScheduler.dispose(); + }) + .doOnError(e -> { + if (!isClosing.get()) { + logger.error("Error in outbound processing", e); + isClosing.set(true); + outboundScheduler.dispose(); + } + }) + .map(msg -> (JSONRPCMessage) msg); + + outboundConsumer.apply(outboundSink.asFlux()).subscribe(); + } // @formatter:on + + private void logIfNotClosing(String message, Exception e) { + if (!isClosing.get()) { + logger.error(message, e); + } + } + + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/ClientNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/ClientNonBlockingSocketChannel.java new file mode 100644 index 000000000..65e976f53 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/ClientNonBlockingSocketChannel.java @@ -0,0 +1,74 @@ +package io.modelcontextprotocol.util; + +import java.io.IOException; +import java.net.SocketAddress; +import java.net.StandardProtocolFamily; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.util.concurrent.ExecutorService; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ClientNonBlockingSocketChannel extends NonBlockingSocketChannel { + + private static final Logger logger = LoggerFactory.getLogger(ClientNonBlockingSocketChannel.class); + + private SocketChannel client; + + public ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) + throws IOException { + super(selector, incomingBufferSize, executor); + } + + public ClientNonBlockingSocketChannel() throws IOException { + super(); + } + + public ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize) { + super(selector, incomingBufferSize); + } + + public ClientNonBlockingSocketChannel(Selector selector) { + super(selector); + } + + public void connectBlocking(StandardProtocolFamily protocol, SocketAddress address, + IOConsumer connectHandler, IOConsumer readHandler) throws IOException { + if (this.client != null) { + throw new IOException("Already connected"); + } + this.client = connectBlocking(SocketChannel.open(protocol), address, connectHandler, readHandler); + } + + @Override + protected void handleException(SelectionKey key, Exception e) { + if (logger.isDebugEnabled()) { + logger.debug("handleException", e); + } + close(); + } + + @Override + public void close() { + try { + hardCloseClient(this.client, (client) -> { + this.client = null; + }); + } + catch (IOException e) { + if (logger.isDebugEnabled()) { + logger.debug("Exception in hardCloseClient", e); + } + } + } + + public void writeMessageBlocking(String message) throws IOException { + if (this.client == null) { + throw new IOException("Cannot write until client connected"); + } + writeBlocking(client, message); + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Inet4ClientNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/Inet4ClientNonBlockingSocketChannel.java new file mode 100644 index 000000000..650fd52b4 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/Inet4ClientNonBlockingSocketChannel.java @@ -0,0 +1,36 @@ +package io.modelcontextprotocol.util; + +import java.io.IOException; +import java.net.Inet4Address; +import java.net.InetSocketAddress; +import java.net.StandardProtocolFamily; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.util.concurrent.ExecutorService; + +public class Inet4ClientNonBlockingSocketChannel extends ClientNonBlockingSocketChannel { + + public Inet4ClientNonBlockingSocketChannel() throws IOException { + super(); + } + + public Inet4ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) + throws IOException { + super(selector, incomingBufferSize, executor); + } + + public Inet4ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize) throws IOException { + super(selector, incomingBufferSize); + } + + public Inet4ClientNonBlockingSocketChannel(Selector selector) throws IOException { + super(selector); + } + + public void connectBlocking(Inet4Address address, int port, IOConsumer connectHandler, + IOConsumer readHandler) throws IOException { + super.connectBlocking(StandardProtocolFamily.INET, new InetSocketAddress(address, port), connectHandler, + readHandler); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Inet4ServerNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/Inet4ServerNonBlockingSocketChannel.java new file mode 100644 index 000000000..a4b9c61f8 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/Inet4ServerNonBlockingSocketChannel.java @@ -0,0 +1,34 @@ +package io.modelcontextprotocol.util; + +import java.io.IOException; +import java.net.Inet4Address; +import java.net.InetSocketAddress; +import java.net.StandardProtocolFamily; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.util.concurrent.ExecutorService; + +public class Inet4ServerNonBlockingSocketChannel extends ServerNonBlockingSocketChannel { + + public Inet4ServerNonBlockingSocketChannel() throws IOException { + super(); + } + + public Inet4ServerNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { + super(selector, incomingBufferSize, executor); + } + + public Inet4ServerNonBlockingSocketChannel(Selector selector, int incomingBufferSize) { + super(selector, incomingBufferSize); + } + + public Inet4ServerNonBlockingSocketChannel(Selector selector) { + super(selector); + } + + public void start(Inet4Address address, int port, IOConsumer acceptHandler, + IOConsumer readHandler) throws IOException { + super.start(StandardProtocolFamily.INET, new InetSocketAddress(address, port), acceptHandler, readHandler); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Inet6ClientNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/Inet6ClientNonBlockingSocketChannel.java new file mode 100644 index 000000000..da8739758 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/Inet6ClientNonBlockingSocketChannel.java @@ -0,0 +1,36 @@ +package io.modelcontextprotocol.util; + +import java.io.IOException; +import java.net.Inet6Address; +import java.net.InetSocketAddress; +import java.net.StandardProtocolFamily; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.util.concurrent.ExecutorService; + +public class Inet6ClientNonBlockingSocketChannel extends ClientNonBlockingSocketChannel { + + public Inet6ClientNonBlockingSocketChannel() throws IOException { + super(); + } + + public Inet6ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) + throws IOException { + super(selector, incomingBufferSize, executor); + } + + public Inet6ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize) throws IOException { + super(selector, incomingBufferSize); + } + + public Inet6ClientNonBlockingSocketChannel(Selector selector) throws IOException { + super(selector); + } + + public void connectBlocking(Inet6Address address, int port, IOConsumer connectHandler, + IOConsumer readHandler) throws IOException { + super.connectBlocking(StandardProtocolFamily.INET6, new InetSocketAddress(address, port), connectHandler, + readHandler); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Inet6ServerNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/Inet6ServerNonBlockingSocketChannel.java new file mode 100644 index 000000000..8a1a95e27 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/Inet6ServerNonBlockingSocketChannel.java @@ -0,0 +1,34 @@ +package io.modelcontextprotocol.util; + +import java.io.IOException; +import java.net.Inet6Address; +import java.net.InetSocketAddress; +import java.net.StandardProtocolFamily; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.util.concurrent.ExecutorService; + +public class Inet6ServerNonBlockingSocketChannel extends ServerNonBlockingSocketChannel { + + public Inet6ServerNonBlockingSocketChannel() throws IOException { + super(); + } + + public Inet6ServerNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { + super(selector, incomingBufferSize, executor); + } + + public Inet6ServerNonBlockingSocketChannel(Selector selector, int incomingBufferSize) { + super(selector, incomingBufferSize); + } + + public Inet6ServerNonBlockingSocketChannel(Selector selector) { + super(selector); + } + + public void start(Inet6Address address, int port, IOConsumer acceptHandler, + IOConsumer readHandler) throws IOException { + super.start(StandardProtocolFamily.INET6, new InetSocketAddress(address, port), acceptHandler, readHandler); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/NonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/NonBlockingSocketChannel.java new file mode 100644 index 000000000..b3b46e13c --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/NonBlockingSocketChannel.java @@ -0,0 +1,387 @@ +package io.modelcontextprotocol.util; + +import java.io.IOException; +import java.io.InterruptedIOException; +import java.net.SocketAddress; +import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; +import java.nio.charset.StandardCharsets; +import java.util.Iterator; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public abstract class NonBlockingSocketChannel { + + private static final Logger logger = LoggerFactory.getLogger(NonBlockingSocketChannel.class); + + public static final int DEFAULT_INBUFFER_SIZE = 1024; + + protected static String MESSAGE_DELIMITER = "\n"; + + protected static int BLOCKING_WRITE_TIMEOUT = 5000; + + protected static int BLOCKING_CONNECT_TIMEOUT = 10000; + + protected final Selector selector; + + protected final ByteBuffer inBuffer; + + protected final ExecutorService executor; + + @FunctionalInterface + public interface IOConsumer { + + void apply(T t) throws IOException; + + } + + protected class AttachedIO { + + public ByteBuffer writing; + + public StringBuffer reading; + + } + + public NonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { + Assert.notNull(selector, "Selector must not be null"); + this.selector = selector; + this.inBuffer = ByteBuffer.allocate(incomingBufferSize); + this.executor = (executor == null) ? Executors.newSingleThreadExecutor() : executor; + } + + public NonBlockingSocketChannel(Selector selector, int incomingBufferSize) { + this(selector, incomingBufferSize, null); + } + + public NonBlockingSocketChannel(Selector selector) { + this(selector, DEFAULT_INBUFFER_SIZE); + } + + public NonBlockingSocketChannel() throws IOException { + this(Selector.open()); + } + + protected Runnable getRunnableForProcessing(IOConsumer acceptHandler, + IOConsumer connectHandler, IOConsumer readHandler) { + return () -> { + SelectionKey key = null; + try { + while (true) { + this.selector.select(); + Set selectedKeys = selector.selectedKeys(); + Iterator iter = selectedKeys.iterator(); + while (iter.hasNext()) { + key = iter.next(); + if (key.isConnectable()) { + handleConnectable(key, connectHandler); + } + else if (key.isAcceptable()) { + handleAcceptable(key, acceptHandler); + } + else if (key.isReadable()) { + handleReadable(key, readHandler); + } + else if (key.isWritable()) { + handleWritable(key); + } + iter.remove(); + } + } + } + catch (Exception e) { + handleException(key, e); + } + }; + } + + public abstract void close(); + + protected abstract void handleException(SelectionKey key, Exception e); + + protected void start(IOConsumer acceptHandler, IOConsumer connectHandler, + IOConsumer readHandler) throws IOException { + this.executor.execute(getRunnableForProcessing(acceptHandler, connectHandler, readHandler)); + } + + // For client subclasses + protected void handleConnectable(SelectionKey key, IOConsumer connectHandler) throws IOException { + SocketChannel client = (SocketChannel) key.channel(); + Object lock = client.blockingLock(); + if (logger.isDebugEnabled()) { + logger.debug("handleConnectable client=" + client.getRemoteAddress()); + } + synchronized (lock) { + client.configureBlocking(false); + client.register(this.selector, SelectionKey.OP_READ, new AttachedIO()); + if (client.isConnectionPending()) { + client.finishConnect(); + if (logger.isDebugEnabled()) { + logger.debug("handleConnectable FINISHED"); + } + } + if (connectHandler != null) { + connectHandler.apply(client); + } + } + } + + protected void handleAcceptable(SelectionKey key, IOConsumer acceptHandler) throws IOException { + ServerSocketChannel serverSocket = (ServerSocketChannel) key.channel(); + SocketChannel client = serverSocket.accept(); + Object lock = client.blockingLock(); + if (logger.isDebugEnabled()) { + logger.debug("handleAcceptable client=" + client); + } + synchronized (lock) { + client.configureBlocking(false); + client.register(this.selector, SelectionKey.OP_READ, new AttachedIO()); + configureAcceptSocketChannel(client); + if (client.isConnectionPending()) { + client.finishConnect(); + if (logger.isDebugEnabled()) { + logger.debug("handleAcceptable FINISHED"); + } + } + if (acceptHandler != null) { + acceptHandler.apply(client); + } + } + } + + protected void configureAcceptSocketChannel(SocketChannel client) throws IOException { + // Subclasses may override + } + + protected AttachedIO getAttachedIO(SelectionKey key) throws IOException { + AttachedIO io = (AttachedIO) key.attachment(); + if (io == null) { + throw new IOException("No AttachedIO object found on key"); + } + return io; + } + + protected void handleReadable(SelectionKey key, IOConsumer readHandler) throws IOException { + SocketChannel client = (SocketChannel) key.channel(); + Object lock = client.blockingLock(); + AttachedIO io = getAttachedIO(key); + if (logger.isDebugEnabled()) { + logger.debug("handleReadable client=" + client); + } + synchronized (lock) { + // non-blocking read here + int r = client.read(this.inBuffer); + // Check if we should expect any more reads + if (r == -1) { + throw new IOException("Channel read reached end of stream"); + } + this.inBuffer.flip(); + String partial = new String(this.inBuffer.array(), 0, r, StandardCharsets.UTF_8); + // If there is are previous partial, then get the io.reading string Buffer + StringBuffer sb = (io.reading != null) ? (StringBuffer) io.reading : new StringBuffer(); + // And append the just read partial to the string buffer + sb.append(partial); + if (partial.endsWith(MESSAGE_DELIMITER)) { + // Get the entire message from the string buffer + String message = sb.toString(); + // Set the io.reading value to null as we are done with this message + io.reading = null; + if (logger.isDebugEnabled()) { + logger.debug("handleReadable COMPLETE msg=" + message); + } + if (readHandler != null) { + readHandler.apply(message); + } + } + else { + io.reading = sb; + if (logger.isDebugEnabled()) { + logger.debug("handleReadable PARTIAL msg=" + partial); + } + } + } + // Clear inbuffer for next read + this.inBuffer.clear(); + } + + protected void handleWritable(SelectionKey key) throws IOException { + ByteBuffer buf = getAttachedIO(key).writing; + SocketChannel client = (SocketChannel) key.channel(); + if (buf != null) { + doWrite(key, client, buf, (lock) -> { + synchronized (lock) { + if (logger.isDebugEnabled()) { + logger.debug("handleWritable NOTIFY client=" + client); + } + lock.notify(); + } + }); + } + } + + protected void doWrite(SocketChannel client, String message, IOConsumer writeHandler) throws IOException { + Assert.notNull(client, "Client must not be null"); + Assert.notNull(message, "Message must not be null"); + if (logger.isDebugEnabled()) { + logger.debug("doWrite msg=" + message); + } + doWrite(client.keyFor(this.selector), client, ByteBuffer.wrap(message.getBytes(StandardCharsets.UTF_8)), + writeHandler); + } + + protected void doWrite(SelectionKey key, SocketChannel client, ByteBuffer buf, IOConsumer writeHandler) + throws IOException { + AttachedIO io = (AttachedIO) key.attachment(); + Object lock = client.blockingLock(); + synchronized (lock) { + int written = client.write(buf); + if (buf.hasRemaining()) { + if (logger.isDebugEnabled()) { + logger.debug("doWrite PARTIAL written=" + written + " remaining=" + buf.remaining()); + } + io.writing = buf.slice(); + key.interestOpsOr(SelectionKey.OP_WRITE); + } + else { + if (logger.isDebugEnabled()) { + logger.debug("doWrite COMPLETED msg=" + new String(buf.array(), 0, written)); + } + io.writing = null; + key.interestOps(SelectionKey.OP_READ); + if (writeHandler != null) { + writeHandler.apply(lock); + } + } + } + } + + protected void executorShutdown() { + if (!this.executor.isShutdown()) { + if (logger.isDebugEnabled()) { + logger.debug("executorShutdown"); + } + try { + this.executor.awaitTermination(2000, TimeUnit.MILLISECONDS); + this.executor.shutdown(); + } + catch (InterruptedException e) { + if (logger.isDebugEnabled()) { + logger.debug("Exception in executor awaitTermination", e); + } + } + } + } + + protected void hardCloseClient(SocketChannel client, IOConsumer closeHandler) throws IOException { + if (client != null) { + Object lock = client.blockingLock(); + if (logger.isDebugEnabled()) { + logger.debug("hardCloseClient client=" + client); + } + synchronized (lock) { + if (closeHandler != null) { + closeHandler.apply(client); + } + client.close(); + } + executorShutdown(); + } + } + + protected void writeBlocking(SocketChannel client, String message) throws IOException { + Objects.requireNonNull(client, "Client must not be null"); + Objects.requireNonNull(message, "Message must not be null"); + // Escape any embedded newlines in the JSON message, and add newline + String outputMessage = message.replace("\r\n", "\\n") + .replace("\n", "\\n") + .replace("\r", "\\n") + .concat(MESSAGE_DELIMITER); + Object lock = client.blockingLock(); + if (logger.isDebugEnabled()) { + logger.debug("writeBlocking msg=" + outputMessage); + } + synchronized (lock) { + // do the non blocking write in thread while holding lock. + doWrite(client, outputMessage, null); + ByteBuffer bufRemaining = null; + long waitTime = System.currentTimeMillis() + BLOCKING_WRITE_TIMEOUT; + while (waitTime - System.currentTimeMillis() > 0) { + // Before releasing lock, check for writing buffer remaining + bufRemaining = getAttachedIO(client.keyFor(this.selector)).writing; + if (bufRemaining == null || bufRemaining.remaining() == 0) { + // It's done + break; + } + // If write is *not* completed, then wait timeout /10 + try { + if (logger.isDebugEnabled()) { + logger + .debug("writeBlocking WAITING=" + String.valueOf(waitTime / 10) + " msg=" + outputMessage); + } + lock.wait(waitTime / 10); + } + catch (InterruptedException e) { + throw new InterruptedIOException("write message wait interrupted"); + } + } + if (bufRemaining != null && bufRemaining.remaining() > 0) { + throw new IOException("Write not completed. Non empty buffer remaining after timeout"); + } + } + if (logger.isDebugEnabled()) { + logger.debug("writeBlocking COMPLETED msg=" + outputMessage); + } + } + + protected void configureConnectSocketChannel(SocketChannel client, SocketAddress connectAddress) + throws IOException { + // Subclasses may override + } + + protected SocketChannel connectBlocking(SocketChannel client, SocketAddress address, + IOConsumer connectHandler, IOConsumer readHandler) throws IOException { + Object lock = client.blockingLock(); + if (logger.isDebugEnabled()) { + logger.debug("connectBlocking CONNECTING targetAddress=" + address); + } + synchronized (lock) { + client.configureBlocking(false); + client.register(selector, SelectionKey.OP_CONNECT); + configureConnectSocketChannel(client, address); + // Start the read thread before connect + // No/null accept handler for clients + start(null, (c) -> { + if (connectHandler != null) { + connectHandler.apply(c); + } + lock.notify(); + }, readHandler); + + client.connect(address); + + try { + if (logger.isDebugEnabled()) { + logger.debug("connectBlocking WAITING targetAddress=" + address); + } + lock.wait(BLOCKING_CONNECT_TIMEOUT); + } + catch (InterruptedException e) { + throw new IOException("Connect to address=" + address + " timed out"); + } + if (logger.isDebugEnabled()) { + logger.debug("connectBlocking CONNECTED client=" + client.getLocalAddress() + " connecting=" + address); + } + return client; + } + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/ServerNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/ServerNonBlockingSocketChannel.java new file mode 100644 index 000000000..918635012 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/ServerNonBlockingSocketChannel.java @@ -0,0 +1,96 @@ +package io.modelcontextprotocol.util; + +import java.io.IOException; +import java.net.SocketAddress; +import java.net.StandardProtocolFamily; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; +import java.util.concurrent.ExecutorService; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ServerNonBlockingSocketChannel extends NonBlockingSocketChannel { + + private static final Logger logger = LoggerFactory.getLogger(ServerNonBlockingSocketChannel.class); + + protected SocketChannel acceptedClient; + + public ServerNonBlockingSocketChannel() throws IOException { + super(); + } + + public ServerNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { + super(selector, incomingBufferSize, executor); + } + + public ServerNonBlockingSocketChannel(Selector selector, int incomingBufferSize) { + super(selector, incomingBufferSize); + } + + public ServerNonBlockingSocketChannel(Selector selector) { + super(selector); + } + + protected void configureServerSocketChannel(ServerSocketChannel serverSocketChannel, SocketAddress acceptAddress) { + // Subclasses may override + } + + public void start(StandardProtocolFamily protocol, SocketAddress address, IOConsumer acceptHandler, + IOConsumer readHandler) throws IOException { + ServerSocketChannel serverChannel = ServerSocketChannel.open(protocol); + serverChannel.configureBlocking(false); + serverChannel.register(this.selector, SelectionKey.OP_ACCEPT); + configureServerSocketChannel(serverChannel, address); + serverChannel.bind(address); + // Start thread/processing of incoming accept, read + super.start((client) -> { + if (logger.isDebugEnabled()) { + logger.debug("Setting client=" + client); + } + this.acceptedClient = client; + if (acceptHandler != null) { + acceptHandler.apply(this.acceptedClient); + } + // No/null connect handler for Acceptors...only accepthandler + }, null, readHandler); + } + + @Override + protected void handleException(SelectionKey key, Exception e) { + if (logger.isDebugEnabled()) { + logger.debug("handleException", e); + } + close(); + } + + public void writeMessageBlocking(String message) throws IOException { + if (this.acceptedClient == null) { + throw new IOException("Cannot write until client connected"); + } + writeBlocking(acceptedClient, message); + } + + @Override + public void close() { + SocketChannel client = this.acceptedClient; + if (client != null) { + try { + hardCloseClient(client, (c) -> { + if (logger.isDebugEnabled()) { + logger.debug("Unsetting client=" + c); + } + this.acceptedClient = null; + }); + } + catch (IOException e) { + if (logger.isDebugEnabled()) { + logger.debug("Exception in hardCloseClient", e); + } + } + } + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/UDSClientNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/UDSClientNonBlockingSocketChannel.java new file mode 100644 index 000000000..ef16590b5 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/UDSClientNonBlockingSocketChannel.java @@ -0,0 +1,34 @@ +package io.modelcontextprotocol.util; + +import java.io.IOException; +import java.net.StandardProtocolFamily; +import java.net.UnixDomainSocketAddress; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.util.concurrent.ExecutorService; + +public class UDSClientNonBlockingSocketChannel extends ClientNonBlockingSocketChannel { + + public UDSClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) + throws IOException { + super(selector, incomingBufferSize, executor); + } + + public UDSClientNonBlockingSocketChannel() throws IOException { + super(); + } + + public UDSClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize) { + super(selector, incomingBufferSize); + } + + public UDSClientNonBlockingSocketChannel(Selector selector) { + super(selector); + } + + public void connectBlocking(UnixDomainSocketAddress address, IOConsumer connectHandler, + IOConsumer readHandler) throws IOException { + super.connectBlocking(StandardProtocolFamily.UNIX, address, connectHandler, readHandler); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/UDSServerNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/UDSServerNonBlockingSocketChannel.java new file mode 100644 index 000000000..259315712 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/UDSServerNonBlockingSocketChannel.java @@ -0,0 +1,33 @@ +package io.modelcontextprotocol.util; + +import java.io.IOException; +import java.net.StandardProtocolFamily; +import java.net.UnixDomainSocketAddress; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.util.concurrent.ExecutorService; + +public class UDSServerNonBlockingSocketChannel extends ServerNonBlockingSocketChannel { + + public UDSServerNonBlockingSocketChannel() throws IOException { + super(); + } + + public UDSServerNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { + super(selector, incomingBufferSize, executor); + } + + public UDSServerNonBlockingSocketChannel(Selector selector, int incomingBufferSize) { + super(selector, incomingBufferSize); + } + + public UDSServerNonBlockingSocketChannel(Selector selector) { + super(selector); + } + + public void start(UnixDomainSocketAddress address, IOConsumer acceptHandler, + IOConsumer readHandler) throws IOException { + super.start(StandardProtocolFamily.UNIX, address, acceptHandler, readHandler); + } + +} From 40b683e05cfd0a3f213ba8a2a0b621f4e241877c Mon Sep 17 00:00:00 2001 From: Scott Lewis Date: Tue, 5 Aug 2025 11:53:27 -0700 Subject: [PATCH 02/25] fix for synchronization --- .../util/NonBlockingSocketChannel.java | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/NonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/NonBlockingSocketChannel.java index b3b46e13c..5fc0eaf9d 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/NonBlockingSocketChannel.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/NonBlockingSocketChannel.java @@ -222,7 +222,7 @@ protected void handleWritable(SelectionKey key) throws IOException { if (logger.isDebugEnabled()) { logger.debug("handleWritable NOTIFY client=" + client); } - lock.notify(); + lock.notifyAll(); } }); } @@ -360,10 +360,12 @@ protected SocketChannel connectBlocking(SocketChannel client, SocketAddress addr // Start the read thread before connect // No/null accept handler for clients start(null, (c) -> { - if (connectHandler != null) { - connectHandler.apply(c); + synchronized (lock) { + if (connectHandler != null) { + connectHandler.apply(c); + } + lock.notifyAll(); } - lock.notify(); }, readHandler); client.connect(address); From d5ebbdb591e15957c68e7bb72f41fddafc439536 Mon Sep 17 00:00:00 2001 From: Scott Lewis Date: Wed, 6 Aug 2025 15:35:12 -0700 Subject: [PATCH 03/25] Added async and sync server tests (UDSMcpAsyncServerTests, UDSMcpSyncServerTest). Also made simplifying changes to *socketchannel classes --- .../server/UDSMcpAsyncServerTests.java | 52 ++++++++++++++++++ .../server/UDSMcpSyncServerTests.java | 53 +++++++++++++++++++ 2 files changed, 105 insertions(+) create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java new file mode 100644 index 000000000..35be2a355 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java @@ -0,0 +1,52 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.io.IOException; +import java.net.UnixDomainSocketAddress; +import java.nio.file.Files; + +import org.junit.jupiter.api.Timeout; + +import io.modelcontextprotocol.server.transport.UDSServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; + +/** + * Tests for {@link McpAsyncServer} using {@link StdioServerTransport}. + * + * @author Christian Tzolov + */ +@Timeout(15) // Giving extra time beyond the client timeout +class UDSMcpAsyncServerTests extends AbstractMcpAsyncServerTests { + + private UnixDomainSocketAddress address; + + @Override + protected void setUp() { + super.onStart(); + address = UnixDomainSocketAddress.of(getClass().getName() + ".unix.socket"); + } + + @Override + protected void tearDown() { + super.onClose(); + if (address != null) { + try { + Files.deleteIfExists(address.getPath()); + } catch (IOException e) { + } + } + } + + protected McpServerTransportProvider createMcpTransportProvider() { + return new UDSServerTransportProvider(address); + } + + @Override + protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(createMcpTransportProvider()); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java new file mode 100644 index 000000000..aa3666fbf --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java @@ -0,0 +1,53 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.io.IOException; +import java.net.UnixDomainSocketAddress; +import java.nio.file.Files; + +import org.junit.jupiter.api.Timeout; + +import io.modelcontextprotocol.server.transport.StdioServerTransportProvider; +import io.modelcontextprotocol.server.transport.UDSServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; + +/** + * Tests for {@link McpSyncServer} using {@link StdioServerTransportProvider}. + * + * @author Christian Tzolov and Scott Lewis + */ +@Timeout(15) // Giving extra time beyond the client timeout +class UDSMcpSyncServerTests extends AbstractMcpSyncServerTests { + + private UnixDomainSocketAddress address; + + @Override + protected void setUp() { + super.onStart(); + address = UnixDomainSocketAddress.of(getClass().getName()+".unix.socket"); + } + + @Override + protected void tearDown() { + super.onClose(); + if (address != null) { + try { + Files.deleteIfExists(address.getPath()); + } catch (IOException e) { + } + } + } + + protected McpServerTransportProvider createMcpTransportProvider() { + return new UDSServerTransportProvider(address); + } + + @Override + protected McpServer.SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(createMcpTransportProvider()); + } + +} From 317a78bfa64ebc40d0e02e12ae78a912c4fc11ad Mon Sep 17 00:00:00 2001 From: Scott Lewis Date: Wed, 6 Aug 2025 15:42:55 -0700 Subject: [PATCH 04/25] Fixes and simplification --- .../transport/UDSServerTransportProvider.java | 27 ++++++++++++------- .../util/ClientNonBlockingSocketChannel.java | 16 +++-------- .../Inet4ClientNonBlockingSocketChannel.java | 7 +++-- .../Inet6ClientNonBlockingSocketChannel.java | 7 +++-- .../util/NonBlockingSocketChannel.java | 16 ++++++++--- .../util/ServerNonBlockingSocketChannel.java | 15 +++-------- .../UDSClientNonBlockingSocketChannel.java | 3 +-- 7 files changed, 44 insertions(+), 47 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java index d6e76a56a..c1b677bb4 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java @@ -44,23 +44,33 @@ public class UDSServerTransportProvider implements McpServerTransportProvider { private UDSMcpSessionTransport transport; - public UDSServerTransportProvider(ObjectMapper objectMapper, UnixDomainSocketAddress unixSocketAddress) - throws IOException { + public UDSServerTransportProvider(UnixDomainSocketAddress unixSocketAddress) { + this(new ObjectMapper(), unixSocketAddress); + } + + public UDSServerTransportProvider(ObjectMapper objectMapper, UnixDomainSocketAddress unixSocketAddress) { Assert.notNull(objectMapper, "The ObjectMapper can not be null"); this.objectMapper = objectMapper; this.address = unixSocketAddress; - this.serverSocketChannel = new UDSServerNonBlockingSocketChannel(); } @Override public void setSessionFactory(McpServerSession.Factory sessionFactory) { + this.transport = new UDSMcpSessionTransport(); + this.session = sessionFactory.create(transport); + this.transport.initProcessing(); + // Also start listening for accept try { + this.serverSocketChannel = new UDSServerNonBlockingSocketChannel(); this.serverSocketChannel.start(this.address, (clientChannel) -> { - this.transport = new UDSMcpSessionTransport(); - this.session = sessionFactory.create(transport); - this.transport.initProcessing(); + if (logger.isDebugEnabled()) { + logger.debug("Accepted connect from clientChannel=" + clientChannel); + } }, (dataLine) -> { String message = (String) dataLine; + if (logger.isDebugEnabled()) { + logger.debug("Received message line=" + message); + } try { this.transport .handleMessage(McpSchema.deserializeJsonRpcMessage(this.objectMapper, message.trim())); @@ -71,6 +81,7 @@ public void setSessionFactory(McpServerSession.Factory sessionFactory) { }); } catch (IOException e) { + // If this happens then we are doomed this.serverSocketChannel.close(); throw new RuntimeException("accepterNonBlockSocketChannel could not be started"); } @@ -114,10 +125,6 @@ public UDSMcpSessionTransport() { this.inboundSink = Sinks.many().unicast().onBackpressureBuffer(); this.outboundSink = Sinks.many().unicast().onBackpressureBuffer(); - // Use bounded schedulers for better resource management - // this.inboundScheduler = - // Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), - // "uds-inbound"); this.outboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), "uds-outbound"); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/ClientNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/ClientNonBlockingSocketChannel.java index 65e976f53..33c7f5f7d 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/ClientNonBlockingSocketChannel.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/ClientNonBlockingSocketChannel.java @@ -17,8 +17,7 @@ public class ClientNonBlockingSocketChannel extends NonBlockingSocketChannel { private SocketChannel client; - public ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) - throws IOException { + public ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { super(selector, incomingBufferSize, executor); } @@ -52,16 +51,9 @@ protected void handleException(SelectionKey key, Exception e) { @Override public void close() { - try { - hardCloseClient(this.client, (client) -> { - this.client = null; - }); - } - catch (IOException e) { - if (logger.isDebugEnabled()) { - logger.debug("Exception in hardCloseClient", e); - } - } + hardCloseClient(this.client, (client) -> { + this.client = null; + }); } public void writeMessageBlocking(String message) throws IOException { diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Inet4ClientNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/Inet4ClientNonBlockingSocketChannel.java index 650fd52b4..b1186e3cd 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/Inet4ClientNonBlockingSocketChannel.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/Inet4ClientNonBlockingSocketChannel.java @@ -14,16 +14,15 @@ public Inet4ClientNonBlockingSocketChannel() throws IOException { super(); } - public Inet4ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) - throws IOException { + public Inet4ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { super(selector, incomingBufferSize, executor); } - public Inet4ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize) throws IOException { + public Inet4ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize) { super(selector, incomingBufferSize); } - public Inet4ClientNonBlockingSocketChannel(Selector selector) throws IOException { + public Inet4ClientNonBlockingSocketChannel(Selector selector) { super(selector); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Inet6ClientNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/Inet6ClientNonBlockingSocketChannel.java index da8739758..9af484858 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/Inet6ClientNonBlockingSocketChannel.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/Inet6ClientNonBlockingSocketChannel.java @@ -14,16 +14,15 @@ public Inet6ClientNonBlockingSocketChannel() throws IOException { super(); } - public Inet6ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) - throws IOException { + public Inet6ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { super(selector, incomingBufferSize, executor); } - public Inet6ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize) throws IOException { + public Inet6ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize) { super(selector, incomingBufferSize); } - public Inet6ClientNonBlockingSocketChannel(Selector selector) throws IOException { + public Inet6ClientNonBlockingSocketChannel(Selector selector) { super(selector); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/NonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/NonBlockingSocketChannel.java index 5fc0eaf9d..de3fe8ba7 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/NonBlockingSocketChannel.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/NonBlockingSocketChannel.java @@ -281,17 +281,25 @@ protected void executorShutdown() { } } - protected void hardCloseClient(SocketChannel client, IOConsumer closeHandler) throws IOException { + protected void hardCloseClient(SocketChannel client, IOConsumer closeHandler) { if (client != null) { Object lock = client.blockingLock(); if (logger.isDebugEnabled()) { logger.debug("hardCloseClient client=" + client); } synchronized (lock) { - if (closeHandler != null) { - closeHandler.apply(client); + try { + if (closeHandler != null) { + closeHandler.apply(client); + } + client.close(); + client = null; + } + catch (IOException e) { + if (logger.isDebugEnabled()) { + logger.debug("hardClose client socketchannel.close exception", e); + } } - client.close(); } executorShutdown(); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/ServerNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/ServerNonBlockingSocketChannel.java index 918635012..4c64e3d18 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/ServerNonBlockingSocketChannel.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/ServerNonBlockingSocketChannel.java @@ -77,19 +77,12 @@ public void writeMessageBlocking(String message) throws IOException { public void close() { SocketChannel client = this.acceptedClient; if (client != null) { - try { - hardCloseClient(client, (c) -> { - if (logger.isDebugEnabled()) { - logger.debug("Unsetting client=" + c); - } - this.acceptedClient = null; - }); - } - catch (IOException e) { + hardCloseClient(client, (c) -> { if (logger.isDebugEnabled()) { - logger.debug("Exception in hardCloseClient", e); + logger.debug("Unsetting client=" + c); } - } + this.acceptedClient = null; + }); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/UDSClientNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/UDSClientNonBlockingSocketChannel.java index ef16590b5..2e279c2b9 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/UDSClientNonBlockingSocketChannel.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/UDSClientNonBlockingSocketChannel.java @@ -9,8 +9,7 @@ public class UDSClientNonBlockingSocketChannel extends ClientNonBlockingSocketChannel { - public UDSClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) - throws IOException { + public UDSClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { super(selector, incomingBufferSize, executor); } From dd99b865841e1b5ffb2247b156d3b3d1bdc54955 Mon Sep 17 00:00:00 2001 From: Scott Lewis Date: Wed, 6 Aug 2025 17:33:32 -0700 Subject: [PATCH 05/25] Layout fixes --- .../server/UDSMcpAsyncServerTests.java | 3 ++- .../server/UDSMcpSyncServerTests.java | 13 +++++++------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java index 35be2a355..b5773e0fb 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java @@ -35,7 +35,8 @@ protected void tearDown() { if (address != null) { try { Files.deleteIfExists(address.getPath()); - } catch (IOException e) { + } + catch (IOException e) { } } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java index aa3666fbf..d2d2581d1 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java @@ -23,26 +23,27 @@ class UDSMcpSyncServerTests extends AbstractMcpSyncServerTests { private UnixDomainSocketAddress address; - + @Override protected void setUp() { super.onStart(); - address = UnixDomainSocketAddress.of(getClass().getName()+".unix.socket"); + address = UnixDomainSocketAddress.of(getClass().getName() + ".unix.socket"); } - + @Override protected void tearDown() { super.onClose(); if (address != null) { try { Files.deleteIfExists(address.getPath()); - } catch (IOException e) { + } + catch (IOException e) { } } } - + protected McpServerTransportProvider createMcpTransportProvider() { - return new UDSServerTransportProvider(address); + return new UDSServerTransportProvider(address); } @Override From 60ea90f069b94234293f1833387f84019f561f9e Mon Sep 17 00:00:00 2001 From: Scott Lewis Date: Wed, 6 Aug 2025 18:55:13 -0700 Subject: [PATCH 06/25] Update for tests --- .../modelcontextprotocol/server/UDSMcpAsyncServerTests.java | 3 ++- .../modelcontextprotocol/server/UDSMcpSyncServerTests.java | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java index b5773e0fb..cad1eae5b 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java @@ -14,9 +14,10 @@ import io.modelcontextprotocol.spec.McpServerTransportProvider; /** - * Tests for {@link McpAsyncServer} using {@link StdioServerTransport}. + * Tests for {@link McpAsyncServer} using {@link UDSServerTransport}. * * @author Christian Tzolov + * @author Scott Lewis */ @Timeout(15) // Giving extra time beyond the client timeout class UDSMcpAsyncServerTests extends AbstractMcpAsyncServerTests { diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java index d2d2581d1..6e896d478 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java @@ -15,9 +15,10 @@ import io.modelcontextprotocol.spec.McpServerTransportProvider; /** - * Tests for {@link McpSyncServer} using {@link StdioServerTransportProvider}. + * Tests for {@link McpSyncServer} using {@link UDSServerTransportProvider}. * - * @author Christian Tzolov and Scott Lewis + * @author Christian Tzolov + * @author Scott Lewis */ @Timeout(15) // Giving extra time beyond the client timeout class UDSMcpSyncServerTests extends AbstractMcpSyncServerTests { From 506715bd3484255348e09e23b45afb4ecc67c17f Mon Sep 17 00:00:00 2001 From: Scott Lewis Date: Wed, 6 Aug 2025 18:56:09 -0700 Subject: [PATCH 07/25] Removed unnecessary import --- .../io/modelcontextprotocol/server/UDSMcpSyncServerTests.java | 1 - 1 file changed, 1 deletion(-) diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java index 6e896d478..57ec7b766 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java @@ -10,7 +10,6 @@ import org.junit.jupiter.api.Timeout; -import io.modelcontextprotocol.server.transport.StdioServerTransportProvider; import io.modelcontextprotocol.server.transport.UDSServerTransportProvider; import io.modelcontextprotocol.spec.McpServerTransportProvider; From 697d8f34b01b50ab106bbf38e2189899a0eaa9b1 Mon Sep 17 00:00:00 2001 From: Scott Lewis Date: Thu, 7 Aug 2025 15:44:42 -0700 Subject: [PATCH 08/25] Added UDSMcpAsyncClientTests and UDSMcpSyncClientTests. Also added 'EverythingServer' to allow Java mcp server to provide support for client tests...that currently use the JavaScript 'everything' server. --- .../transport/UDSClientTransportProvider.java | 8 +- .../transport/UDSServerTransportProvider.java | 16 +-- .../client/UDSMcpAsyncClientTests.java | 69 +++++++++ .../client/UDSMcpSyncClientTests.java | 69 +++++++++ .../server/EverythingServer.java | 131 ++++++++++++++++++ 5 files changed, 280 insertions(+), 13 deletions(-) create mode 100644 mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpAsyncClientTests.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpSyncClientTests.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/EverythingServer.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java index bc364dea8..28bb1fe8a 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java @@ -41,6 +41,10 @@ public class UDSClientTransportProvider implements McpClientTransport { private volatile boolean isClosing = false; + public UDSClientTransportProvider(UnixDomainSocketAddress targetAddress) throws IOException { + this(new ObjectMapper(), targetAddress); + } + public UDSClientTransportProvider(ObjectMapper objectMapper, UnixDomainSocketAddress targetAddress) throws IOException { Assert.notNull(objectMapper, "The ObjectMapper can not be null"); @@ -159,9 +163,7 @@ public Mono closeGracefully() { this.clientChannel = null; } return Mono.empty(); - })).doOnNext(o -> { - logger.info("MCP server process stopped"); - }).then(Mono.fromRunnable(() -> { + })).then(Mono.fromRunnable(() -> { try { // The Threads are blocked on readLine so disposeGracefully would not // interrupt them, therefore we issue an async hard dispose. diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java index c1b677bb4..977f5b90e 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java @@ -58,7 +58,10 @@ public UDSServerTransportProvider(ObjectMapper objectMapper, UnixDomainSocketAdd public void setSessionFactory(McpServerSession.Factory sessionFactory) { this.transport = new UDSMcpSessionTransport(); this.session = sessionFactory.create(transport); - this.transport.initProcessing(); + this.transport.handleIncomingMessages(); + if (this.transport.isStarted.compareAndSet(false, true)) { + inboundReady.tryEmitValue(null); + } // Also start listening for accept try { this.serverSocketChannel = new UDSServerNonBlockingSocketChannel(); @@ -66,6 +69,8 @@ public void setSessionFactory(McpServerSession.Factory sessionFactory) { if (logger.isDebugEnabled()) { logger.debug("Accepted connect from clientChannel=" + clientChannel); } + // Start outbound processing now that the clientChannel has been accepted + this.transport.startOutboundProcessing(); }, (dataLine) -> { String message = (String) dataLine; if (logger.isDebugEnabled()) { @@ -171,18 +176,9 @@ public Mono closeGracefully() { @Override public void close() { isClosing.set(true); - serverSocketChannel.close(); logger.debug("Session transport closed"); } - private void initProcessing() { - handleIncomingMessages(); - if (isStarted.compareAndSet(false, true)) { - inboundReady.tryEmitValue(null); - } - startOutboundProcessing(); - } - private void handleIncomingMessages() { this.inboundSink.asFlux().flatMap(message -> session.handle(message)).doOnTerminate(() -> { // The outbound processing will dispose its scheduler upon completion diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpAsyncClientTests.java new file mode 100644 index 000000000..99121e01c --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpAsyncClientTests.java @@ -0,0 +1,69 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.io.IOException; +import java.net.UnixDomainSocketAddress; +import java.nio.file.Files; +import java.time.Duration; + +import org.junit.jupiter.api.Timeout; + +import io.modelcontextprotocol.client.transport.UDSClientTransportProvider; +import io.modelcontextprotocol.server.EverythingServer; +import io.modelcontextprotocol.server.transport.UDSServerTransportProvider; +import io.modelcontextprotocol.spec.McpClientTransport; + +/** + * Tests for the {@link McpAyncClient} with {@link UDSClientTransport}. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + * @author Scott Lewis + */ +@Timeout(15) // Giving extra time beyond the client timeout +class UDSMcpAsyncClientTests extends AbstractMcpAsyncClientTests { + + UnixDomainSocketAddress address; + EverythingServer server; + + @Override + protected void onStart() { + this.address = UnixDomainSocketAddress.of(getClass().getName() + ".socket"); + try { + // Delete this file if exists from previous run + Files.deleteIfExists(this.address.getPath()); + } catch (IOException e) { + throw new RuntimeException(e); + } + this.server = new EverythingServer(new UDSServerTransportProvider(address)); + } + + @Override + protected void onClose() { + server.closeGracefully(); + server = null; + try { + Files.deleteIfExists(address.getPath()); + } catch (IOException e) { + throw new RuntimeException(e); + } + address = null; + } + + @Override + protected McpClientTransport createMcpTransport() { + try { + return new UDSClientTransportProvider(address); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(2); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpSyncClientTests.java new file mode 100644 index 000000000..c52d98a97 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpSyncClientTests.java @@ -0,0 +1,69 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.io.IOException; +import java.net.UnixDomainSocketAddress; +import java.nio.file.Files; +import java.time.Duration; + +import org.junit.jupiter.api.Timeout; + +import io.modelcontextprotocol.client.transport.UDSClientTransportProvider; +import io.modelcontextprotocol.server.EverythingServer; +import io.modelcontextprotocol.server.transport.UDSServerTransportProvider; +import io.modelcontextprotocol.spec.McpClientTransport; + +/** + * Tests for the {@link McpSyncClient} with {@link UDSClientTransport}. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + * @author Scott Lewis + */ +@Timeout(15) // Giving extra time beyond the client timeout +class UDSMcpSyncClientTests extends AbstractMcpSyncClientTests { + + UnixDomainSocketAddress address; + EverythingServer server; + + @Override + protected void onStart() { + this.address = UnixDomainSocketAddress.of(getClass().getName() + ".socket"); + try { + // Delete this file if exists from previous run + Files.deleteIfExists(this.address.getPath()); + } catch (IOException e) { + throw new RuntimeException(e); + } + this.server = new EverythingServer(new UDSServerTransportProvider(address)); + } + + @Override + protected void onClose() { + server.closeGracefully(); + server = null; + try { + Files.deleteIfExists(address.getPath()); + } catch (IOException e) { + throw new RuntimeException(e); + } + address = null; + } + + @Override + protected McpClientTransport createMcpTransport() { + try { + return new UDSClientTransportProvider(address); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(2); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/EverythingServer.java b/mcp/src/test/java/io/modelcontextprotocol/server/EverythingServer.java new file mode 100644 index 000000000..a158ab2fb --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/EverythingServer.java @@ -0,0 +1,131 @@ +package io.modelcontextprotocol.server; + +import java.util.List; + +import io.modelcontextprotocol.server.McpServerFeatures.SyncPromptSpecification; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema.Annotations; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.spec.McpSchema.ProgressNotification; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptMessage; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.Resource; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.SamplingMessage; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import io.modelcontextprotocol.spec.McpSchema.TextResourceContents; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest.ContextInclusionStrategy; + +public class EverythingServer { + + private static final String TEST_RESOURCE_URI = "test://resources/"; + + private static final String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + private McpSyncServer server; + + public EverythingServer(McpServerTransportProvider transport) { + McpServerFeatures.SyncResourceSpecification[] specs = new McpServerFeatures.SyncResourceSpecification[10]; + for (int i = 0; i < 10; i++) { + String istr = String.valueOf(i); + String uri = TEST_RESOURCE_URI + istr; + specs[i] = new McpServerFeatures.SyncResourceSpecification( + Resource.builder().uri(uri).name("Test Resource").mimeType("text/plain") + .description("Test resource description").build(), + (exchange, + req) -> new ReadResourceResult(List.of(new TextResourceContents(uri, "text/plain", istr)))); + } + + this.server = McpServer.sync(transport).serverInfo(getClass().getName() + "-server", "1.0.0") + .capabilities( + ServerCapabilities.builder().logging().tools(true).prompts(true).resources(true, true).build()) + .toolCall(Tool.builder().name("echo").description("echo tool description").inputSchema(emptyJsonSchema) + .build(), (exchange, request) -> { + return CallToolResult.builder().addTextContent((String) request.arguments().get("message")) + .build(); + }) + .toolCall( + Tool.builder().name("add").description("add two integers").inputSchema(emptyJsonSchema).build(), + (exchange, request) -> { + Integer a = (Integer) request.arguments().get("a"); + Integer b = (Integer) request.arguments().get("b"); + + return CallToolResult.builder().addTextContent(String.valueOf(a + b)).build(); + }) + .toolCall(Tool.builder().name("sampleLLM").description("sampleLLM tool").inputSchema(emptyJsonSchema) + .build(), (exchange, request) -> { + String prompt = (String) request.arguments().get("prompt"); + Integer maxTokens = (Integer) request.arguments().get("maxTokens"); + SamplingMessage sm = new SamplingMessage(McpSchema.Role.USER, + new TextContent("Resource sampleLLM context: " + prompt)); + CreateMessageRequest cmRequest = CreateMessageRequest.builder().messages(List.of(sm)) + .systemPrompt("You are a helpful test server.").maxTokens(maxTokens) + .temperature(0.7).includeContext(ContextInclusionStrategy.THIS_SERVER).build(); + CreateMessageResult result = exchange.createMessage(cmRequest); + + return CallToolResult.builder() + .addTextContent("LLM sampling result: " + ((TextContent) result.content()).text()) + .build(); + }) + .toolCall(Tool.builder().name("longRunningOperation") + .description("Demonstrates a long running operation with progress updates") + .inputSchema(emptyJsonSchema).build(), (exchange, request) -> { + String progressToken = (String) request.progressToken(); + int steps = (Integer) request.arguments().get("steps"); + for (int i = 0; i < steps; i++) { + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + if (progressToken != null) { + exchange.progressNotification( + new ProgressNotification(progressToken, (double) i + 1, (double) steps, + "progress message " + String.valueOf(i + 1))); + } + } + return CallToolResult.builder().content(List.of(new TextContent("done"))).build(); + }) + .toolCall(Tool.builder().name("annotatedMessage").description("annotated message").build(), + (exchange, request) -> { + String messageType = (String) request.arguments().get("messageType"); + Annotations annotations = null; + if (messageType.equals("success")) { + annotations = new Annotations(List.of(McpSchema.Role.USER), 0.7); + } else if (messageType.equals("error")) { + annotations = new Annotations(List.of(McpSchema.Role.USER, McpSchema.Role.ASSISTANT), + 1.0); + } else if (messageType.equals("debug")) { + annotations = new Annotations(List.of(McpSchema.Role.ASSISTANT), 0.3); + } + return CallToolResult.builder().addContent(new TextContent(annotations, "some response")) + .build(); + }) + .prompts(List.of(new SyncPromptSpecification( + new Prompt("simple_prompt", "Simple prompt description", null), (exchange, request) -> { + return new GetPromptResult("description", + List.of(new PromptMessage(Role.USER, new TextContent("hello")))); + }))) + .resources(specs).build(); + } + + public void closeGracefully() { + if (this.server != null) { + this.server.closeGracefully(); + this.server = null; + } + } +} From 398462ebeeedb0350d837d6ced2031a8b0242b15 Mon Sep 17 00:00:00 2001 From: Scott Lewis Date: Fri, 1 Aug 2025 15:27:16 -0700 Subject: [PATCH 09/25] Initial checkin to address feature request https://github.com/modelcontextprotocol/java-sdk/issues/415 i.mcp.client/server.transport packages contains new transport providers classes : UDSClient/ServerTransportProvider. The name UDS refers to UnixDomainSocket as that's the SocketChannel type being used. These transport providers use the new classses in util: UDSClientNonBlockingSocketChannel and UDSServerNonBlockingSocketChannel. These further depend upon super classes ClientNonBlockingSocketChannel and ServerNonBlockingSocketChannel which both depend upon superclass NonBlockSocketChannel, which has most of the actual implementation of the single-threaded/Selector based non-blocking read and write. This subclass/superclass structure means that Inet4 and Inet6 client/server SocketChannel classes are also present. These will work just the same as the UDSClient/ServerSocketChannel classses but rather will use Inet4 and Inet6 connections rather than UnixDomainSockets. It will be very easy to create server/client transport providers that use inet4 or inet6 tcp stacks for localhost or non localhost connections. But for the moment, I've only created UDSServer/ClientTransportProviders for testing and review. Signed-off-by: Scott Lewis --- .../transport/UDSClientTransportProvider.java | 183 +++++++++ .../transport/UDSServerTransportProvider.java | 239 +++++++++++ .../util/ClientNonBlockingSocketChannel.java | 74 ++++ .../Inet4ClientNonBlockingSocketChannel.java | 36 ++ .../Inet4ServerNonBlockingSocketChannel.java | 34 ++ .../Inet6ClientNonBlockingSocketChannel.java | 36 ++ .../Inet6ServerNonBlockingSocketChannel.java | 34 ++ .../util/NonBlockingSocketChannel.java | 387 ++++++++++++++++++ .../util/ServerNonBlockingSocketChannel.java | 96 +++++ .../UDSClientNonBlockingSocketChannel.java | 34 ++ .../UDSServerNonBlockingSocketChannel.java | 33 ++ 11 files changed, 1186 insertions(+) create mode 100644 mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/ClientNonBlockingSocketChannel.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/Inet4ClientNonBlockingSocketChannel.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/Inet4ServerNonBlockingSocketChannel.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/Inet6ClientNonBlockingSocketChannel.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/Inet6ServerNonBlockingSocketChannel.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/NonBlockingSocketChannel.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/ServerNonBlockingSocketChannel.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/UDSClientNonBlockingSocketChannel.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/UDSServerNonBlockingSocketChannel.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java new file mode 100644 index 000000000..bc364dea8 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java @@ -0,0 +1,183 @@ +package io.modelcontextprotocol.client.transport; + +import java.io.IOException; +import java.net.UnixDomainSocketAddress; +import java.time.Duration; +import java.util.concurrent.Executors; +import java.util.function.Function; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.UDSClientNonBlockingSocketChannel; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; + +public class UDSClientTransportProvider implements McpClientTransport { + + private static final Logger logger = LoggerFactory.getLogger(UDSClientTransportProvider.class); + + private final Sinks.Many inboundSink; + + private final Sinks.Many outboundSink; + + private ObjectMapper objectMapper; + + private UDSClientNonBlockingSocketChannel clientChannel; + + private UnixDomainSocketAddress targetAddress; + + private Scheduler outboundScheduler; + + private volatile boolean isClosing = false; + + public UDSClientTransportProvider(ObjectMapper objectMapper, UnixDomainSocketAddress targetAddress) + throws IOException { + Assert.notNull(objectMapper, "The ObjectMapper can not be null"); + + this.inboundSink = Sinks.many().unicast().onBackpressureBuffer(); + this.outboundSink = Sinks.many().unicast().onBackpressureBuffer(); + + this.objectMapper = objectMapper; + + // Start threads + this.outboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), "outbound"); + this.clientChannel = new UDSClientNonBlockingSocketChannel(); + this.targetAddress = targetAddress; + } + + @Override + public Mono connect(Function, Mono> handler) { + return Mono.fromRunnable(() -> { + handleIncomingMessages(handler); + try { + this.clientChannel.connectBlocking(targetAddress, (client) -> { + logger.info("CONNECTED to targetAddress=" + targetAddress); + }, (data) -> { + JSONRPCMessage json = McpSchema.deserializeJsonRpcMessage(this.objectMapper, data); + if (!this.inboundSink.tryEmitNext(json).isSuccess()) { + if (!isClosing) { + logger.error("Failed to enqueue inbound message: {}", json); + } + } + }); + } + catch (IOException e) { + this.clientChannel.close(); + throw new RuntimeException( + "Connect to address=" + targetAddress + " failed message: " + e.getMessage()); + } + startOutboundProcessing(); + }).subscribeOn(Schedulers.boundedElastic()); + } + + private void handleIncomingMessages(Function, Mono> inboundMessageHandler) { + this.inboundSink.asFlux() + .flatMap(message -> Mono.just(message) + .transform(inboundMessageHandler) + .contextWrite(ctx -> ctx.put("observation", "myObservation"))) + .subscribe(); + } + + @Override + public Mono sendMessage(JSONRPCMessage message) { + if (this.outboundSink.tryEmitNext(message).isSuccess()) { + // TODO: essentially we could reschedule ourselves in some time and make + // another attempt with the already read data but pause reading until + // success + // In this approach we delegate the retry and the backpressure onto the + // caller. This might be enough for most cases. + return Mono.empty(); + } + else { + return Mono.error(new RuntimeException("Failed to enqueue message")); + } + } + + private void startOutboundProcessing() { + this.handleOutbound(messages -> messages + // this bit is important since writes come from user threads, and we + // want to ensure that the actual writing happens on a dedicated thread + .publishOn(outboundScheduler) + .handle((message, s) -> { + if (message != null && !isClosing) { + try { + this.clientChannel.writeMessageBlocking(objectMapper.writeValueAsString(message)); + s.next(message); + } + catch (IOException e) { + s.error(new RuntimeException(e)); + } + } + })); + } + + protected void handleOutbound(Function, Flux> outboundConsumer) { + outboundConsumer.apply(outboundSink.asFlux()).doOnComplete(() -> { + isClosing = true; + outboundSink.tryEmitComplete(); + }).doOnError(e -> { + if (!isClosing) { + logger.error("Error in outbound processing", e); + isClosing = true; + outboundSink.tryEmitComplete(); + } + }).subscribe(); + } + + /** + * Gracefully closes the transport by destroying the process and disposing of the + * schedulers. This method sends a TERM signal to the process and waits for it to exit + * before cleaning up resources. + * @return A Mono that completes when the transport is closed + */ + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + isClosing = true; + logger.debug("Initiating graceful shutdown"); + }).then(Mono.defer(() -> { + // First complete all sinks to stop accepting new messages + inboundSink.tryEmitComplete(); + outboundSink.tryEmitComplete(); + // Give a short time for any pending messages to be processed + return Mono.delay(Duration.ofMillis(100)).then(); + })).then(Mono.defer(() -> { + // Close our clientChannel + if (this.clientChannel != null) { + this.clientChannel.close(); + this.clientChannel = null; + } + return Mono.empty(); + })).doOnNext(o -> { + logger.info("MCP server process stopped"); + }).then(Mono.fromRunnable(() -> { + try { + // The Threads are blocked on readLine so disposeGracefully would not + // interrupt them, therefore we issue an async hard dispose. + outboundScheduler.dispose(); + + logger.debug("Graceful shutdown completed"); + } + catch (Exception e) { + logger.error("Error during graceful shutdown", e); + } + })).then().subscribeOn(Schedulers.boundedElastic()); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return this.objectMapper.convertValue(data, typeRef); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java new file mode 100644 index 000000000..d6e76a56a --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java @@ -0,0 +1,239 @@ +package io.modelcontextprotocol.server.transport; + +import java.io.IOException; +import java.net.UnixDomainSocketAddress; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpServerTransport; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.UDSServerNonBlockingSocketChannel; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; + +public class UDSServerTransportProvider implements McpServerTransportProvider { + + private static final Logger logger = LoggerFactory.getLogger(UDSServerTransportProvider.class); + + private final ObjectMapper objectMapper; + + private McpServerSession session; + + private final AtomicBoolean isClosing = new AtomicBoolean(false); + + private final Sinks.One inboundReady = Sinks.one(); + + private UDSServerNonBlockingSocketChannel serverSocketChannel; + + private UnixDomainSocketAddress address; + + private UDSMcpSessionTransport transport; + + public UDSServerTransportProvider(ObjectMapper objectMapper, UnixDomainSocketAddress unixSocketAddress) + throws IOException { + Assert.notNull(objectMapper, "The ObjectMapper can not be null"); + this.objectMapper = objectMapper; + this.address = unixSocketAddress; + this.serverSocketChannel = new UDSServerNonBlockingSocketChannel(); + } + + @Override + public void setSessionFactory(McpServerSession.Factory sessionFactory) { + try { + this.serverSocketChannel.start(this.address, (clientChannel) -> { + this.transport = new UDSMcpSessionTransport(); + this.session = sessionFactory.create(transport); + this.transport.initProcessing(); + }, (dataLine) -> { + String message = (String) dataLine; + try { + this.transport + .handleMessage(McpSchema.deserializeJsonRpcMessage(this.objectMapper, message.trim())); + } + catch (IOException e) { + this.serverSocketChannel.close(); + } + }); + } + catch (IOException e) { + this.serverSocketChannel.close(); + throw new RuntimeException("accepterNonBlockSocketChannel could not be started"); + } + } + + @Override + public Mono notifyClients(String method, Object params) { + if (this.session == null) { + return Mono.error(new McpError("No session to close")); + } + return this.session.sendNotification(method, params) + .doOnError(e -> logger.error("Failed to send notification: {}", e.getMessage())); + } + + @Override + public Mono closeGracefully() { + if (this.session == null) { + return Mono.empty(); + } + return this.session.closeGracefully(); + } + + /** + * Implementation of McpServerTransport for the stdio session. + */ + private class UDSMcpSessionTransport implements McpServerTransport { + + private final Sinks.Many inboundSink; + + private final Sinks.Many outboundSink; + + private final AtomicBoolean isStarted = new AtomicBoolean(false); + + /** Scheduler for handling outbound messages */ + private Scheduler outboundScheduler; + + private final Sinks.One outboundReady = Sinks.one(); + + public UDSMcpSessionTransport() { + + this.inboundSink = Sinks.many().unicast().onBackpressureBuffer(); + this.outboundSink = Sinks.many().unicast().onBackpressureBuffer(); + + // Use bounded schedulers for better resource management + // this.inboundScheduler = + // Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), + // "uds-inbound"); + this.outboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), + "uds-outbound"); + } + + public void handleMessage(McpSchema.JSONRPCMessage json) throws IOException { + try { + if (!this.inboundSink.tryEmitNext(json).isSuccess()) { + throw new Exception("Failed to enqueue message"); + } + } + catch (Exception e) { + logIfNotClosing("Error processing inbound message", e); + throw new IOException("Error in processing inbound message", e); + } + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + + return Mono.zip(inboundReady.asMono(), outboundReady.asMono()).then(Mono.defer(() -> { + if (outboundSink.tryEmitNext(message).isSuccess()) { + return Mono.empty(); + } + else { + return Mono.error(new RuntimeException("Failed to enqueue message")); + } + })); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + isClosing.set(true); + logger.debug("Session transport closing gracefully"); + inboundSink.tryEmitComplete(); + }); + } + + @Override + public void close() { + isClosing.set(true); + serverSocketChannel.close(); + logger.debug("Session transport closed"); + } + + private void initProcessing() { + handleIncomingMessages(); + if (isStarted.compareAndSet(false, true)) { + inboundReady.tryEmitValue(null); + } + startOutboundProcessing(); + } + + private void handleIncomingMessages() { + this.inboundSink.asFlux().flatMap(message -> session.handle(message)).doOnTerminate(() -> { + // The outbound processing will dispose its scheduler upon completion + this.outboundSink.tryEmitComplete(); + // this.inboundScheduler.dispose(); + }).subscribe(); + } + + /** + * Starts the outbound processing thread that writes JSON-RPC messages to stdout. + * Messages are serialized to JSON and written with a newline delimiter. + */ + private void startOutboundProcessing() { + Function, Flux> outboundConsumer = messages -> messages // @formatter:off + .doOnSubscribe(subscription -> outboundReady.tryEmitValue(null)) + .publishOn(outboundScheduler) + .handle((message, sink) -> { + if (message != null && !isClosing.get()) { + try { + serverSocketChannel.writeMessageBlocking(objectMapper.writeValueAsString(message)); + sink.next(message); + } + catch (IOException e) { + if (!isClosing.get()) { + logger.error("Error writing message", e); + sink.error(new RuntimeException(e)); + } + else { + logger.debug("Stream closed during shutdown", e); + } + } + } + else if (isClosing.get()) { + sink.complete(); + } + }) + .doOnComplete(() -> { + isClosing.set(true); + outboundScheduler.dispose(); + }) + .doOnError(e -> { + if (!isClosing.get()) { + logger.error("Error in outbound processing", e); + isClosing.set(true); + outboundScheduler.dispose(); + } + }) + .map(msg -> (JSONRPCMessage) msg); + + outboundConsumer.apply(outboundSink.asFlux()).subscribe(); + } // @formatter:on + + private void logIfNotClosing(String message, Exception e) { + if (!isClosing.get()) { + logger.error(message, e); + } + } + + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/ClientNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/ClientNonBlockingSocketChannel.java new file mode 100644 index 000000000..65e976f53 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/ClientNonBlockingSocketChannel.java @@ -0,0 +1,74 @@ +package io.modelcontextprotocol.util; + +import java.io.IOException; +import java.net.SocketAddress; +import java.net.StandardProtocolFamily; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.util.concurrent.ExecutorService; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ClientNonBlockingSocketChannel extends NonBlockingSocketChannel { + + private static final Logger logger = LoggerFactory.getLogger(ClientNonBlockingSocketChannel.class); + + private SocketChannel client; + + public ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) + throws IOException { + super(selector, incomingBufferSize, executor); + } + + public ClientNonBlockingSocketChannel() throws IOException { + super(); + } + + public ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize) { + super(selector, incomingBufferSize); + } + + public ClientNonBlockingSocketChannel(Selector selector) { + super(selector); + } + + public void connectBlocking(StandardProtocolFamily protocol, SocketAddress address, + IOConsumer connectHandler, IOConsumer readHandler) throws IOException { + if (this.client != null) { + throw new IOException("Already connected"); + } + this.client = connectBlocking(SocketChannel.open(protocol), address, connectHandler, readHandler); + } + + @Override + protected void handleException(SelectionKey key, Exception e) { + if (logger.isDebugEnabled()) { + logger.debug("handleException", e); + } + close(); + } + + @Override + public void close() { + try { + hardCloseClient(this.client, (client) -> { + this.client = null; + }); + } + catch (IOException e) { + if (logger.isDebugEnabled()) { + logger.debug("Exception in hardCloseClient", e); + } + } + } + + public void writeMessageBlocking(String message) throws IOException { + if (this.client == null) { + throw new IOException("Cannot write until client connected"); + } + writeBlocking(client, message); + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Inet4ClientNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/Inet4ClientNonBlockingSocketChannel.java new file mode 100644 index 000000000..650fd52b4 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/Inet4ClientNonBlockingSocketChannel.java @@ -0,0 +1,36 @@ +package io.modelcontextprotocol.util; + +import java.io.IOException; +import java.net.Inet4Address; +import java.net.InetSocketAddress; +import java.net.StandardProtocolFamily; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.util.concurrent.ExecutorService; + +public class Inet4ClientNonBlockingSocketChannel extends ClientNonBlockingSocketChannel { + + public Inet4ClientNonBlockingSocketChannel() throws IOException { + super(); + } + + public Inet4ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) + throws IOException { + super(selector, incomingBufferSize, executor); + } + + public Inet4ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize) throws IOException { + super(selector, incomingBufferSize); + } + + public Inet4ClientNonBlockingSocketChannel(Selector selector) throws IOException { + super(selector); + } + + public void connectBlocking(Inet4Address address, int port, IOConsumer connectHandler, + IOConsumer readHandler) throws IOException { + super.connectBlocking(StandardProtocolFamily.INET, new InetSocketAddress(address, port), connectHandler, + readHandler); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Inet4ServerNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/Inet4ServerNonBlockingSocketChannel.java new file mode 100644 index 000000000..a4b9c61f8 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/Inet4ServerNonBlockingSocketChannel.java @@ -0,0 +1,34 @@ +package io.modelcontextprotocol.util; + +import java.io.IOException; +import java.net.Inet4Address; +import java.net.InetSocketAddress; +import java.net.StandardProtocolFamily; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.util.concurrent.ExecutorService; + +public class Inet4ServerNonBlockingSocketChannel extends ServerNonBlockingSocketChannel { + + public Inet4ServerNonBlockingSocketChannel() throws IOException { + super(); + } + + public Inet4ServerNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { + super(selector, incomingBufferSize, executor); + } + + public Inet4ServerNonBlockingSocketChannel(Selector selector, int incomingBufferSize) { + super(selector, incomingBufferSize); + } + + public Inet4ServerNonBlockingSocketChannel(Selector selector) { + super(selector); + } + + public void start(Inet4Address address, int port, IOConsumer acceptHandler, + IOConsumer readHandler) throws IOException { + super.start(StandardProtocolFamily.INET, new InetSocketAddress(address, port), acceptHandler, readHandler); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Inet6ClientNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/Inet6ClientNonBlockingSocketChannel.java new file mode 100644 index 000000000..da8739758 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/Inet6ClientNonBlockingSocketChannel.java @@ -0,0 +1,36 @@ +package io.modelcontextprotocol.util; + +import java.io.IOException; +import java.net.Inet6Address; +import java.net.InetSocketAddress; +import java.net.StandardProtocolFamily; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.util.concurrent.ExecutorService; + +public class Inet6ClientNonBlockingSocketChannel extends ClientNonBlockingSocketChannel { + + public Inet6ClientNonBlockingSocketChannel() throws IOException { + super(); + } + + public Inet6ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) + throws IOException { + super(selector, incomingBufferSize, executor); + } + + public Inet6ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize) throws IOException { + super(selector, incomingBufferSize); + } + + public Inet6ClientNonBlockingSocketChannel(Selector selector) throws IOException { + super(selector); + } + + public void connectBlocking(Inet6Address address, int port, IOConsumer connectHandler, + IOConsumer readHandler) throws IOException { + super.connectBlocking(StandardProtocolFamily.INET6, new InetSocketAddress(address, port), connectHandler, + readHandler); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Inet6ServerNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/Inet6ServerNonBlockingSocketChannel.java new file mode 100644 index 000000000..8a1a95e27 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/Inet6ServerNonBlockingSocketChannel.java @@ -0,0 +1,34 @@ +package io.modelcontextprotocol.util; + +import java.io.IOException; +import java.net.Inet6Address; +import java.net.InetSocketAddress; +import java.net.StandardProtocolFamily; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.util.concurrent.ExecutorService; + +public class Inet6ServerNonBlockingSocketChannel extends ServerNonBlockingSocketChannel { + + public Inet6ServerNonBlockingSocketChannel() throws IOException { + super(); + } + + public Inet6ServerNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { + super(selector, incomingBufferSize, executor); + } + + public Inet6ServerNonBlockingSocketChannel(Selector selector, int incomingBufferSize) { + super(selector, incomingBufferSize); + } + + public Inet6ServerNonBlockingSocketChannel(Selector selector) { + super(selector); + } + + public void start(Inet6Address address, int port, IOConsumer acceptHandler, + IOConsumer readHandler) throws IOException { + super.start(StandardProtocolFamily.INET6, new InetSocketAddress(address, port), acceptHandler, readHandler); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/NonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/NonBlockingSocketChannel.java new file mode 100644 index 000000000..b3b46e13c --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/NonBlockingSocketChannel.java @@ -0,0 +1,387 @@ +package io.modelcontextprotocol.util; + +import java.io.IOException; +import java.io.InterruptedIOException; +import java.net.SocketAddress; +import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; +import java.nio.charset.StandardCharsets; +import java.util.Iterator; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public abstract class NonBlockingSocketChannel { + + private static final Logger logger = LoggerFactory.getLogger(NonBlockingSocketChannel.class); + + public static final int DEFAULT_INBUFFER_SIZE = 1024; + + protected static String MESSAGE_DELIMITER = "\n"; + + protected static int BLOCKING_WRITE_TIMEOUT = 5000; + + protected static int BLOCKING_CONNECT_TIMEOUT = 10000; + + protected final Selector selector; + + protected final ByteBuffer inBuffer; + + protected final ExecutorService executor; + + @FunctionalInterface + public interface IOConsumer { + + void apply(T t) throws IOException; + + } + + protected class AttachedIO { + + public ByteBuffer writing; + + public StringBuffer reading; + + } + + public NonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { + Assert.notNull(selector, "Selector must not be null"); + this.selector = selector; + this.inBuffer = ByteBuffer.allocate(incomingBufferSize); + this.executor = (executor == null) ? Executors.newSingleThreadExecutor() : executor; + } + + public NonBlockingSocketChannel(Selector selector, int incomingBufferSize) { + this(selector, incomingBufferSize, null); + } + + public NonBlockingSocketChannel(Selector selector) { + this(selector, DEFAULT_INBUFFER_SIZE); + } + + public NonBlockingSocketChannel() throws IOException { + this(Selector.open()); + } + + protected Runnable getRunnableForProcessing(IOConsumer acceptHandler, + IOConsumer connectHandler, IOConsumer readHandler) { + return () -> { + SelectionKey key = null; + try { + while (true) { + this.selector.select(); + Set selectedKeys = selector.selectedKeys(); + Iterator iter = selectedKeys.iterator(); + while (iter.hasNext()) { + key = iter.next(); + if (key.isConnectable()) { + handleConnectable(key, connectHandler); + } + else if (key.isAcceptable()) { + handleAcceptable(key, acceptHandler); + } + else if (key.isReadable()) { + handleReadable(key, readHandler); + } + else if (key.isWritable()) { + handleWritable(key); + } + iter.remove(); + } + } + } + catch (Exception e) { + handleException(key, e); + } + }; + } + + public abstract void close(); + + protected abstract void handleException(SelectionKey key, Exception e); + + protected void start(IOConsumer acceptHandler, IOConsumer connectHandler, + IOConsumer readHandler) throws IOException { + this.executor.execute(getRunnableForProcessing(acceptHandler, connectHandler, readHandler)); + } + + // For client subclasses + protected void handleConnectable(SelectionKey key, IOConsumer connectHandler) throws IOException { + SocketChannel client = (SocketChannel) key.channel(); + Object lock = client.blockingLock(); + if (logger.isDebugEnabled()) { + logger.debug("handleConnectable client=" + client.getRemoteAddress()); + } + synchronized (lock) { + client.configureBlocking(false); + client.register(this.selector, SelectionKey.OP_READ, new AttachedIO()); + if (client.isConnectionPending()) { + client.finishConnect(); + if (logger.isDebugEnabled()) { + logger.debug("handleConnectable FINISHED"); + } + } + if (connectHandler != null) { + connectHandler.apply(client); + } + } + } + + protected void handleAcceptable(SelectionKey key, IOConsumer acceptHandler) throws IOException { + ServerSocketChannel serverSocket = (ServerSocketChannel) key.channel(); + SocketChannel client = serverSocket.accept(); + Object lock = client.blockingLock(); + if (logger.isDebugEnabled()) { + logger.debug("handleAcceptable client=" + client); + } + synchronized (lock) { + client.configureBlocking(false); + client.register(this.selector, SelectionKey.OP_READ, new AttachedIO()); + configureAcceptSocketChannel(client); + if (client.isConnectionPending()) { + client.finishConnect(); + if (logger.isDebugEnabled()) { + logger.debug("handleAcceptable FINISHED"); + } + } + if (acceptHandler != null) { + acceptHandler.apply(client); + } + } + } + + protected void configureAcceptSocketChannel(SocketChannel client) throws IOException { + // Subclasses may override + } + + protected AttachedIO getAttachedIO(SelectionKey key) throws IOException { + AttachedIO io = (AttachedIO) key.attachment(); + if (io == null) { + throw new IOException("No AttachedIO object found on key"); + } + return io; + } + + protected void handleReadable(SelectionKey key, IOConsumer readHandler) throws IOException { + SocketChannel client = (SocketChannel) key.channel(); + Object lock = client.blockingLock(); + AttachedIO io = getAttachedIO(key); + if (logger.isDebugEnabled()) { + logger.debug("handleReadable client=" + client); + } + synchronized (lock) { + // non-blocking read here + int r = client.read(this.inBuffer); + // Check if we should expect any more reads + if (r == -1) { + throw new IOException("Channel read reached end of stream"); + } + this.inBuffer.flip(); + String partial = new String(this.inBuffer.array(), 0, r, StandardCharsets.UTF_8); + // If there is are previous partial, then get the io.reading string Buffer + StringBuffer sb = (io.reading != null) ? (StringBuffer) io.reading : new StringBuffer(); + // And append the just read partial to the string buffer + sb.append(partial); + if (partial.endsWith(MESSAGE_DELIMITER)) { + // Get the entire message from the string buffer + String message = sb.toString(); + // Set the io.reading value to null as we are done with this message + io.reading = null; + if (logger.isDebugEnabled()) { + logger.debug("handleReadable COMPLETE msg=" + message); + } + if (readHandler != null) { + readHandler.apply(message); + } + } + else { + io.reading = sb; + if (logger.isDebugEnabled()) { + logger.debug("handleReadable PARTIAL msg=" + partial); + } + } + } + // Clear inbuffer for next read + this.inBuffer.clear(); + } + + protected void handleWritable(SelectionKey key) throws IOException { + ByteBuffer buf = getAttachedIO(key).writing; + SocketChannel client = (SocketChannel) key.channel(); + if (buf != null) { + doWrite(key, client, buf, (lock) -> { + synchronized (lock) { + if (logger.isDebugEnabled()) { + logger.debug("handleWritable NOTIFY client=" + client); + } + lock.notify(); + } + }); + } + } + + protected void doWrite(SocketChannel client, String message, IOConsumer writeHandler) throws IOException { + Assert.notNull(client, "Client must not be null"); + Assert.notNull(message, "Message must not be null"); + if (logger.isDebugEnabled()) { + logger.debug("doWrite msg=" + message); + } + doWrite(client.keyFor(this.selector), client, ByteBuffer.wrap(message.getBytes(StandardCharsets.UTF_8)), + writeHandler); + } + + protected void doWrite(SelectionKey key, SocketChannel client, ByteBuffer buf, IOConsumer writeHandler) + throws IOException { + AttachedIO io = (AttachedIO) key.attachment(); + Object lock = client.blockingLock(); + synchronized (lock) { + int written = client.write(buf); + if (buf.hasRemaining()) { + if (logger.isDebugEnabled()) { + logger.debug("doWrite PARTIAL written=" + written + " remaining=" + buf.remaining()); + } + io.writing = buf.slice(); + key.interestOpsOr(SelectionKey.OP_WRITE); + } + else { + if (logger.isDebugEnabled()) { + logger.debug("doWrite COMPLETED msg=" + new String(buf.array(), 0, written)); + } + io.writing = null; + key.interestOps(SelectionKey.OP_READ); + if (writeHandler != null) { + writeHandler.apply(lock); + } + } + } + } + + protected void executorShutdown() { + if (!this.executor.isShutdown()) { + if (logger.isDebugEnabled()) { + logger.debug("executorShutdown"); + } + try { + this.executor.awaitTermination(2000, TimeUnit.MILLISECONDS); + this.executor.shutdown(); + } + catch (InterruptedException e) { + if (logger.isDebugEnabled()) { + logger.debug("Exception in executor awaitTermination", e); + } + } + } + } + + protected void hardCloseClient(SocketChannel client, IOConsumer closeHandler) throws IOException { + if (client != null) { + Object lock = client.blockingLock(); + if (logger.isDebugEnabled()) { + logger.debug("hardCloseClient client=" + client); + } + synchronized (lock) { + if (closeHandler != null) { + closeHandler.apply(client); + } + client.close(); + } + executorShutdown(); + } + } + + protected void writeBlocking(SocketChannel client, String message) throws IOException { + Objects.requireNonNull(client, "Client must not be null"); + Objects.requireNonNull(message, "Message must not be null"); + // Escape any embedded newlines in the JSON message, and add newline + String outputMessage = message.replace("\r\n", "\\n") + .replace("\n", "\\n") + .replace("\r", "\\n") + .concat(MESSAGE_DELIMITER); + Object lock = client.blockingLock(); + if (logger.isDebugEnabled()) { + logger.debug("writeBlocking msg=" + outputMessage); + } + synchronized (lock) { + // do the non blocking write in thread while holding lock. + doWrite(client, outputMessage, null); + ByteBuffer bufRemaining = null; + long waitTime = System.currentTimeMillis() + BLOCKING_WRITE_TIMEOUT; + while (waitTime - System.currentTimeMillis() > 0) { + // Before releasing lock, check for writing buffer remaining + bufRemaining = getAttachedIO(client.keyFor(this.selector)).writing; + if (bufRemaining == null || bufRemaining.remaining() == 0) { + // It's done + break; + } + // If write is *not* completed, then wait timeout /10 + try { + if (logger.isDebugEnabled()) { + logger + .debug("writeBlocking WAITING=" + String.valueOf(waitTime / 10) + " msg=" + outputMessage); + } + lock.wait(waitTime / 10); + } + catch (InterruptedException e) { + throw new InterruptedIOException("write message wait interrupted"); + } + } + if (bufRemaining != null && bufRemaining.remaining() > 0) { + throw new IOException("Write not completed. Non empty buffer remaining after timeout"); + } + } + if (logger.isDebugEnabled()) { + logger.debug("writeBlocking COMPLETED msg=" + outputMessage); + } + } + + protected void configureConnectSocketChannel(SocketChannel client, SocketAddress connectAddress) + throws IOException { + // Subclasses may override + } + + protected SocketChannel connectBlocking(SocketChannel client, SocketAddress address, + IOConsumer connectHandler, IOConsumer readHandler) throws IOException { + Object lock = client.blockingLock(); + if (logger.isDebugEnabled()) { + logger.debug("connectBlocking CONNECTING targetAddress=" + address); + } + synchronized (lock) { + client.configureBlocking(false); + client.register(selector, SelectionKey.OP_CONNECT); + configureConnectSocketChannel(client, address); + // Start the read thread before connect + // No/null accept handler for clients + start(null, (c) -> { + if (connectHandler != null) { + connectHandler.apply(c); + } + lock.notify(); + }, readHandler); + + client.connect(address); + + try { + if (logger.isDebugEnabled()) { + logger.debug("connectBlocking WAITING targetAddress=" + address); + } + lock.wait(BLOCKING_CONNECT_TIMEOUT); + } + catch (InterruptedException e) { + throw new IOException("Connect to address=" + address + " timed out"); + } + if (logger.isDebugEnabled()) { + logger.debug("connectBlocking CONNECTED client=" + client.getLocalAddress() + " connecting=" + address); + } + return client; + } + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/ServerNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/ServerNonBlockingSocketChannel.java new file mode 100644 index 000000000..918635012 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/ServerNonBlockingSocketChannel.java @@ -0,0 +1,96 @@ +package io.modelcontextprotocol.util; + +import java.io.IOException; +import java.net.SocketAddress; +import java.net.StandardProtocolFamily; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; +import java.util.concurrent.ExecutorService; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ServerNonBlockingSocketChannel extends NonBlockingSocketChannel { + + private static final Logger logger = LoggerFactory.getLogger(ServerNonBlockingSocketChannel.class); + + protected SocketChannel acceptedClient; + + public ServerNonBlockingSocketChannel() throws IOException { + super(); + } + + public ServerNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { + super(selector, incomingBufferSize, executor); + } + + public ServerNonBlockingSocketChannel(Selector selector, int incomingBufferSize) { + super(selector, incomingBufferSize); + } + + public ServerNonBlockingSocketChannel(Selector selector) { + super(selector); + } + + protected void configureServerSocketChannel(ServerSocketChannel serverSocketChannel, SocketAddress acceptAddress) { + // Subclasses may override + } + + public void start(StandardProtocolFamily protocol, SocketAddress address, IOConsumer acceptHandler, + IOConsumer readHandler) throws IOException { + ServerSocketChannel serverChannel = ServerSocketChannel.open(protocol); + serverChannel.configureBlocking(false); + serverChannel.register(this.selector, SelectionKey.OP_ACCEPT); + configureServerSocketChannel(serverChannel, address); + serverChannel.bind(address); + // Start thread/processing of incoming accept, read + super.start((client) -> { + if (logger.isDebugEnabled()) { + logger.debug("Setting client=" + client); + } + this.acceptedClient = client; + if (acceptHandler != null) { + acceptHandler.apply(this.acceptedClient); + } + // No/null connect handler for Acceptors...only accepthandler + }, null, readHandler); + } + + @Override + protected void handleException(SelectionKey key, Exception e) { + if (logger.isDebugEnabled()) { + logger.debug("handleException", e); + } + close(); + } + + public void writeMessageBlocking(String message) throws IOException { + if (this.acceptedClient == null) { + throw new IOException("Cannot write until client connected"); + } + writeBlocking(acceptedClient, message); + } + + @Override + public void close() { + SocketChannel client = this.acceptedClient; + if (client != null) { + try { + hardCloseClient(client, (c) -> { + if (logger.isDebugEnabled()) { + logger.debug("Unsetting client=" + c); + } + this.acceptedClient = null; + }); + } + catch (IOException e) { + if (logger.isDebugEnabled()) { + logger.debug("Exception in hardCloseClient", e); + } + } + } + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/UDSClientNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/UDSClientNonBlockingSocketChannel.java new file mode 100644 index 000000000..ef16590b5 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/UDSClientNonBlockingSocketChannel.java @@ -0,0 +1,34 @@ +package io.modelcontextprotocol.util; + +import java.io.IOException; +import java.net.StandardProtocolFamily; +import java.net.UnixDomainSocketAddress; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.util.concurrent.ExecutorService; + +public class UDSClientNonBlockingSocketChannel extends ClientNonBlockingSocketChannel { + + public UDSClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) + throws IOException { + super(selector, incomingBufferSize, executor); + } + + public UDSClientNonBlockingSocketChannel() throws IOException { + super(); + } + + public UDSClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize) { + super(selector, incomingBufferSize); + } + + public UDSClientNonBlockingSocketChannel(Selector selector) { + super(selector); + } + + public void connectBlocking(UnixDomainSocketAddress address, IOConsumer connectHandler, + IOConsumer readHandler) throws IOException { + super.connectBlocking(StandardProtocolFamily.UNIX, address, connectHandler, readHandler); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/UDSServerNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/UDSServerNonBlockingSocketChannel.java new file mode 100644 index 000000000..259315712 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/UDSServerNonBlockingSocketChannel.java @@ -0,0 +1,33 @@ +package io.modelcontextprotocol.util; + +import java.io.IOException; +import java.net.StandardProtocolFamily; +import java.net.UnixDomainSocketAddress; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.util.concurrent.ExecutorService; + +public class UDSServerNonBlockingSocketChannel extends ServerNonBlockingSocketChannel { + + public UDSServerNonBlockingSocketChannel() throws IOException { + super(); + } + + public UDSServerNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { + super(selector, incomingBufferSize, executor); + } + + public UDSServerNonBlockingSocketChannel(Selector selector, int incomingBufferSize) { + super(selector, incomingBufferSize); + } + + public UDSServerNonBlockingSocketChannel(Selector selector) { + super(selector); + } + + public void start(UnixDomainSocketAddress address, IOConsumer acceptHandler, + IOConsumer readHandler) throws IOException { + super.start(StandardProtocolFamily.UNIX, address, acceptHandler, readHandler); + } + +} From f6d706373a9579156e09adc3530ac85e2ea5148c Mon Sep 17 00:00:00 2001 From: Scott Lewis Date: Tue, 5 Aug 2025 11:53:27 -0700 Subject: [PATCH 10/25] fix for synchronization --- .../util/NonBlockingSocketChannel.java | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/NonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/NonBlockingSocketChannel.java index b3b46e13c..5fc0eaf9d 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/NonBlockingSocketChannel.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/NonBlockingSocketChannel.java @@ -222,7 +222,7 @@ protected void handleWritable(SelectionKey key) throws IOException { if (logger.isDebugEnabled()) { logger.debug("handleWritable NOTIFY client=" + client); } - lock.notify(); + lock.notifyAll(); } }); } @@ -360,10 +360,12 @@ protected SocketChannel connectBlocking(SocketChannel client, SocketAddress addr // Start the read thread before connect // No/null accept handler for clients start(null, (c) -> { - if (connectHandler != null) { - connectHandler.apply(c); + synchronized (lock) { + if (connectHandler != null) { + connectHandler.apply(c); + } + lock.notifyAll(); } - lock.notify(); }, readHandler); client.connect(address); From 2e14fd06cc3a4f917b521856dd59e227cf1932a4 Mon Sep 17 00:00:00 2001 From: Scott Lewis Date: Wed, 6 Aug 2025 15:35:12 -0700 Subject: [PATCH 11/25] Added async and sync server tests (UDSMcpAsyncServerTests, UDSMcpSyncServerTest). Also made simplifying changes to *socketchannel classes --- .../server/UDSMcpAsyncServerTests.java | 52 ++++++++++++++++++ .../server/UDSMcpSyncServerTests.java | 53 +++++++++++++++++++ 2 files changed, 105 insertions(+) create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java new file mode 100644 index 000000000..35be2a355 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java @@ -0,0 +1,52 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.io.IOException; +import java.net.UnixDomainSocketAddress; +import java.nio.file.Files; + +import org.junit.jupiter.api.Timeout; + +import io.modelcontextprotocol.server.transport.UDSServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; + +/** + * Tests for {@link McpAsyncServer} using {@link StdioServerTransport}. + * + * @author Christian Tzolov + */ +@Timeout(15) // Giving extra time beyond the client timeout +class UDSMcpAsyncServerTests extends AbstractMcpAsyncServerTests { + + private UnixDomainSocketAddress address; + + @Override + protected void setUp() { + super.onStart(); + address = UnixDomainSocketAddress.of(getClass().getName() + ".unix.socket"); + } + + @Override + protected void tearDown() { + super.onClose(); + if (address != null) { + try { + Files.deleteIfExists(address.getPath()); + } catch (IOException e) { + } + } + } + + protected McpServerTransportProvider createMcpTransportProvider() { + return new UDSServerTransportProvider(address); + } + + @Override + protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(createMcpTransportProvider()); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java new file mode 100644 index 000000000..aa3666fbf --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java @@ -0,0 +1,53 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.io.IOException; +import java.net.UnixDomainSocketAddress; +import java.nio.file.Files; + +import org.junit.jupiter.api.Timeout; + +import io.modelcontextprotocol.server.transport.StdioServerTransportProvider; +import io.modelcontextprotocol.server.transport.UDSServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; + +/** + * Tests for {@link McpSyncServer} using {@link StdioServerTransportProvider}. + * + * @author Christian Tzolov and Scott Lewis + */ +@Timeout(15) // Giving extra time beyond the client timeout +class UDSMcpSyncServerTests extends AbstractMcpSyncServerTests { + + private UnixDomainSocketAddress address; + + @Override + protected void setUp() { + super.onStart(); + address = UnixDomainSocketAddress.of(getClass().getName()+".unix.socket"); + } + + @Override + protected void tearDown() { + super.onClose(); + if (address != null) { + try { + Files.deleteIfExists(address.getPath()); + } catch (IOException e) { + } + } + } + + protected McpServerTransportProvider createMcpTransportProvider() { + return new UDSServerTransportProvider(address); + } + + @Override + protected McpServer.SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(createMcpTransportProvider()); + } + +} From 8e9c09bc2e3823a194e42936ec736f064ed4d4c0 Mon Sep 17 00:00:00 2001 From: Scott Lewis Date: Wed, 6 Aug 2025 15:42:55 -0700 Subject: [PATCH 12/25] Fixes and simplification --- .../transport/UDSServerTransportProvider.java | 27 ++++++++++++------- .../util/ClientNonBlockingSocketChannel.java | 16 +++-------- .../Inet4ClientNonBlockingSocketChannel.java | 7 +++-- .../Inet6ClientNonBlockingSocketChannel.java | 7 +++-- .../util/NonBlockingSocketChannel.java | 16 ++++++++--- .../util/ServerNonBlockingSocketChannel.java | 15 +++-------- .../UDSClientNonBlockingSocketChannel.java | 3 +-- 7 files changed, 44 insertions(+), 47 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java index d6e76a56a..c1b677bb4 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java @@ -44,23 +44,33 @@ public class UDSServerTransportProvider implements McpServerTransportProvider { private UDSMcpSessionTransport transport; - public UDSServerTransportProvider(ObjectMapper objectMapper, UnixDomainSocketAddress unixSocketAddress) - throws IOException { + public UDSServerTransportProvider(UnixDomainSocketAddress unixSocketAddress) { + this(new ObjectMapper(), unixSocketAddress); + } + + public UDSServerTransportProvider(ObjectMapper objectMapper, UnixDomainSocketAddress unixSocketAddress) { Assert.notNull(objectMapper, "The ObjectMapper can not be null"); this.objectMapper = objectMapper; this.address = unixSocketAddress; - this.serverSocketChannel = new UDSServerNonBlockingSocketChannel(); } @Override public void setSessionFactory(McpServerSession.Factory sessionFactory) { + this.transport = new UDSMcpSessionTransport(); + this.session = sessionFactory.create(transport); + this.transport.initProcessing(); + // Also start listening for accept try { + this.serverSocketChannel = new UDSServerNonBlockingSocketChannel(); this.serverSocketChannel.start(this.address, (clientChannel) -> { - this.transport = new UDSMcpSessionTransport(); - this.session = sessionFactory.create(transport); - this.transport.initProcessing(); + if (logger.isDebugEnabled()) { + logger.debug("Accepted connect from clientChannel=" + clientChannel); + } }, (dataLine) -> { String message = (String) dataLine; + if (logger.isDebugEnabled()) { + logger.debug("Received message line=" + message); + } try { this.transport .handleMessage(McpSchema.deserializeJsonRpcMessage(this.objectMapper, message.trim())); @@ -71,6 +81,7 @@ public void setSessionFactory(McpServerSession.Factory sessionFactory) { }); } catch (IOException e) { + // If this happens then we are doomed this.serverSocketChannel.close(); throw new RuntimeException("accepterNonBlockSocketChannel could not be started"); } @@ -114,10 +125,6 @@ public UDSMcpSessionTransport() { this.inboundSink = Sinks.many().unicast().onBackpressureBuffer(); this.outboundSink = Sinks.many().unicast().onBackpressureBuffer(); - // Use bounded schedulers for better resource management - // this.inboundScheduler = - // Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), - // "uds-inbound"); this.outboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), "uds-outbound"); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/ClientNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/ClientNonBlockingSocketChannel.java index 65e976f53..33c7f5f7d 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/ClientNonBlockingSocketChannel.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/ClientNonBlockingSocketChannel.java @@ -17,8 +17,7 @@ public class ClientNonBlockingSocketChannel extends NonBlockingSocketChannel { private SocketChannel client; - public ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) - throws IOException { + public ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { super(selector, incomingBufferSize, executor); } @@ -52,16 +51,9 @@ protected void handleException(SelectionKey key, Exception e) { @Override public void close() { - try { - hardCloseClient(this.client, (client) -> { - this.client = null; - }); - } - catch (IOException e) { - if (logger.isDebugEnabled()) { - logger.debug("Exception in hardCloseClient", e); - } - } + hardCloseClient(this.client, (client) -> { + this.client = null; + }); } public void writeMessageBlocking(String message) throws IOException { diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Inet4ClientNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/Inet4ClientNonBlockingSocketChannel.java index 650fd52b4..b1186e3cd 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/Inet4ClientNonBlockingSocketChannel.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/Inet4ClientNonBlockingSocketChannel.java @@ -14,16 +14,15 @@ public Inet4ClientNonBlockingSocketChannel() throws IOException { super(); } - public Inet4ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) - throws IOException { + public Inet4ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { super(selector, incomingBufferSize, executor); } - public Inet4ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize) throws IOException { + public Inet4ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize) { super(selector, incomingBufferSize); } - public Inet4ClientNonBlockingSocketChannel(Selector selector) throws IOException { + public Inet4ClientNonBlockingSocketChannel(Selector selector) { super(selector); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Inet6ClientNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/Inet6ClientNonBlockingSocketChannel.java index da8739758..9af484858 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/Inet6ClientNonBlockingSocketChannel.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/Inet6ClientNonBlockingSocketChannel.java @@ -14,16 +14,15 @@ public Inet6ClientNonBlockingSocketChannel() throws IOException { super(); } - public Inet6ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) - throws IOException { + public Inet6ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { super(selector, incomingBufferSize, executor); } - public Inet6ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize) throws IOException { + public Inet6ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize) { super(selector, incomingBufferSize); } - public Inet6ClientNonBlockingSocketChannel(Selector selector) throws IOException { + public Inet6ClientNonBlockingSocketChannel(Selector selector) { super(selector); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/NonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/NonBlockingSocketChannel.java index 5fc0eaf9d..de3fe8ba7 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/NonBlockingSocketChannel.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/NonBlockingSocketChannel.java @@ -281,17 +281,25 @@ protected void executorShutdown() { } } - protected void hardCloseClient(SocketChannel client, IOConsumer closeHandler) throws IOException { + protected void hardCloseClient(SocketChannel client, IOConsumer closeHandler) { if (client != null) { Object lock = client.blockingLock(); if (logger.isDebugEnabled()) { logger.debug("hardCloseClient client=" + client); } synchronized (lock) { - if (closeHandler != null) { - closeHandler.apply(client); + try { + if (closeHandler != null) { + closeHandler.apply(client); + } + client.close(); + client = null; + } + catch (IOException e) { + if (logger.isDebugEnabled()) { + logger.debug("hardClose client socketchannel.close exception", e); + } } - client.close(); } executorShutdown(); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/ServerNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/ServerNonBlockingSocketChannel.java index 918635012..4c64e3d18 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/ServerNonBlockingSocketChannel.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/ServerNonBlockingSocketChannel.java @@ -77,19 +77,12 @@ public void writeMessageBlocking(String message) throws IOException { public void close() { SocketChannel client = this.acceptedClient; if (client != null) { - try { - hardCloseClient(client, (c) -> { - if (logger.isDebugEnabled()) { - logger.debug("Unsetting client=" + c); - } - this.acceptedClient = null; - }); - } - catch (IOException e) { + hardCloseClient(client, (c) -> { if (logger.isDebugEnabled()) { - logger.debug("Exception in hardCloseClient", e); + logger.debug("Unsetting client=" + c); } - } + this.acceptedClient = null; + }); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/UDSClientNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/UDSClientNonBlockingSocketChannel.java index ef16590b5..2e279c2b9 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/UDSClientNonBlockingSocketChannel.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/UDSClientNonBlockingSocketChannel.java @@ -9,8 +9,7 @@ public class UDSClientNonBlockingSocketChannel extends ClientNonBlockingSocketChannel { - public UDSClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) - throws IOException { + public UDSClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { super(selector, incomingBufferSize, executor); } From b7a533f969793049f30748af44d31b4a6d75af14 Mon Sep 17 00:00:00 2001 From: Scott Lewis Date: Wed, 6 Aug 2025 17:33:32 -0700 Subject: [PATCH 13/25] Layout fixes --- .../server/UDSMcpAsyncServerTests.java | 3 ++- .../server/UDSMcpSyncServerTests.java | 13 +++++++------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java index 35be2a355..b5773e0fb 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java @@ -35,7 +35,8 @@ protected void tearDown() { if (address != null) { try { Files.deleteIfExists(address.getPath()); - } catch (IOException e) { + } + catch (IOException e) { } } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java index aa3666fbf..d2d2581d1 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java @@ -23,26 +23,27 @@ class UDSMcpSyncServerTests extends AbstractMcpSyncServerTests { private UnixDomainSocketAddress address; - + @Override protected void setUp() { super.onStart(); - address = UnixDomainSocketAddress.of(getClass().getName()+".unix.socket"); + address = UnixDomainSocketAddress.of(getClass().getName() + ".unix.socket"); } - + @Override protected void tearDown() { super.onClose(); if (address != null) { try { Files.deleteIfExists(address.getPath()); - } catch (IOException e) { + } + catch (IOException e) { } } } - + protected McpServerTransportProvider createMcpTransportProvider() { - return new UDSServerTransportProvider(address); + return new UDSServerTransportProvider(address); } @Override From 6e41a22d41c947b4bf00a67b9bb7e5c1f3e6582f Mon Sep 17 00:00:00 2001 From: Scott Lewis Date: Wed, 6 Aug 2025 18:55:13 -0700 Subject: [PATCH 14/25] Update for tests --- .../modelcontextprotocol/server/UDSMcpAsyncServerTests.java | 3 ++- .../modelcontextprotocol/server/UDSMcpSyncServerTests.java | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java index b5773e0fb..cad1eae5b 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java @@ -14,9 +14,10 @@ import io.modelcontextprotocol.spec.McpServerTransportProvider; /** - * Tests for {@link McpAsyncServer} using {@link StdioServerTransport}. + * Tests for {@link McpAsyncServer} using {@link UDSServerTransport}. * * @author Christian Tzolov + * @author Scott Lewis */ @Timeout(15) // Giving extra time beyond the client timeout class UDSMcpAsyncServerTests extends AbstractMcpAsyncServerTests { diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java index d2d2581d1..6e896d478 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java @@ -15,9 +15,10 @@ import io.modelcontextprotocol.spec.McpServerTransportProvider; /** - * Tests for {@link McpSyncServer} using {@link StdioServerTransportProvider}. + * Tests for {@link McpSyncServer} using {@link UDSServerTransportProvider}. * - * @author Christian Tzolov and Scott Lewis + * @author Christian Tzolov + * @author Scott Lewis */ @Timeout(15) // Giving extra time beyond the client timeout class UDSMcpSyncServerTests extends AbstractMcpSyncServerTests { From 558bf2a3094901049ecf8e374700320b16038877 Mon Sep 17 00:00:00 2001 From: Scott Lewis Date: Wed, 6 Aug 2025 18:56:09 -0700 Subject: [PATCH 15/25] Removed unnecessary import --- .../io/modelcontextprotocol/server/UDSMcpSyncServerTests.java | 1 - 1 file changed, 1 deletion(-) diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java index 6e896d478..57ec7b766 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java @@ -10,7 +10,6 @@ import org.junit.jupiter.api.Timeout; -import io.modelcontextprotocol.server.transport.StdioServerTransportProvider; import io.modelcontextprotocol.server.transport.UDSServerTransportProvider; import io.modelcontextprotocol.spec.McpServerTransportProvider; From d338800b36f800531529b181aceb2ba690582fd4 Mon Sep 17 00:00:00 2001 From: Scott Lewis Date: Thu, 7 Aug 2025 15:44:42 -0700 Subject: [PATCH 16/25] Added UDSMcpAsyncClientTests and UDSMcpSyncClientTests. Also added 'EverythingServer' to allow Java mcp server to provide support for client tests...that currently use the JavaScript 'everything' server. --- .../transport/UDSClientTransportProvider.java | 8 +- .../transport/UDSServerTransportProvider.java | 16 +-- .../client/UDSMcpAsyncClientTests.java | 69 +++++++++ .../client/UDSMcpSyncClientTests.java | 69 +++++++++ .../server/EverythingServer.java | 131 ++++++++++++++++++ 5 files changed, 280 insertions(+), 13 deletions(-) create mode 100644 mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpAsyncClientTests.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpSyncClientTests.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/EverythingServer.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java index bc364dea8..28bb1fe8a 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java @@ -41,6 +41,10 @@ public class UDSClientTransportProvider implements McpClientTransport { private volatile boolean isClosing = false; + public UDSClientTransportProvider(UnixDomainSocketAddress targetAddress) throws IOException { + this(new ObjectMapper(), targetAddress); + } + public UDSClientTransportProvider(ObjectMapper objectMapper, UnixDomainSocketAddress targetAddress) throws IOException { Assert.notNull(objectMapper, "The ObjectMapper can not be null"); @@ -159,9 +163,7 @@ public Mono closeGracefully() { this.clientChannel = null; } return Mono.empty(); - })).doOnNext(o -> { - logger.info("MCP server process stopped"); - }).then(Mono.fromRunnable(() -> { + })).then(Mono.fromRunnable(() -> { try { // The Threads are blocked on readLine so disposeGracefully would not // interrupt them, therefore we issue an async hard dispose. diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java index c1b677bb4..977f5b90e 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java @@ -58,7 +58,10 @@ public UDSServerTransportProvider(ObjectMapper objectMapper, UnixDomainSocketAdd public void setSessionFactory(McpServerSession.Factory sessionFactory) { this.transport = new UDSMcpSessionTransport(); this.session = sessionFactory.create(transport); - this.transport.initProcessing(); + this.transport.handleIncomingMessages(); + if (this.transport.isStarted.compareAndSet(false, true)) { + inboundReady.tryEmitValue(null); + } // Also start listening for accept try { this.serverSocketChannel = new UDSServerNonBlockingSocketChannel(); @@ -66,6 +69,8 @@ public void setSessionFactory(McpServerSession.Factory sessionFactory) { if (logger.isDebugEnabled()) { logger.debug("Accepted connect from clientChannel=" + clientChannel); } + // Start outbound processing now that the clientChannel has been accepted + this.transport.startOutboundProcessing(); }, (dataLine) -> { String message = (String) dataLine; if (logger.isDebugEnabled()) { @@ -171,18 +176,9 @@ public Mono closeGracefully() { @Override public void close() { isClosing.set(true); - serverSocketChannel.close(); logger.debug("Session transport closed"); } - private void initProcessing() { - handleIncomingMessages(); - if (isStarted.compareAndSet(false, true)) { - inboundReady.tryEmitValue(null); - } - startOutboundProcessing(); - } - private void handleIncomingMessages() { this.inboundSink.asFlux().flatMap(message -> session.handle(message)).doOnTerminate(() -> { // The outbound processing will dispose its scheduler upon completion diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpAsyncClientTests.java new file mode 100644 index 000000000..99121e01c --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpAsyncClientTests.java @@ -0,0 +1,69 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.io.IOException; +import java.net.UnixDomainSocketAddress; +import java.nio.file.Files; +import java.time.Duration; + +import org.junit.jupiter.api.Timeout; + +import io.modelcontextprotocol.client.transport.UDSClientTransportProvider; +import io.modelcontextprotocol.server.EverythingServer; +import io.modelcontextprotocol.server.transport.UDSServerTransportProvider; +import io.modelcontextprotocol.spec.McpClientTransport; + +/** + * Tests for the {@link McpAyncClient} with {@link UDSClientTransport}. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + * @author Scott Lewis + */ +@Timeout(15) // Giving extra time beyond the client timeout +class UDSMcpAsyncClientTests extends AbstractMcpAsyncClientTests { + + UnixDomainSocketAddress address; + EverythingServer server; + + @Override + protected void onStart() { + this.address = UnixDomainSocketAddress.of(getClass().getName() + ".socket"); + try { + // Delete this file if exists from previous run + Files.deleteIfExists(this.address.getPath()); + } catch (IOException e) { + throw new RuntimeException(e); + } + this.server = new EverythingServer(new UDSServerTransportProvider(address)); + } + + @Override + protected void onClose() { + server.closeGracefully(); + server = null; + try { + Files.deleteIfExists(address.getPath()); + } catch (IOException e) { + throw new RuntimeException(e); + } + address = null; + } + + @Override + protected McpClientTransport createMcpTransport() { + try { + return new UDSClientTransportProvider(address); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(2); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpSyncClientTests.java new file mode 100644 index 000000000..c52d98a97 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpSyncClientTests.java @@ -0,0 +1,69 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.io.IOException; +import java.net.UnixDomainSocketAddress; +import java.nio.file.Files; +import java.time.Duration; + +import org.junit.jupiter.api.Timeout; + +import io.modelcontextprotocol.client.transport.UDSClientTransportProvider; +import io.modelcontextprotocol.server.EverythingServer; +import io.modelcontextprotocol.server.transport.UDSServerTransportProvider; +import io.modelcontextprotocol.spec.McpClientTransport; + +/** + * Tests for the {@link McpSyncClient} with {@link UDSClientTransport}. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + * @author Scott Lewis + */ +@Timeout(15) // Giving extra time beyond the client timeout +class UDSMcpSyncClientTests extends AbstractMcpSyncClientTests { + + UnixDomainSocketAddress address; + EverythingServer server; + + @Override + protected void onStart() { + this.address = UnixDomainSocketAddress.of(getClass().getName() + ".socket"); + try { + // Delete this file if exists from previous run + Files.deleteIfExists(this.address.getPath()); + } catch (IOException e) { + throw new RuntimeException(e); + } + this.server = new EverythingServer(new UDSServerTransportProvider(address)); + } + + @Override + protected void onClose() { + server.closeGracefully(); + server = null; + try { + Files.deleteIfExists(address.getPath()); + } catch (IOException e) { + throw new RuntimeException(e); + } + address = null; + } + + @Override + protected McpClientTransport createMcpTransport() { + try { + return new UDSClientTransportProvider(address); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(2); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/EverythingServer.java b/mcp/src/test/java/io/modelcontextprotocol/server/EverythingServer.java new file mode 100644 index 000000000..a158ab2fb --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/EverythingServer.java @@ -0,0 +1,131 @@ +package io.modelcontextprotocol.server; + +import java.util.List; + +import io.modelcontextprotocol.server.McpServerFeatures.SyncPromptSpecification; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema.Annotations; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.spec.McpSchema.ProgressNotification; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptMessage; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.Resource; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.SamplingMessage; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import io.modelcontextprotocol.spec.McpSchema.TextResourceContents; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest.ContextInclusionStrategy; + +public class EverythingServer { + + private static final String TEST_RESOURCE_URI = "test://resources/"; + + private static final String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + private McpSyncServer server; + + public EverythingServer(McpServerTransportProvider transport) { + McpServerFeatures.SyncResourceSpecification[] specs = new McpServerFeatures.SyncResourceSpecification[10]; + for (int i = 0; i < 10; i++) { + String istr = String.valueOf(i); + String uri = TEST_RESOURCE_URI + istr; + specs[i] = new McpServerFeatures.SyncResourceSpecification( + Resource.builder().uri(uri).name("Test Resource").mimeType("text/plain") + .description("Test resource description").build(), + (exchange, + req) -> new ReadResourceResult(List.of(new TextResourceContents(uri, "text/plain", istr)))); + } + + this.server = McpServer.sync(transport).serverInfo(getClass().getName() + "-server", "1.0.0") + .capabilities( + ServerCapabilities.builder().logging().tools(true).prompts(true).resources(true, true).build()) + .toolCall(Tool.builder().name("echo").description("echo tool description").inputSchema(emptyJsonSchema) + .build(), (exchange, request) -> { + return CallToolResult.builder().addTextContent((String) request.arguments().get("message")) + .build(); + }) + .toolCall( + Tool.builder().name("add").description("add two integers").inputSchema(emptyJsonSchema).build(), + (exchange, request) -> { + Integer a = (Integer) request.arguments().get("a"); + Integer b = (Integer) request.arguments().get("b"); + + return CallToolResult.builder().addTextContent(String.valueOf(a + b)).build(); + }) + .toolCall(Tool.builder().name("sampleLLM").description("sampleLLM tool").inputSchema(emptyJsonSchema) + .build(), (exchange, request) -> { + String prompt = (String) request.arguments().get("prompt"); + Integer maxTokens = (Integer) request.arguments().get("maxTokens"); + SamplingMessage sm = new SamplingMessage(McpSchema.Role.USER, + new TextContent("Resource sampleLLM context: " + prompt)); + CreateMessageRequest cmRequest = CreateMessageRequest.builder().messages(List.of(sm)) + .systemPrompt("You are a helpful test server.").maxTokens(maxTokens) + .temperature(0.7).includeContext(ContextInclusionStrategy.THIS_SERVER).build(); + CreateMessageResult result = exchange.createMessage(cmRequest); + + return CallToolResult.builder() + .addTextContent("LLM sampling result: " + ((TextContent) result.content()).text()) + .build(); + }) + .toolCall(Tool.builder().name("longRunningOperation") + .description("Demonstrates a long running operation with progress updates") + .inputSchema(emptyJsonSchema).build(), (exchange, request) -> { + String progressToken = (String) request.progressToken(); + int steps = (Integer) request.arguments().get("steps"); + for (int i = 0; i < steps; i++) { + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + if (progressToken != null) { + exchange.progressNotification( + new ProgressNotification(progressToken, (double) i + 1, (double) steps, + "progress message " + String.valueOf(i + 1))); + } + } + return CallToolResult.builder().content(List.of(new TextContent("done"))).build(); + }) + .toolCall(Tool.builder().name("annotatedMessage").description("annotated message").build(), + (exchange, request) -> { + String messageType = (String) request.arguments().get("messageType"); + Annotations annotations = null; + if (messageType.equals("success")) { + annotations = new Annotations(List.of(McpSchema.Role.USER), 0.7); + } else if (messageType.equals("error")) { + annotations = new Annotations(List.of(McpSchema.Role.USER, McpSchema.Role.ASSISTANT), + 1.0); + } else if (messageType.equals("debug")) { + annotations = new Annotations(List.of(McpSchema.Role.ASSISTANT), 0.3); + } + return CallToolResult.builder().addContent(new TextContent(annotations, "some response")) + .build(); + }) + .prompts(List.of(new SyncPromptSpecification( + new Prompt("simple_prompt", "Simple prompt description", null), (exchange, request) -> { + return new GetPromptResult("description", + List.of(new PromptMessage(Role.USER, new TextContent("hello")))); + }))) + .resources(specs).build(); + } + + public void closeGracefully() { + if (this.server != null) { + this.server.closeGracefully(); + this.server = null; + } + } +} From 29a49316c1aa2b45979e70ec5a8d16d9323f96ce Mon Sep 17 00:00:00 2001 From: Scott Lewis Date: Wed, 13 Aug 2025 17:44:48 -0700 Subject: [PATCH 17/25] Refactoring for simplification and reliability. Also fixed failing tests in UDSMcpAsync/Sync Server/Client Tests. Now the only failing tests are those that are expecting specific behavior of javascript everything server for tests. io.modelcontextprotocol.client.AbstractMcpSyncClientTests.testCallTool() fails because of this assertion assertThat(result.isError()).isNull(); This is asserting that the isError is null...which it is in the successful case for the javascript server. The java TestEverythingServer, however the isError is false (not null). I believe that false is correct (the java builder will not allow me to set isError to null). io.modelcontextprotocol.client.AbstractMcpAsyncClientTests.testCallToolWithInvalidTool() line 247 has e -> assertThat(e).isInstanceOf(McpError.class).hasMessage("Unknown tool: nonexistent_tool")) And the assertion fails because the java TestEverythingServer (included in this pr) provides the message "Tool not found: nonexistent_tool" io.modelcontextprotocol.client.AbstractMcpAsyncClientTests.testCallTool() fails on line 232 with the same assertion error that io.modelcontextprotocol.client.AbstractMcpSyncClientTests.testCallTool() expecting isError to be null rather than false error. Signed-off-by: Scott Lewis --- .../transport/UDSClientTransportProvider.java | 201 ++++++++---- .../transport/UDSServerTransportProvider.java | 167 +++++----- ...hannel.java => AbstractSocketChannel.java} | 291 +++++++----------- .../util/ClientNonBlockingSocketChannel.java | 66 ---- .../util/ClientSocketChannel.java | 100 ++++++ .../Inet4ClientNonBlockingSocketChannel.java | 35 --- .../Inet4ServerNonBlockingSocketChannel.java | 34 -- .../Inet6ClientNonBlockingSocketChannel.java | 35 --- .../Inet6ServerNonBlockingSocketChannel.java | 34 -- ...ketChannel.java => ServSocketChannel.java} | 22 +- .../UDSClientNonBlockingSocketChannel.java | 33 -- .../util/UDSClientSocketChannel.java | 33 ++ ...annel.java => UDSServerSocketChannel.java} | 10 +- .../client/UDSMcpAsyncClientTests.java | 48 ++- .../client/UDSMcpSyncClientTests.java | 41 ++- .../server/EverythingServer.java | 131 -------- .../server/TestEverythingServer.java | 149 +++++++++ .../server/UDSMcpAsyncServerTests.java | 29 +- .../server/UDSMcpSyncServerTests.java | 29 +- 19 files changed, 720 insertions(+), 768 deletions(-) rename mcp/src/main/java/io/modelcontextprotocol/util/{NonBlockingSocketChannel.java => AbstractSocketChannel.java} (50%) delete mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/ClientNonBlockingSocketChannel.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/ClientSocketChannel.java delete mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/Inet4ClientNonBlockingSocketChannel.java delete mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/Inet4ServerNonBlockingSocketChannel.java delete mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/Inet6ClientNonBlockingSocketChannel.java delete mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/Inet6ServerNonBlockingSocketChannel.java rename mcp/src/main/java/io/modelcontextprotocol/util/{ServerNonBlockingSocketChannel.java => ServSocketChannel.java} (76%) delete mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/UDSClientNonBlockingSocketChannel.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/UDSClientSocketChannel.java rename mcp/src/main/java/io/modelcontextprotocol/util/{UDSServerNonBlockingSocketChannel.java => UDSServerSocketChannel.java} (60%) delete mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/EverythingServer.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/TestEverythingServer.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java index 28bb1fe8a..7d9dc3f80 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java @@ -2,8 +2,10 @@ import java.io.IOException; import java.net.UnixDomainSocketAddress; +import java.nio.channels.SelectionKey; import java.time.Duration; import java.util.concurrent.Executors; +import java.util.function.Consumer; import java.util.function.Function; import org.slf4j.Logger; @@ -16,7 +18,7 @@ import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; import io.modelcontextprotocol.util.Assert; -import io.modelcontextprotocol.util.UDSClientNonBlockingSocketChannel; +import io.modelcontextprotocol.util.UDSClientSocketChannel; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.publisher.Sinks; @@ -25,7 +27,7 @@ public class UDSClientTransportProvider implements McpClientTransport { - private static final Logger logger = LoggerFactory.getLogger(UDSClientTransportProvider.class); + private static final Logger logger = LoggerFactory.getLogger(StdioClientTransport.class); private final Sinks.Many inboundSink; @@ -33,20 +35,56 @@ public class UDSClientTransportProvider implements McpClientTransport { private ObjectMapper objectMapper; - private UDSClientNonBlockingSocketChannel clientChannel; + /** Scheduler for handling outbound messages to the server process */ + private Scheduler outboundScheduler; - private UnixDomainSocketAddress targetAddress; + private final Sinks.Many errorSink; - private Scheduler outboundScheduler; + private UDSClientSocketChannel clientChannel; + + private UnixDomainSocketAddress targetAddress; private volatile boolean isClosing = false; - public UDSClientTransportProvider(UnixDomainSocketAddress targetAddress) throws IOException { + // visible for tests + private Consumer stdErrorHandler = error -> logger.info("STDERR Message received: {}", error); + + public UDSClientTransportProvider(UnixDomainSocketAddress targetAddress) { this(new ObjectMapper(), targetAddress); } - - public UDSClientTransportProvider(ObjectMapper objectMapper, UnixDomainSocketAddress targetAddress) - throws IOException { + + public UDSClientTransportProvider(ObjectMapper objectMapper, UnixDomainSocketAddress targetAddress) { + Assert.notNull(objectMapper, "objectMapper cannot be null"); + this.objectMapper = objectMapper; + Assert.notNull(objectMapper, "targetAddress cannot be null"); + this.targetAddress = targetAddress; + this.inboundSink = Sinks.many().unicast().onBackpressureBuffer(); + this.outboundSink = Sinks.many().unicast().onBackpressureBuffer(); + this.errorSink = Sinks.many().unicast().onBackpressureBuffer(); + try { + this.clientChannel = new UDSClientSocketChannel() { + @Override + protected void handleException(SelectionKey key, Exception e) { + isClosing = true; + super.handleException(key, e); + } + }; + } catch (IOException e) { + throw new RuntimeException(e); + } + this.outboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), "outbound"); + } + + /** + * Creates a new StdioClientTransport with the specified parameters and + * ObjectMapper. + * + * @param params The parameters for configuring the server process + * @param objectMapper The ObjectMapper to use for JSON + * serialization/deserialization + */ + public UDSClientTransportProvider(ServerParameters params, ObjectMapper objectMapper) { + Assert.notNull(params, "The params can not be null"); Assert.notNull(objectMapper, "The ObjectMapper can not be null"); this.inboundSink = Sinks.many().unicast().onBackpressureBuffer(); @@ -54,76 +92,115 @@ public UDSClientTransportProvider(ObjectMapper objectMapper, UnixDomainSocketAdd this.objectMapper = objectMapper; - // Start threads + this.errorSink = Sinks.many().unicast().onBackpressureBuffer(); + + // Start thread for outbound this.outboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), "outbound"); - this.clientChannel = new UDSClientNonBlockingSocketChannel(); - this.targetAddress = targetAddress; } + /** + * Starts the server process and initializes the message processing streams. + * This method sets up the process with the configured command, arguments, and + * environment, then starts the inbound, outbound, and error processing threads. + * + * @throws RuntimeException if the process fails to start or if the process + * streams are null + */ @Override public Mono connect(Function, Mono> handler) { return Mono.fromRunnable(() -> { handleIncomingMessages(handler); + handleIncomingErrors(); + + // Connect client channel try { - this.clientChannel.connectBlocking(targetAddress, (client) -> { - logger.info("CONNECTED to targetAddress=" + targetAddress); - }, (data) -> { - JSONRPCMessage json = McpSchema.deserializeJsonRpcMessage(this.objectMapper, data); - if (!this.inboundSink.tryEmitNext(json).isSuccess()) { + this.clientChannel.connect(targetAddress, (client) -> { + if (logger.isInfoEnabled()) { + logger.info("UDSClientTransportProvider CONNECTED to targetAddress=" + targetAddress); + } + }, (message) -> { + if (logger.isDebugEnabled()) { + logger.debug("received message=" + message); + } + // Incoming messages processed right here + McpSchema.JSONRPCMessage jsonMessage = McpSchema.deserializeJsonRpcMessage(objectMapper, message); + if (!this.inboundSink.tryEmitNext(jsonMessage).isSuccess()) { if (!isClosing) { - logger.error("Failed to enqueue inbound message: {}", json); + if (logger.isDebugEnabled()) { + logger.error("Failed to enqueue inbound json message: {}", jsonMessage); + } } } }); - } - catch (IOException e) { + } catch (IOException e) { this.clientChannel.close(); throw new RuntimeException( "Connect to address=" + targetAddress + " failed message: " + e.getMessage()); } + startOutboundProcessing(); + }).subscribeOn(Schedulers.boundedElastic()); } + /** + * Sets the handler for processing transport-level errors. + * + *

+ * The provided handler will be called when errors occur during transport + * operations, such as connection failures or protocol violations. + *

+ * + * @param errorHandler a consumer that processes error messages + */ + public void setStdErrorHandler(Consumer errorHandler) { + this.stdErrorHandler = errorHandler; + } + private void handleIncomingMessages(Function, Mono> inboundMessageHandler) { - this.inboundSink.asFlux() - .flatMap(message -> Mono.just(message) - .transform(inboundMessageHandler) - .contextWrite(ctx -> ctx.put("observation", "myObservation"))) - .subscribe(); + this.inboundSink.asFlux().flatMap(message -> Mono.just(message).transform(inboundMessageHandler) + .contextWrite(ctx -> ctx.put("observation", "myObservation"))).subscribe(); + } + + private void handleIncomingErrors() { + this.errorSink.asFlux().subscribe(e -> { + this.stdErrorHandler.accept(e); + }); } @Override public Mono sendMessage(JSONRPCMessage message) { - if (this.outboundSink.tryEmitNext(message).isSuccess()) { - // TODO: essentially we could reschedule ourselves in some time and make - // another attempt with the already read data but pause reading until - // success - // In this approach we delegate the retry and the backpressure onto the - // caller. This might be enough for most cases. - return Mono.empty(); - } - else { - return Mono.error(new RuntimeException("Failed to enqueue message")); - } + outboundSink.emitNext(message, (signalType, emitResult) -> { + // Allow retry + return true; + }); + return Mono.empty(); } + /** + * Starts the outbound processing thread that writes JSON-RPC messages to the + * process's output stream. Messages are serialized to JSON and written with a + * newline delimiter. + */ private void startOutboundProcessing() { this.handleOutbound(messages -> messages - // this bit is important since writes come from user threads, and we - // want to ensure that the actual writing happens on a dedicated thread - .publishOn(outboundScheduler) - .handle((message, s) -> { - if (message != null && !isClosing) { - try { - this.clientChannel.writeMessageBlocking(objectMapper.writeValueAsString(message)); - s.next(message); - } - catch (IOException e) { - s.error(new RuntimeException(e)); + // this bit is important since writes come from user threads, and we + // want to ensure that the actual writing happens on a dedicated thread + .publishOn(outboundScheduler).handle((message, sink) -> { + if (message != null && !isClosing) { + try { + clientChannel.writeMessage(objectMapper.writeValueAsString(message)); + sink.next(message); + } catch (IOException e) { + if (!isClosing) { + logger.error("Error writing message", e); + sink.error(new RuntimeException(e)); + } else { + logger.debug("Stream closed during shutdown", e); + } + } } - } - })); + })); } protected void handleOutbound(Function, Flux> outboundConsumer) { @@ -140,9 +217,10 @@ protected void handleOutbound(Function, Flux closeGracefully() { // First complete all sinks to stop accepting new messages inboundSink.tryEmitComplete(); outboundSink.tryEmitComplete(); + errorSink.tryEmitComplete(); + // Give a short time for any pending messages to be processed return Mono.delay(Duration.ofMillis(100)).then(); - })).then(Mono.defer(() -> { - // Close our clientChannel - if (this.clientChannel != null) { - this.clientChannel.close(); - this.clientChannel = null; - } - return Mono.empty(); })).then(Mono.fromRunnable(() -> { try { - // The Threads are blocked on readLine so disposeGracefully would not - // interrupt them, therefore we issue an async hard dispose. outboundScheduler.dispose(); - logger.debug("Graceful shutdown completed"); - } - catch (Exception e) { + } catch (Exception e) { logger.error("Error during graceful shutdown", e); } })).then().subscribeOn(Schedulers.boundedElastic()); } + public Sinks.Many getErrorSink() { + return this.errorSink; + } + @Override public T unmarshalFrom(Object data, TypeReference typeRef) { return this.objectMapper.convertValue(data, typeRef); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java index 977f5b90e..2d4908cb3 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java @@ -2,6 +2,8 @@ import java.io.IOException; import java.net.UnixDomainSocketAddress; +import java.nio.channels.SelectionKey; +import java.util.List; import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; @@ -18,8 +20,9 @@ import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerTransport; import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.ProtocolVersions; import io.modelcontextprotocol.util.Assert; -import io.modelcontextprotocol.util.UDSServerNonBlockingSocketChannel; +import io.modelcontextprotocol.util.UDSServerSocketChannel; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.publisher.Sinks; @@ -28,68 +31,54 @@ public class UDSServerTransportProvider implements McpServerTransportProvider { - private static final Logger logger = LoggerFactory.getLogger(UDSServerTransportProvider.class); + private static final Logger logger = LoggerFactory.getLogger(StdioServerTransportProvider.class); private final ObjectMapper objectMapper; + private UDSMcpSessionTransport transport; + private McpServerSession session; private final AtomicBoolean isClosing = new AtomicBoolean(false); private final Sinks.One inboundReady = Sinks.one(); - private UDSServerNonBlockingSocketChannel serverSocketChannel; - - private UnixDomainSocketAddress address; + private final Sinks.One outboundReady = Sinks.one(); - private UDSMcpSessionTransport transport; + private UnixDomainSocketAddress targetAddress; + /** + * Creates a new UDSServerTransportProvider with a default ObjectMapper + * + * @param unixSocketAddress the UDS socket address to bind to. Must not be null. + */ public UDSServerTransportProvider(UnixDomainSocketAddress unixSocketAddress) { this(new ObjectMapper(), unixSocketAddress); } + /** + * Creates a new UDSServerTransportProvider with the specified ObjectMapper + * + * @param objectMapper The ObjectMapper to use for JSON + * serialization/deserialization + */ public UDSServerTransportProvider(ObjectMapper objectMapper, UnixDomainSocketAddress unixSocketAddress) { - Assert.notNull(objectMapper, "The ObjectMapper can not be null"); + Assert.notNull(objectMapper, "objectMapper cannot be null"); this.objectMapper = objectMapper; - this.address = unixSocketAddress; + Assert.notNull(unixSocketAddress, "unixSocketAddress cannot be null"); + this.targetAddress = unixSocketAddress; + } + + @Override + public List protocolVersions() { + return List.of(ProtocolVersions.MCP_2024_11_05); } @Override public void setSessionFactory(McpServerSession.Factory sessionFactory) { this.transport = new UDSMcpSessionTransport(); this.session = sessionFactory.create(transport); - this.transport.handleIncomingMessages(); - if (this.transport.isStarted.compareAndSet(false, true)) { - inboundReady.tryEmitValue(null); - } - // Also start listening for accept - try { - this.serverSocketChannel = new UDSServerNonBlockingSocketChannel(); - this.serverSocketChannel.start(this.address, (clientChannel) -> { - if (logger.isDebugEnabled()) { - logger.debug("Accepted connect from clientChannel=" + clientChannel); - } - // Start outbound processing now that the clientChannel has been accepted - this.transport.startOutboundProcessing(); - }, (dataLine) -> { - String message = (String) dataLine; - if (logger.isDebugEnabled()) { - logger.debug("Received message line=" + message); - } - try { - this.transport - .handleMessage(McpSchema.deserializeJsonRpcMessage(this.objectMapper, message.trim())); - } - catch (IOException e) { - this.serverSocketChannel.close(); - } - }); - } - catch (IOException e) { - // If this happens then we are doomed - this.serverSocketChannel.close(); - throw new RuntimeException("accepterNonBlockSocketChannel could not be started"); - } + this.transport.initProcessing(); } @Override @@ -98,7 +87,7 @@ public Mono notifyClients(String method, Object params) { return Mono.error(new McpError("No session to close")); } return this.session.sendNotification(method, params) - .doOnError(e -> logger.error("Failed to send notification: {}", e.getMessage())); + .doOnError(e -> logger.error("Failed to send notification: {}", e.getMessage())); } @Override @@ -110,7 +99,7 @@ public Mono closeGracefully() { } /** - * Implementation of McpServerTransport for the stdio session. + * Implementation of McpServerTransport for the uds session. */ private class UDSMcpSessionTransport implements McpServerTransport { @@ -118,44 +107,43 @@ private class UDSMcpSessionTransport implements McpServerTransport { private final Sinks.Many outboundSink; - private final AtomicBoolean isStarted = new AtomicBoolean(false); - /** Scheduler for handling outbound messages */ private Scheduler outboundScheduler; - private final Sinks.One outboundReady = Sinks.one(); + private final AtomicBoolean isStarted = new AtomicBoolean(false); + + private final UDSServerSocketChannel serverSocketChannel; public UDSMcpSessionTransport() { - this.inboundSink = Sinks.many().unicast().onBackpressureBuffer(); this.outboundSink = Sinks.many().unicast().onBackpressureBuffer(); - this.outboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), "uds-outbound"); - } - - public void handleMessage(McpSchema.JSONRPCMessage json) throws IOException { try { - if (!this.inboundSink.tryEmitNext(json).isSuccess()) { - throw new Exception("Failed to enqueue message"); - } - } - catch (Exception e) { - logIfNotClosing("Error processing inbound message", e); - throw new IOException("Error in processing inbound message", e); + this.serverSocketChannel = new UDSServerSocketChannel() { + @Override + protected void handleException(SelectionKey key, Exception e) { + isClosing.set(true); + if (session != null) { + session.close(); + session = null; + } + inboundSink.tryEmitComplete(); + } + }; + } catch (IOException e) { + throw new RuntimeException(e); } } @Override public Mono sendMessage(McpSchema.JSONRPCMessage message) { - return Mono.zip(inboundReady.asMono(), outboundReady.asMono()).then(Mono.defer(() -> { - if (outboundSink.tryEmitNext(message).isSuccess()) { - return Mono.empty(); - } - else { - return Mono.error(new RuntimeException("Failed to enqueue message")); - } + outboundSink.emitNext(message, (signalType, emitResult) -> { + // Allow retry + return true; + }); + return Mono.empty(); })); } @@ -179,17 +167,52 @@ public void close() { logger.debug("Session transport closed"); } + private void initProcessing() { + handleIncomingMessages(); + startInboundProcessing(); + startOutboundProcessing(); + + inboundReady.tryEmitValue(null); + outboundReady.tryEmitValue(null); + } + private void handleIncomingMessages() { this.inboundSink.asFlux().flatMap(message -> session.handle(message)).doOnTerminate(() -> { - // The outbound processing will dispose its scheduler upon completion this.outboundSink.tryEmitComplete(); - // this.inboundScheduler.dispose(); }).subscribe(); } /** - * Starts the outbound processing thread that writes JSON-RPC messages to stdout. - * Messages are serialized to JSON and written with a newline delimiter. + * Starts the inbound processing thread that reads JSON-RPC messages from stdin. + * Messages are deserialized and passed to the session for handling. + */ + private void startInboundProcessing() { + if (isStarted.compareAndSet(false, true)) { + try { + this.serverSocketChannel.start(targetAddress, (clientChannel) -> { + if (logger.isDebugEnabled()) { + logger.debug("Accepted connect from clientChannel=" + clientChannel); + } + }, (message) -> { + if (logger.isDebugEnabled()) { + logger.debug("Received message=" + message); + } + // Incoming messages processed right here + McpSchema.JSONRPCMessage jsonMessage = McpSchema.deserializeJsonRpcMessage(objectMapper, + message); + if (!this.inboundSink.tryEmitNext(jsonMessage).isSuccess()) { + throw new IOException("Error adding jsonMessge to inboundSink"); + } + }); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + } + + /** + * Starts the outbound processing thread that writes JSON-RPC messages to + * stdout. Messages are serialized to JSON and written with a newline delimiter. */ private void startOutboundProcessing() { Function, Flux> outboundConsumer = messages -> messages // @formatter:off @@ -198,7 +221,7 @@ private void startOutboundProcessing() { .handle((message, sink) -> { if (message != null && !isClosing.get()) { try { - serverSocketChannel.writeMessageBlocking(objectMapper.writeValueAsString(message)); + serverSocketChannel.writeMessage(objectMapper.writeValueAsString(message)); sink.next(message); } catch (IOException e) { @@ -231,12 +254,6 @@ else if (isClosing.get()) { outboundConsumer.apply(outboundSink.asFlux()).subscribe(); } // @formatter:on - private void logIfNotClosing(String message, Exception e) { - if (!isClosing.get()) { - logger.error(message, e); - } - } - } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/NonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/AbstractSocketChannel.java similarity index 50% rename from mcp/src/main/java/io/modelcontextprotocol/util/NonBlockingSocketChannel.java rename to mcp/src/main/java/io/modelcontextprotocol/util/AbstractSocketChannel.java index de3fe8ba7..1b25dff48 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/NonBlockingSocketChannel.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/AbstractSocketChannel.java @@ -19,17 +19,43 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public abstract class NonBlockingSocketChannel { +public abstract class AbstractSocketChannel { - private static final Logger logger = LoggerFactory.getLogger(NonBlockingSocketChannel.class); + private static final Logger logger = LoggerFactory.getLogger(AbstractSocketChannel.class); public static final int DEFAULT_INBUFFER_SIZE = 1024; - protected static String MESSAGE_DELIMITER = "\n"; + public static String DEFAULT_MESSAGE_DELIMITER = "\n"; - protected static int BLOCKING_WRITE_TIMEOUT = 5000; + protected String messageDelimiter = DEFAULT_MESSAGE_DELIMITER; - protected static int BLOCKING_CONNECT_TIMEOUT = 10000; + protected void setMessageDelimiter(String delim) { + this.messageDelimiter = delim; + } + + public static int DEFAULT_WRITE_TIMEOUT = 5000; // ms + + protected int writeTimeout = DEFAULT_WRITE_TIMEOUT; + + protected void setWriteTimeout(int timeout) { + this.writeTimeout = timeout; + } + + public static int DEFAULT_CONNECT_TIMEOUT = 10000; // ms + + protected int connectTimeout = DEFAULT_CONNECT_TIMEOUT; + + protected void setConnectTimeout(int timeout) { + this.connectTimeout = timeout; + } + + public static int DEFAULT_TERMINATION_TIMEOUT = 2000; // ms + + protected int terminationTimeout = DEFAULT_TERMINATION_TIMEOUT; + + protected void setTerminationTimeout(int timeout) { + this.terminationTimeout = timeout; + } protected final Selector selector; @@ -37,6 +63,8 @@ public abstract class NonBlockingSocketChannel { protected final ExecutorService executor; + private final Object writeLock = new Object(); + @FunctionalInterface public interface IOConsumer { @@ -52,22 +80,22 @@ protected class AttachedIO { } - public NonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { + public AbstractSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { Assert.notNull(selector, "Selector must not be null"); this.selector = selector; this.inBuffer = ByteBuffer.allocate(incomingBufferSize); this.executor = (executor == null) ? Executors.newSingleThreadExecutor() : executor; } - public NonBlockingSocketChannel(Selector selector, int incomingBufferSize) { + public AbstractSocketChannel(Selector selector, int incomingBufferSize) { this(selector, incomingBufferSize, null); } - public NonBlockingSocketChannel(Selector selector) { + public AbstractSocketChannel(Selector selector) { this(selector, DEFAULT_INBUFFER_SIZE); } - public NonBlockingSocketChannel() throws IOException { + public AbstractSocketChannel() throws IOException { this(Selector.open()); } @@ -77,28 +105,25 @@ protected Runnable getRunnableForProcessing(IOConsumer acceptHand SelectionKey key = null; try { while (true) { - this.selector.select(); + int count = this.selector.select(); + debug("Select returned count=%s", count); Set selectedKeys = selector.selectedKeys(); Iterator iter = selectedKeys.iterator(); while (iter.hasNext()) { key = iter.next(); if (key.isConnectable()) { handleConnectable(key, connectHandler); - } - else if (key.isAcceptable()) { + } else if (key.isAcceptable()) { handleAcceptable(key, acceptHandler); - } - else if (key.isReadable()) { + } else if (key.isReadable()) { handleReadable(key, readHandler); - } - else if (key.isWritable()) { + } else if (key.isWritable()) { handleWritable(key); } iter.remove(); } } - } - catch (Exception e) { + } catch (Exception e) { handleException(key, e); } }; @@ -113,48 +138,40 @@ protected void start(IOConsumer acceptHandler, IOConsumer connectHandler) throws IOException { SocketChannel client = (SocketChannel) key.channel(); - Object lock = client.blockingLock(); - if (logger.isDebugEnabled()) { - logger.debug("handleConnectable client=" + client.getRemoteAddress()); + debug("client=%s", client); + client.configureBlocking(false); + client.register(this.selector, SelectionKey.OP_READ, new AttachedIO()); + if (client.isConnectionPending()) { + client.finishConnect(); + debug("connected client=%s", client); } - synchronized (lock) { - client.configureBlocking(false); - client.register(this.selector, SelectionKey.OP_READ, new AttachedIO()); - if (client.isConnectionPending()) { - client.finishConnect(); - if (logger.isDebugEnabled()) { - logger.debug("handleConnectable FINISHED"); - } - } - if (connectHandler != null) { - connectHandler.apply(client); - } + if (connectHandler != null) { + connectHandler.apply(client); } } protected void handleAcceptable(SelectionKey key, IOConsumer acceptHandler) throws IOException { ServerSocketChannel serverSocket = (ServerSocketChannel) key.channel(); SocketChannel client = serverSocket.accept(); - Object lock = client.blockingLock(); - if (logger.isDebugEnabled()) { - logger.debug("handleAcceptable client=" + client); + debug("client=%s", client); + client.configureBlocking(false); + client.register(this.selector, SelectionKey.OP_READ, new AttachedIO()); + configureAcceptSocketChannel(client); + if (client.isConnectionPending()) { + client.finishConnect(); + debug("accepted client=%s", client); } - synchronized (lock) { - client.configureBlocking(false); - client.register(this.selector, SelectionKey.OP_READ, new AttachedIO()); - configureAcceptSocketChannel(client); - if (client.isConnectionPending()) { - client.finishConnect(); - if (logger.isDebugEnabled()) { - logger.debug("handleAcceptable FINISHED"); - } - } - if (acceptHandler != null) { - acceptHandler.apply(client); - } + if (acceptHandler != null) { + acceptHandler.apply(client); } } @@ -172,42 +189,35 @@ protected AttachedIO getAttachedIO(SelectionKey key) throws IOException { protected void handleReadable(SelectionKey key, IOConsumer readHandler) throws IOException { SocketChannel client = (SocketChannel) key.channel(); - Object lock = client.blockingLock(); AttachedIO io = getAttachedIO(key); - if (logger.isDebugEnabled()) { - logger.debug("handleReadable client=" + client); + debug("read client=%s", client); + // read + int r = client.read(this.inBuffer); + // Check if we should expect any more reads + if (r == -1) { + throw new IOException("Channel read reached end of stream"); } - synchronized (lock) { - // non-blocking read here - int r = client.read(this.inBuffer); - // Check if we should expect any more reads - if (r == -1) { - throw new IOException("Channel read reached end of stream"); - } - this.inBuffer.flip(); - String partial = new String(this.inBuffer.array(), 0, r, StandardCharsets.UTF_8); - // If there is are previous partial, then get the io.reading string Buffer - StringBuffer sb = (io.reading != null) ? (StringBuffer) io.reading : new StringBuffer(); - // And append the just read partial to the string buffer - sb.append(partial); - if (partial.endsWith(MESSAGE_DELIMITER)) { - // Get the entire message from the string buffer - String message = sb.toString(); - // Set the io.reading value to null as we are done with this message - io.reading = null; - if (logger.isDebugEnabled()) { - logger.debug("handleReadable COMPLETE msg=" + message); - } - if (readHandler != null) { - readHandler.apply(message); - } - } - else { - io.reading = sb; - if (logger.isDebugEnabled()) { - logger.debug("handleReadable PARTIAL msg=" + partial); + this.inBuffer.flip(); + String partial = new String(this.inBuffer.array(), 0, r, StandardCharsets.UTF_8); + // If there is previous partial, get the io.reading string Buffer + StringBuffer sb = (io.reading != null) ? (StringBuffer) io.reading : new StringBuffer(); + // append the just read partial to the existing or new string buffer + sb.append(partial); + if (partial.endsWith(messageDelimiter)) { + // Get the entire message from the string buffer + String message = sb.toString(); + // Set the io.reading value to null as we are done with this message + io.reading = null; + debug("read client=%s msg=", client, message); + if (readHandler != null) { + String[] messages = splitMessage(message); + for (int i = 0; i < messages.length; i++) { + readHandler.apply(messages[i]); } } + } else { + io.reading = sb; + debug("read partial=%s", partial); } // Clear inbuffer for next read this.inBuffer.clear(); @@ -217,12 +227,9 @@ protected void handleWritable(SelectionKey key) throws IOException { ByteBuffer buf = getAttachedIO(key).writing; SocketChannel client = (SocketChannel) key.channel(); if (buf != null) { - doWrite(key, client, buf, (lock) -> { - synchronized (lock) { - if (logger.isDebugEnabled()) { - logger.debug("handleWritable NOTIFY client=" + client); - } - lock.notifyAll(); + doWrite(key, client, buf, (o) -> { + synchronized (writeLock) { + writeLock.notifyAll(); } }); } @@ -231,9 +238,6 @@ protected void handleWritable(SelectionKey key) throws IOException { protected void doWrite(SocketChannel client, String message, IOConsumer writeHandler) throws IOException { Assert.notNull(client, "Client must not be null"); Assert.notNull(message, "Message must not be null"); - if (logger.isDebugEnabled()) { - logger.debug("doWrite msg=" + message); - } doWrite(client.keyFor(this.selector), client, ByteBuffer.wrap(message.getBytes(StandardCharsets.UTF_8)), writeHandler); } @@ -241,24 +245,20 @@ protected void doWrite(SocketChannel client, String message, IOConsumer protected void doWrite(SelectionKey key, SocketChannel client, ByteBuffer buf, IOConsumer writeHandler) throws IOException { AttachedIO io = (AttachedIO) key.attachment(); - Object lock = client.blockingLock(); - synchronized (lock) { + synchronized (writeLock) { int written = client.write(buf); if (buf.hasRemaining()) { - if (logger.isDebugEnabled()) { - logger.debug("doWrite PARTIAL written=" + written + " remaining=" + buf.remaining()); - } + debug("doWrite written=%s, remaining=%s", written, buf.remaining()); io.writing = buf.slice(); key.interestOpsOr(SelectionKey.OP_WRITE); - } - else { + } else { if (logger.isDebugEnabled()) { - logger.debug("doWrite COMPLETED msg=" + new String(buf.array(), 0, written)); + logger.debug("doWrite message=%s", new String(buf.array(), 0, written)); } io.writing = null; key.interestOps(SelectionKey.OP_READ); if (writeHandler != null) { - writeHandler.apply(lock); + writeHandler.apply(null); } } } @@ -266,14 +266,11 @@ protected void doWrite(SelectionKey key, SocketChannel client, ByteBuffer buf, I protected void executorShutdown() { if (!this.executor.isShutdown()) { - if (logger.isDebugEnabled()) { - logger.debug("executorShutdown"); - } + debug("shutdown"); try { - this.executor.awaitTermination(2000, TimeUnit.MILLISECONDS); + this.executor.awaitTermination(this.terminationTimeout, TimeUnit.MILLISECONDS); this.executor.shutdown(); - } - catch (InterruptedException e) { + } catch (InterruptedException e) { if (logger.isDebugEnabled()) { logger.debug("Exception in executor awaitTermination", e); } @@ -283,19 +280,14 @@ protected void executorShutdown() { protected void hardCloseClient(SocketChannel client, IOConsumer closeHandler) { if (client != null) { - Object lock = client.blockingLock(); - if (logger.isDebugEnabled()) { - logger.debug("hardCloseClient client=" + client); - } - synchronized (lock) { + debug("hardClose client=%s", client); + synchronized (writeLock) { try { if (closeHandler != null) { closeHandler.apply(client); } client.close(); - client = null; - } - catch (IOException e) { + } catch (IOException e) { if (logger.isDebugEnabled()) { logger.debug("hardClose client socketchannel.close exception", e); } @@ -305,23 +297,19 @@ protected void hardCloseClient(SocketChannel client, IOConsumer c } } - protected void writeBlocking(SocketChannel client, String message) throws IOException { + protected void writeMessageToChannel(SocketChannel client, String message) throws IOException { Objects.requireNonNull(client, "Client must not be null"); Objects.requireNonNull(message, "Message must not be null"); - // Escape any embedded newlines in the JSON message, and add newline - String outputMessage = message.replace("\r\n", "\\n") - .replace("\n", "\\n") - .replace("\r", "\\n") - .concat(MESSAGE_DELIMITER); - Object lock = client.blockingLock(); - if (logger.isDebugEnabled()) { - logger.debug("writeBlocking msg=" + outputMessage); - } - synchronized (lock) { + // Escape any embedded newlines in the JSON message + String outputMessage = message.replace("\r\n", "\\n").replace("\n", "\\n").replace("\r", "\\n") + // add message delimiter + .concat(DEFAULT_MESSAGE_DELIMITER); + debug("writing msg=%s", outputMessage); + synchronized (writeLock) { // do the non blocking write in thread while holding lock. doWrite(client, outputMessage, null); ByteBuffer bufRemaining = null; - long waitTime = System.currentTimeMillis() + BLOCKING_WRITE_TIMEOUT; + long waitTime = System.currentTimeMillis() + this.writeTimeout; while (waitTime - System.currentTimeMillis() > 0) { // Before releasing lock, check for writing buffer remaining bufRemaining = getAttachedIO(client.keyFor(this.selector)).writing; @@ -331,13 +319,9 @@ protected void writeBlocking(SocketChannel client, String message) throws IOExce } // If write is *not* completed, then wait timeout /10 try { - if (logger.isDebugEnabled()) { - logger - .debug("writeBlocking WAITING=" + String.valueOf(waitTime / 10) + " msg=" + outputMessage); - } - lock.wait(waitTime / 10); - } - catch (InterruptedException e) { + debug("writeBlocking WAITING(ms)=%s msg=%s", String.valueOf(waitTime / 10), outputMessage); + writeLock.wait(waitTime / 10); + } catch (InterruptedException e) { throw new InterruptedIOException("write message wait interrupted"); } } @@ -345,9 +329,7 @@ protected void writeBlocking(SocketChannel client, String message) throws IOExce throw new IOException("Write not completed. Non empty buffer remaining after timeout"); } } - if (logger.isDebugEnabled()) { - logger.debug("writeBlocking COMPLETED msg=" + outputMessage); - } + debug("writing done msg=%s", outputMessage); } protected void configureConnectSocketChannel(SocketChannel client, SocketAddress connectAddress) @@ -355,43 +337,8 @@ protected void configureConnectSocketChannel(SocketChannel client, SocketAddress // Subclasses may override } - protected SocketChannel connectBlocking(SocketChannel client, SocketAddress address, - IOConsumer connectHandler, IOConsumer readHandler) throws IOException { - Object lock = client.blockingLock(); - if (logger.isDebugEnabled()) { - logger.debug("connectBlocking CONNECTING targetAddress=" + address); - } - synchronized (lock) { - client.configureBlocking(false); - client.register(selector, SelectionKey.OP_CONNECT); - configureConnectSocketChannel(client, address); - // Start the read thread before connect - // No/null accept handler for clients - start(null, (c) -> { - synchronized (lock) { - if (connectHandler != null) { - connectHandler.apply(c); - } - lock.notifyAll(); - } - }, readHandler); - - client.connect(address); - - try { - if (logger.isDebugEnabled()) { - logger.debug("connectBlocking WAITING targetAddress=" + address); - } - lock.wait(BLOCKING_CONNECT_TIMEOUT); - } - catch (InterruptedException e) { - throw new IOException("Connect to address=" + address + " timed out"); - } - if (logger.isDebugEnabled()) { - logger.debug("connectBlocking CONNECTED client=" + client.getLocalAddress() + " connecting=" + address); - } - return client; - } + protected String[] splitMessage(String message) { + return (message == null) ? new String[0] : message.split(messageDelimiter); } } \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/ClientNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/ClientNonBlockingSocketChannel.java deleted file mode 100644 index 33c7f5f7d..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/util/ClientNonBlockingSocketChannel.java +++ /dev/null @@ -1,66 +0,0 @@ -package io.modelcontextprotocol.util; - -import java.io.IOException; -import java.net.SocketAddress; -import java.net.StandardProtocolFamily; -import java.nio.channels.SelectionKey; -import java.nio.channels.Selector; -import java.nio.channels.SocketChannel; -import java.util.concurrent.ExecutorService; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class ClientNonBlockingSocketChannel extends NonBlockingSocketChannel { - - private static final Logger logger = LoggerFactory.getLogger(ClientNonBlockingSocketChannel.class); - - private SocketChannel client; - - public ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { - super(selector, incomingBufferSize, executor); - } - - public ClientNonBlockingSocketChannel() throws IOException { - super(); - } - - public ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize) { - super(selector, incomingBufferSize); - } - - public ClientNonBlockingSocketChannel(Selector selector) { - super(selector); - } - - public void connectBlocking(StandardProtocolFamily protocol, SocketAddress address, - IOConsumer connectHandler, IOConsumer readHandler) throws IOException { - if (this.client != null) { - throw new IOException("Already connected"); - } - this.client = connectBlocking(SocketChannel.open(protocol), address, connectHandler, readHandler); - } - - @Override - protected void handleException(SelectionKey key, Exception e) { - if (logger.isDebugEnabled()) { - logger.debug("handleException", e); - } - close(); - } - - @Override - public void close() { - hardCloseClient(this.client, (client) -> { - this.client = null; - }); - } - - public void writeMessageBlocking(String message) throws IOException { - if (this.client == null) { - throw new IOException("Cannot write until client connected"); - } - writeBlocking(client, message); - } - -} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/ClientSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/ClientSocketChannel.java new file mode 100644 index 000000000..e3b7134e1 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/ClientSocketChannel.java @@ -0,0 +1,100 @@ +package io.modelcontextprotocol.util; + +import java.io.IOException; +import java.net.SocketAddress; +import java.net.StandardProtocolFamily; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.util.concurrent.ExecutorService; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ClientSocketChannel extends AbstractSocketChannel { + + private static final Logger logger = LoggerFactory.getLogger(ClientSocketChannel.class); + + protected SocketChannel client; + + protected final Object connectLock = new Object(); + + public ClientSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { + super(selector, incomingBufferSize, executor); + } + + public ClientSocketChannel() throws IOException { + super(); + } + + public ClientSocketChannel(Selector selector, int incomingBufferSize) { + super(selector, incomingBufferSize); + } + + public ClientSocketChannel(Selector selector) { + super(selector); + } + + protected SocketChannel doConnect(SocketChannel client, SocketAddress address, + IOConsumer connectHandler, IOConsumer readHandler) throws IOException { + debug("connect targetAddress=%s", address); + client.configureBlocking(false); + client.register(selector, SelectionKey.OP_CONNECT); + configureConnectSocketChannel(client, address); + // Start the read thread before connect + // No/null accept handler for clients + start(null, (c) -> { + synchronized (connectLock) { + if (connectHandler != null) { + connectHandler.apply(c); + } + connectLock.notifyAll(); + } + }, readHandler); + + client.connect(address); + try { + debug("connect targetAddress=%s", address); + synchronized (connectLock) { + connectLock.wait(this.connectTimeout); + } + } + catch (InterruptedException e) { + throw new IOException("Connect to address=" + address + " timed out after " + String.valueOf(this.connectTimeout)+ "ms" ); + } + debug("connected client=%s", client); + return client; + } + + + public void connect(StandardProtocolFamily protocol, SocketAddress address, + IOConsumer connectHandler, IOConsumer readHandler) throws IOException { + if (this.client != null) { + throw new IOException("Already connected"); + } + this.client = doConnect(SocketChannel.open(protocol), address, connectHandler, readHandler); + } + + @Override + protected void handleException(SelectionKey key, Exception e) { + if (logger.isDebugEnabled()) { + logger.debug("handleException", e); + } + close(); + } + + @Override + public void close() { + hardCloseClient(this.client, (client) -> { + this.client = null; + }); + } + + public void writeMessage(String message) throws IOException { + if (this.client == null) { + throw new IOException("Cannot write until client connected"); + } + writeMessageToChannel(client, message); + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Inet4ClientNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/Inet4ClientNonBlockingSocketChannel.java deleted file mode 100644 index b1186e3cd..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/util/Inet4ClientNonBlockingSocketChannel.java +++ /dev/null @@ -1,35 +0,0 @@ -package io.modelcontextprotocol.util; - -import java.io.IOException; -import java.net.Inet4Address; -import java.net.InetSocketAddress; -import java.net.StandardProtocolFamily; -import java.nio.channels.Selector; -import java.nio.channels.SocketChannel; -import java.util.concurrent.ExecutorService; - -public class Inet4ClientNonBlockingSocketChannel extends ClientNonBlockingSocketChannel { - - public Inet4ClientNonBlockingSocketChannel() throws IOException { - super(); - } - - public Inet4ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { - super(selector, incomingBufferSize, executor); - } - - public Inet4ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize) { - super(selector, incomingBufferSize); - } - - public Inet4ClientNonBlockingSocketChannel(Selector selector) { - super(selector); - } - - public void connectBlocking(Inet4Address address, int port, IOConsumer connectHandler, - IOConsumer readHandler) throws IOException { - super.connectBlocking(StandardProtocolFamily.INET, new InetSocketAddress(address, port), connectHandler, - readHandler); - } - -} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Inet4ServerNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/Inet4ServerNonBlockingSocketChannel.java deleted file mode 100644 index a4b9c61f8..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/util/Inet4ServerNonBlockingSocketChannel.java +++ /dev/null @@ -1,34 +0,0 @@ -package io.modelcontextprotocol.util; - -import java.io.IOException; -import java.net.Inet4Address; -import java.net.InetSocketAddress; -import java.net.StandardProtocolFamily; -import java.nio.channels.Selector; -import java.nio.channels.SocketChannel; -import java.util.concurrent.ExecutorService; - -public class Inet4ServerNonBlockingSocketChannel extends ServerNonBlockingSocketChannel { - - public Inet4ServerNonBlockingSocketChannel() throws IOException { - super(); - } - - public Inet4ServerNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { - super(selector, incomingBufferSize, executor); - } - - public Inet4ServerNonBlockingSocketChannel(Selector selector, int incomingBufferSize) { - super(selector, incomingBufferSize); - } - - public Inet4ServerNonBlockingSocketChannel(Selector selector) { - super(selector); - } - - public void start(Inet4Address address, int port, IOConsumer acceptHandler, - IOConsumer readHandler) throws IOException { - super.start(StandardProtocolFamily.INET, new InetSocketAddress(address, port), acceptHandler, readHandler); - } - -} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Inet6ClientNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/Inet6ClientNonBlockingSocketChannel.java deleted file mode 100644 index 9af484858..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/util/Inet6ClientNonBlockingSocketChannel.java +++ /dev/null @@ -1,35 +0,0 @@ -package io.modelcontextprotocol.util; - -import java.io.IOException; -import java.net.Inet6Address; -import java.net.InetSocketAddress; -import java.net.StandardProtocolFamily; -import java.nio.channels.Selector; -import java.nio.channels.SocketChannel; -import java.util.concurrent.ExecutorService; - -public class Inet6ClientNonBlockingSocketChannel extends ClientNonBlockingSocketChannel { - - public Inet6ClientNonBlockingSocketChannel() throws IOException { - super(); - } - - public Inet6ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { - super(selector, incomingBufferSize, executor); - } - - public Inet6ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize) { - super(selector, incomingBufferSize); - } - - public Inet6ClientNonBlockingSocketChannel(Selector selector) { - super(selector); - } - - public void connectBlocking(Inet6Address address, int port, IOConsumer connectHandler, - IOConsumer readHandler) throws IOException { - super.connectBlocking(StandardProtocolFamily.INET6, new InetSocketAddress(address, port), connectHandler, - readHandler); - } - -} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Inet6ServerNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/Inet6ServerNonBlockingSocketChannel.java deleted file mode 100644 index 8a1a95e27..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/util/Inet6ServerNonBlockingSocketChannel.java +++ /dev/null @@ -1,34 +0,0 @@ -package io.modelcontextprotocol.util; - -import java.io.IOException; -import java.net.Inet6Address; -import java.net.InetSocketAddress; -import java.net.StandardProtocolFamily; -import java.nio.channels.Selector; -import java.nio.channels.SocketChannel; -import java.util.concurrent.ExecutorService; - -public class Inet6ServerNonBlockingSocketChannel extends ServerNonBlockingSocketChannel { - - public Inet6ServerNonBlockingSocketChannel() throws IOException { - super(); - } - - public Inet6ServerNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { - super(selector, incomingBufferSize, executor); - } - - public Inet6ServerNonBlockingSocketChannel(Selector selector, int incomingBufferSize) { - super(selector, incomingBufferSize); - } - - public Inet6ServerNonBlockingSocketChannel(Selector selector) { - super(selector); - } - - public void start(Inet6Address address, int port, IOConsumer acceptHandler, - IOConsumer readHandler) throws IOException { - super.start(StandardProtocolFamily.INET6, new InetSocketAddress(address, port), acceptHandler, readHandler); - } - -} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/ServerNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/ServSocketChannel.java similarity index 76% rename from mcp/src/main/java/io/modelcontextprotocol/util/ServerNonBlockingSocketChannel.java rename to mcp/src/main/java/io/modelcontextprotocol/util/ServSocketChannel.java index 4c64e3d18..0b210c639 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/ServerNonBlockingSocketChannel.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/ServSocketChannel.java @@ -12,25 +12,25 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class ServerNonBlockingSocketChannel extends NonBlockingSocketChannel { +public class ServSocketChannel extends AbstractSocketChannel { - private static final Logger logger = LoggerFactory.getLogger(ServerNonBlockingSocketChannel.class); + private static final Logger logger = LoggerFactory.getLogger(ServSocketChannel.class); protected SocketChannel acceptedClient; - public ServerNonBlockingSocketChannel() throws IOException { + public ServSocketChannel() throws IOException { super(); } - public ServerNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { + public ServSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { super(selector, incomingBufferSize, executor); } - public ServerNonBlockingSocketChannel(Selector selector, int incomingBufferSize) { + public ServSocketChannel(Selector selector, int incomingBufferSize) { super(selector, incomingBufferSize); } - public ServerNonBlockingSocketChannel(Selector selector) { + public ServSocketChannel(Selector selector) { super(selector); } @@ -66,11 +66,13 @@ protected void handleException(SelectionKey key, Exception e) { close(); } - public void writeMessageBlocking(String message) throws IOException { - if (this.acceptedClient == null) { - throw new IOException("Cannot write until client connected"); + public void writeMessage(String message) throws IOException { + SocketChannel c = this.acceptedClient; + if (c != null) { + writeMessageToChannel(c, message); + } else { + throw new IOException("not connected"); } - writeBlocking(acceptedClient, message); } @Override diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/UDSClientNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/UDSClientNonBlockingSocketChannel.java deleted file mode 100644 index 2e279c2b9..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/util/UDSClientNonBlockingSocketChannel.java +++ /dev/null @@ -1,33 +0,0 @@ -package io.modelcontextprotocol.util; - -import java.io.IOException; -import java.net.StandardProtocolFamily; -import java.net.UnixDomainSocketAddress; -import java.nio.channels.Selector; -import java.nio.channels.SocketChannel; -import java.util.concurrent.ExecutorService; - -public class UDSClientNonBlockingSocketChannel extends ClientNonBlockingSocketChannel { - - public UDSClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { - super(selector, incomingBufferSize, executor); - } - - public UDSClientNonBlockingSocketChannel() throws IOException { - super(); - } - - public UDSClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize) { - super(selector, incomingBufferSize); - } - - public UDSClientNonBlockingSocketChannel(Selector selector) { - super(selector); - } - - public void connectBlocking(UnixDomainSocketAddress address, IOConsumer connectHandler, - IOConsumer readHandler) throws IOException { - super.connectBlocking(StandardProtocolFamily.UNIX, address, connectHandler, readHandler); - } - -} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/UDSClientSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/UDSClientSocketChannel.java new file mode 100644 index 000000000..93539c852 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/UDSClientSocketChannel.java @@ -0,0 +1,33 @@ +package io.modelcontextprotocol.util; + +import java.io.IOException; +import java.net.StandardProtocolFamily; +import java.net.UnixDomainSocketAddress; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.util.concurrent.ExecutorService; + +public class UDSClientSocketChannel extends ClientSocketChannel { + + public UDSClientSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { + super(selector, incomingBufferSize, executor); + } + + public UDSClientSocketChannel() throws IOException { + super(); + } + + public UDSClientSocketChannel(Selector selector, int incomingBufferSize) { + super(selector, incomingBufferSize); + } + + public UDSClientSocketChannel(Selector selector) { + super(selector); + } + + public void connect(UnixDomainSocketAddress address, IOConsumer connectHandler, + IOConsumer readHandler) throws IOException { + super.connect(StandardProtocolFamily.UNIX, address, connectHandler, readHandler); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/UDSServerNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/UDSServerSocketChannel.java similarity index 60% rename from mcp/src/main/java/io/modelcontextprotocol/util/UDSServerNonBlockingSocketChannel.java rename to mcp/src/main/java/io/modelcontextprotocol/util/UDSServerSocketChannel.java index 259315712..a6607cf17 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/UDSServerNonBlockingSocketChannel.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/UDSServerSocketChannel.java @@ -7,21 +7,21 @@ import java.nio.channels.SocketChannel; import java.util.concurrent.ExecutorService; -public class UDSServerNonBlockingSocketChannel extends ServerNonBlockingSocketChannel { +public class UDSServerSocketChannel extends ServSocketChannel { - public UDSServerNonBlockingSocketChannel() throws IOException { + public UDSServerSocketChannel() throws IOException { super(); } - public UDSServerNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { + public UDSServerSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { super(selector, incomingBufferSize, executor); } - public UDSServerNonBlockingSocketChannel(Selector selector, int incomingBufferSize) { + public UDSServerSocketChannel(Selector selector, int incomingBufferSize) { super(selector, incomingBufferSize); } - public UDSServerNonBlockingSocketChannel(Selector selector) { + public UDSServerSocketChannel(Selector selector) { super(selector); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpAsyncClientTests.java index 99121e01c..94b0b8a70 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpAsyncClientTests.java @@ -7,12 +7,13 @@ import java.io.IOException; import java.net.UnixDomainSocketAddress; import java.nio.file.Files; -import java.time.Duration; +import java.nio.file.Path; +import java.nio.file.Paths; import org.junit.jupiter.api.Timeout; import io.modelcontextprotocol.client.transport.UDSClientTransportProvider; -import io.modelcontextprotocol.server.EverythingServer; +import io.modelcontextprotocol.server.TestEverythingServer; import io.modelcontextprotocol.server.transport.UDSServerTransportProvider; import io.modelcontextprotocol.spec.McpClientTransport; @@ -23,47 +24,40 @@ * @author Dariusz Jędrzejczyk * @author Scott Lewis */ -@Timeout(15) // Giving extra time beyond the client timeout +@Timeout(150) // Giving extra time beyond the client timeout class UDSMcpAsyncClientTests extends AbstractMcpAsyncClientTests { - UnixDomainSocketAddress address; - EverythingServer server; + private Path socketPath = Paths.get(getClass().getName() + ".unix.socket"); - @Override - protected void onStart() { - this.address = UnixDomainSocketAddress.of(getClass().getName() + ".socket"); + private void deleteSocketPath() { try { - // Delete this file if exists from previous run - Files.deleteIfExists(this.address.getPath()); + Files.deleteIfExists(socketPath); } catch (IOException e) { throw new RuntimeException(e); } - this.server = new EverythingServer(new UDSServerTransportProvider(address)); + } + + protected void onStart() { + super.onStart(); + deleteSocketPath(); + this.server = new TestEverythingServer(new UDSServerTransportProvider(UnixDomainSocketAddress.of(socketPath))); } @Override protected void onClose() { - server.closeGracefully(); - server = null; - try { - Files.deleteIfExists(address.getPath()); - } catch (IOException e) { - throw new RuntimeException(e); + super.onClose(); + if (server != null) { + server.closeGracefully(); + server = null; } - address = null; + deleteSocketPath(); } + private TestEverythingServer server; + @Override protected McpClientTransport createMcpTransport() { - try { - return new UDSClientTransportProvider(address); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - protected Duration getInitializationTimeout() { - return Duration.ofSeconds(2); + return new UDSClientTransportProvider(UnixDomainSocketAddress.of(socketPath)); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpSyncClientTests.java index c52d98a97..3ce8be05c 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpSyncClientTests.java @@ -7,12 +7,14 @@ import java.io.IOException; import java.net.UnixDomainSocketAddress; import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; import java.time.Duration; import org.junit.jupiter.api.Timeout; import io.modelcontextprotocol.client.transport.UDSClientTransportProvider; -import io.modelcontextprotocol.server.EverythingServer; +import io.modelcontextprotocol.server.TestEverythingServer; import io.modelcontextprotocol.server.transport.UDSServerTransportProvider; import io.modelcontextprotocol.spec.McpClientTransport; @@ -26,40 +28,37 @@ @Timeout(15) // Giving extra time beyond the client timeout class UDSMcpSyncClientTests extends AbstractMcpSyncClientTests { - UnixDomainSocketAddress address; - EverythingServer server; + private Path socketPath = Paths.get(getClass().getName() + ".unix.socket"); - @Override - protected void onStart() { - this.address = UnixDomainSocketAddress.of(getClass().getName() + ".socket"); + private void deleteSocketPath() { try { - // Delete this file if exists from previous run - Files.deleteIfExists(this.address.getPath()); + Files.deleteIfExists(socketPath); } catch (IOException e) { throw new RuntimeException(e); } - this.server = new EverythingServer(new UDSServerTransportProvider(address)); + } + + protected void onStart() { + super.onStart(); + deleteSocketPath(); + this.server = new TestEverythingServer(new UDSServerTransportProvider(UnixDomainSocketAddress.of(socketPath))); } @Override protected void onClose() { - server.closeGracefully(); - server = null; - try { - Files.deleteIfExists(address.getPath()); - } catch (IOException e) { - throw new RuntimeException(e); + super.onClose(); + if (server != null) { + server.closeGracefully(); + server = null; } - address = null; + deleteSocketPath(); } + private TestEverythingServer server; + @Override protected McpClientTransport createMcpTransport() { - try { - return new UDSClientTransportProvider(address); - } catch (IOException e) { - throw new RuntimeException(e); - } + return new UDSClientTransportProvider(UnixDomainSocketAddress.of(socketPath)); } protected Duration getInitializationTimeout() { diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/EverythingServer.java b/mcp/src/test/java/io/modelcontextprotocol/server/EverythingServer.java deleted file mode 100644 index a158ab2fb..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/server/EverythingServer.java +++ /dev/null @@ -1,131 +0,0 @@ -package io.modelcontextprotocol.server; - -import java.util.List; - -import io.modelcontextprotocol.server.McpServerFeatures.SyncPromptSpecification; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpServerTransportProvider; -import io.modelcontextprotocol.spec.McpSchema.Annotations; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; -import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; -import io.modelcontextprotocol.spec.McpSchema.ProgressNotification; -import io.modelcontextprotocol.spec.McpSchema.Prompt; -import io.modelcontextprotocol.spec.McpSchema.PromptMessage; -import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; -import io.modelcontextprotocol.spec.McpSchema.Resource; -import io.modelcontextprotocol.spec.McpSchema.Role; -import io.modelcontextprotocol.spec.McpSchema.SamplingMessage; -import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; -import io.modelcontextprotocol.spec.McpSchema.TextContent; -import io.modelcontextprotocol.spec.McpSchema.TextResourceContents; -import io.modelcontextprotocol.spec.McpSchema.Tool; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest.ContextInclusionStrategy; - -public class EverythingServer { - - private static final String TEST_RESOURCE_URI = "test://resources/"; - - private static final String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - - private McpSyncServer server; - - public EverythingServer(McpServerTransportProvider transport) { - McpServerFeatures.SyncResourceSpecification[] specs = new McpServerFeatures.SyncResourceSpecification[10]; - for (int i = 0; i < 10; i++) { - String istr = String.valueOf(i); - String uri = TEST_RESOURCE_URI + istr; - specs[i] = new McpServerFeatures.SyncResourceSpecification( - Resource.builder().uri(uri).name("Test Resource").mimeType("text/plain") - .description("Test resource description").build(), - (exchange, - req) -> new ReadResourceResult(List.of(new TextResourceContents(uri, "text/plain", istr)))); - } - - this.server = McpServer.sync(transport).serverInfo(getClass().getName() + "-server", "1.0.0") - .capabilities( - ServerCapabilities.builder().logging().tools(true).prompts(true).resources(true, true).build()) - .toolCall(Tool.builder().name("echo").description("echo tool description").inputSchema(emptyJsonSchema) - .build(), (exchange, request) -> { - return CallToolResult.builder().addTextContent((String) request.arguments().get("message")) - .build(); - }) - .toolCall( - Tool.builder().name("add").description("add two integers").inputSchema(emptyJsonSchema).build(), - (exchange, request) -> { - Integer a = (Integer) request.arguments().get("a"); - Integer b = (Integer) request.arguments().get("b"); - - return CallToolResult.builder().addTextContent(String.valueOf(a + b)).build(); - }) - .toolCall(Tool.builder().name("sampleLLM").description("sampleLLM tool").inputSchema(emptyJsonSchema) - .build(), (exchange, request) -> { - String prompt = (String) request.arguments().get("prompt"); - Integer maxTokens = (Integer) request.arguments().get("maxTokens"); - SamplingMessage sm = new SamplingMessage(McpSchema.Role.USER, - new TextContent("Resource sampleLLM context: " + prompt)); - CreateMessageRequest cmRequest = CreateMessageRequest.builder().messages(List.of(sm)) - .systemPrompt("You are a helpful test server.").maxTokens(maxTokens) - .temperature(0.7).includeContext(ContextInclusionStrategy.THIS_SERVER).build(); - CreateMessageResult result = exchange.createMessage(cmRequest); - - return CallToolResult.builder() - .addTextContent("LLM sampling result: " + ((TextContent) result.content()).text()) - .build(); - }) - .toolCall(Tool.builder().name("longRunningOperation") - .description("Demonstrates a long running operation with progress updates") - .inputSchema(emptyJsonSchema).build(), (exchange, request) -> { - String progressToken = (String) request.progressToken(); - int steps = (Integer) request.arguments().get("steps"); - for (int i = 0; i < steps; i++) { - try { - Thread.sleep(1000); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - if (progressToken != null) { - exchange.progressNotification( - new ProgressNotification(progressToken, (double) i + 1, (double) steps, - "progress message " + String.valueOf(i + 1))); - } - } - return CallToolResult.builder().content(List.of(new TextContent("done"))).build(); - }) - .toolCall(Tool.builder().name("annotatedMessage").description("annotated message").build(), - (exchange, request) -> { - String messageType = (String) request.arguments().get("messageType"); - Annotations annotations = null; - if (messageType.equals("success")) { - annotations = new Annotations(List.of(McpSchema.Role.USER), 0.7); - } else if (messageType.equals("error")) { - annotations = new Annotations(List.of(McpSchema.Role.USER, McpSchema.Role.ASSISTANT), - 1.0); - } else if (messageType.equals("debug")) { - annotations = new Annotations(List.of(McpSchema.Role.ASSISTANT), 0.3); - } - return CallToolResult.builder().addContent(new TextContent(annotations, "some response")) - .build(); - }) - .prompts(List.of(new SyncPromptSpecification( - new Prompt("simple_prompt", "Simple prompt description", null), (exchange, request) -> { - return new GetPromptResult("description", - List.of(new PromptMessage(Role.USER, new TextContent("hello")))); - }))) - .resources(specs).build(); - } - - public void closeGracefully() { - if (this.server != null) { - this.server.closeGracefully(); - this.server = null; - } - } -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/TestEverythingServer.java b/mcp/src/test/java/io/modelcontextprotocol/server/TestEverythingServer.java new file mode 100644 index 000000000..3b0ca8fca --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/TestEverythingServer.java @@ -0,0 +1,149 @@ +package io.modelcontextprotocol.server; + +import java.util.List; + +import io.modelcontextprotocol.server.McpServerFeatures.SyncPromptSpecification; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema.Annotations; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.spec.McpSchema.ProgressNotification; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptMessage; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.Resource; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.SamplingMessage; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import io.modelcontextprotocol.spec.McpSchema.TextResourceContents; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest.ContextInclusionStrategy; + +public class TestEverythingServer { + + private static final String TEST_RESOURCE_URI = "test://resources/"; + + private static final String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + private McpSyncServer server; + + public TestEverythingServer(McpServerTransportProvider transport) { + McpServerFeatures.SyncResourceSpecification[] specs = new McpServerFeatures.SyncResourceSpecification[10]; + for (int i = 0; i < 10; i++) { + String istr = String.valueOf(i); + String uri = TEST_RESOURCE_URI + istr; + specs[i] = new McpServerFeatures.SyncResourceSpecification( + Resource.builder() + .uri(uri) + .name("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(), + (exchange, + req) -> new ReadResourceResult(List.of(new TextResourceContents(uri, "text/plain", istr)))); + } + + this.server = McpServer.sync(transport) + .serverInfo(getClass().getName() + "-server", "1.0.0") + .capabilities( + ServerCapabilities.builder().logging().tools(true).prompts(true).resources(true, true).build()) + .toolCall(Tool.builder() + .name("echo") + .description("echo tool description") + .inputSchema(emptyJsonSchema) + .build(), (exchange, request) -> { + return CallToolResult.builder().addTextContent((String) request.arguments().get("message")).build(); + }) + .toolCall(Tool.builder().name("add").description("add two integers").inputSchema(emptyJsonSchema).build(), + (exchange, request) -> { + Integer a = (Integer) request.arguments().get("a"); + Integer b = (Integer) request.arguments().get("b"); + + return CallToolResult.builder().addTextContent(String.valueOf(a + b)).build(); + }) + .toolCall( + Tool.builder().name("sampleLLM").description("sampleLLM tool").inputSchema(emptyJsonSchema).build(), + (exchange, request) -> { + String prompt = (String) request.arguments().get("prompt"); + Integer maxTokens = (Integer) request.arguments().get("maxTokens"); + SamplingMessage sm = new SamplingMessage(McpSchema.Role.USER, + new TextContent("Resource sampleLLM context: " + prompt)); + CreateMessageRequest cmRequest = CreateMessageRequest.builder() + .messages(List.of(sm)) + .systemPrompt("You are a helpful test server.") + .maxTokens(maxTokens) + .temperature(0.7) + .includeContext(ContextInclusionStrategy.THIS_SERVER) + .build(); + CreateMessageResult result = exchange.createMessage(cmRequest); + + return CallToolResult.builder() + .addTextContent("LLM sampling result: " + ((TextContent) result.content()).text()) + .build(); + }) + .toolCall(Tool.builder() + .name("longRunningOperation") + .description("Demonstrates a long running operation with progress updates") + .inputSchema(emptyJsonSchema) + .build(), (exchange, request) -> { + String progressToken = (String) request.progressToken(); + int steps = (Integer) request.arguments().get("steps"); + for (int i = 0; i < steps; i++) { + try { + Thread.sleep(1000); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + if (progressToken != null) { + exchange.progressNotification(new ProgressNotification(progressToken, (double) i + 1, + (double) steps, "progress message " + String.valueOf(i + 1))); + } + } + return CallToolResult.builder().content(List.of(new TextContent("done"))).build(); + }) + .toolCall(Tool.builder().name("annotatedMessage").description("annotated message").build(), + (exchange, request) -> { + String messageType = (String) request.arguments().get("messageType"); + Annotations annotations = null; + if (messageType.equals("success")) { + annotations = new Annotations(List.of(McpSchema.Role.USER), 0.7); + } + else if (messageType.equals("error")) { + annotations = new Annotations(List.of(McpSchema.Role.USER, McpSchema.Role.ASSISTANT), 1.0); + } + else if (messageType.equals("debug")) { + annotations = new Annotations(List.of(McpSchema.Role.ASSISTANT), 0.3); + } + return CallToolResult.builder() + .addContent(new TextContent(annotations, "some response")) + .build(); + }) + .prompts(List.of(new SyncPromptSpecification(new Prompt("simple_prompt", "Simple prompt description", null), + (exchange, request) -> { + return new GetPromptResult("description", + List.of(new PromptMessage(Role.USER, new TextContent("hello")))); + }))) + .resources(specs) + .build(); + } + + + public void closeGracefully() { + if (this.server != null) { + this.server.closeGracefully(); + this.server = null; + } + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java index cad1eae5b..505798502 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java @@ -7,6 +7,8 @@ import java.io.IOException; import java.net.UnixDomainSocketAddress; import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; import org.junit.jupiter.api.Timeout; @@ -22,28 +24,29 @@ @Timeout(15) // Giving extra time beyond the client timeout class UDSMcpAsyncServerTests extends AbstractMcpAsyncServerTests { - private UnixDomainSocketAddress address; + private Path socketPath = Paths.get(getClass().getName() + ".unix.socket"); - @Override - protected void setUp() { + private void deleteSocketPath() { + try { + Files.deleteIfExists(socketPath); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + protected void onStart() { super.onStart(); - address = UnixDomainSocketAddress.of(getClass().getName() + ".unix.socket"); + deleteSocketPath(); } + @Override - protected void tearDown() { + protected void onClose() { super.onClose(); - if (address != null) { - try { - Files.deleteIfExists(address.getPath()); - } - catch (IOException e) { - } - } + deleteSocketPath(); } protected McpServerTransportProvider createMcpTransportProvider() { - return new UDSServerTransportProvider(address); + return new UDSServerTransportProvider(UnixDomainSocketAddress.of(socketPath)); } @Override diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java index 57ec7b766..cacb7835a 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java @@ -7,6 +7,8 @@ import java.io.IOException; import java.net.UnixDomainSocketAddress; import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; import org.junit.jupiter.api.Timeout; @@ -22,28 +24,29 @@ @Timeout(15) // Giving extra time beyond the client timeout class UDSMcpSyncServerTests extends AbstractMcpSyncServerTests { - private UnixDomainSocketAddress address; + private Path socketPath = Paths.get(getClass().getName() + ".unix.socket"); - @Override - protected void setUp() { + private void deleteSocketPath() { + try { + Files.deleteIfExists(socketPath); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + protected void onStart() { super.onStart(); - address = UnixDomainSocketAddress.of(getClass().getName() + ".unix.socket"); + deleteSocketPath(); } + @Override - protected void tearDown() { + protected void onClose() { super.onClose(); - if (address != null) { - try { - Files.deleteIfExists(address.getPath()); - } - catch (IOException e) { - } - } + deleteSocketPath(); } protected McpServerTransportProvider createMcpTransportProvider() { - return new UDSServerTransportProvider(address); + return new UDSServerTransportProvider(UnixDomainSocketAddress.of(socketPath)); } @Override From 6c96c765cfe73b88209fa76d1f9b86911ca4278b Mon Sep 17 00:00:00 2001 From: Scott Lewis Date: Wed, 13 Aug 2025 18:35:16 -0700 Subject: [PATCH 18/25] Refactor --- .../transport/UDSClientTransportProvider.java | 85 ++++++++++--------- .../transport/UDSServerTransportProvider.java | 17 ++-- .../util/AbstractSocketChannel.java | 35 +++++--- .../util/ClientSocketChannel.java | 6 +- .../util/ServSocketChannel.java | 3 +- .../client/UDSMcpAsyncClientTests.java | 5 +- .../client/UDSMcpSyncClientTests.java | 5 +- .../server/TestEverythingServer.java | 3 +- .../server/UDSMcpAsyncServerTests.java | 5 +- .../server/UDSMcpSyncServerTests.java | 5 +- 10 files changed, 93 insertions(+), 76 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java index 7d9dc3f80..26bbc7bd2 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java @@ -69,19 +69,17 @@ protected void handleException(SelectionKey key, Exception e) { super.handleException(key, e); } }; - } catch (IOException e) { + } + catch (IOException e) { throw new RuntimeException(e); } this.outboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), "outbound"); } /** - * Creates a new StdioClientTransport with the specified parameters and - * ObjectMapper. - * - * @param params The parameters for configuring the server process - * @param objectMapper The ObjectMapper to use for JSON - * serialization/deserialization + * Creates a new StdioClientTransport with the specified parameters and ObjectMapper. + * @param params The parameters for configuring the server process + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization */ public UDSClientTransportProvider(ServerParameters params, ObjectMapper objectMapper) { Assert.notNull(params, "The params can not be null"); @@ -99,12 +97,11 @@ public UDSClientTransportProvider(ServerParameters params, ObjectMapper objectMa } /** - * Starts the server process and initializes the message processing streams. - * This method sets up the process with the configured command, arguments, and - * environment, then starts the inbound, outbound, and error processing threads. - * - * @throws RuntimeException if the process fails to start or if the process - * streams are null + * Starts the server process and initializes the message processing streams. This + * method sets up the process with the configured command, arguments, and environment, + * then starts the inbound, outbound, and error processing threads. + * @throws RuntimeException if the process fails to start or if the process streams + * are null */ @Override public Mono connect(Function, Mono> handler) { @@ -132,7 +129,8 @@ public Mono connect(Function, Mono> h } } }); - } catch (IOException e) { + } + catch (IOException e) { this.clientChannel.close(); throw new RuntimeException( "Connect to address=" + targetAddress + " failed message: " + e.getMessage()); @@ -147,10 +145,9 @@ public Mono connect(Function, Mono> h * Sets the handler for processing transport-level errors. * *

- * The provided handler will be called when errors occur during transport - * operations, such as connection failures or protocol violations. + * The provided handler will be called when errors occur during transport operations, + * such as connection failures or protocol violations. *

- * * @param errorHandler a consumer that processes error messages */ public void setStdErrorHandler(Consumer errorHandler) { @@ -158,8 +155,11 @@ public void setStdErrorHandler(Consumer errorHandler) { } private void handleIncomingMessages(Function, Mono> inboundMessageHandler) { - this.inboundSink.asFlux().flatMap(message -> Mono.just(message).transform(inboundMessageHandler) - .contextWrite(ctx -> ctx.put("observation", "myObservation"))).subscribe(); + this.inboundSink.asFlux() + .flatMap(message -> Mono.just(message) + .transform(inboundMessageHandler) + .contextWrite(ctx -> ctx.put("observation", "myObservation"))) + .subscribe(); } private void handleIncomingErrors() { @@ -179,28 +179,31 @@ public Mono sendMessage(JSONRPCMessage message) { /** * Starts the outbound processing thread that writes JSON-RPC messages to the - * process's output stream. Messages are serialized to JSON and written with a - * newline delimiter. + * process's output stream. Messages are serialized to JSON and written with a newline + * delimiter. */ private void startOutboundProcessing() { this.handleOutbound(messages -> messages - // this bit is important since writes come from user threads, and we - // want to ensure that the actual writing happens on a dedicated thread - .publishOn(outboundScheduler).handle((message, sink) -> { - if (message != null && !isClosing) { - try { - clientChannel.writeMessage(objectMapper.writeValueAsString(message)); - sink.next(message); - } catch (IOException e) { - if (!isClosing) { - logger.error("Error writing message", e); - sink.error(new RuntimeException(e)); - } else { - logger.debug("Stream closed during shutdown", e); - } + // this bit is important since writes come from user threads, and we + // want to ensure that the actual writing happens on a dedicated thread + .publishOn(outboundScheduler) + .handle((message, sink) -> { + if (message != null && !isClosing) { + try { + clientChannel.writeMessage(objectMapper.writeValueAsString(message)); + sink.next(message); + } + catch (IOException e) { + if (!isClosing) { + logger.error("Error writing message", e); + sink.error(new RuntimeException(e)); + } + else { + logger.debug("Stream closed during shutdown", e); } } - })); + } + })); } protected void handleOutbound(Function, Flux> outboundConsumer) { @@ -217,10 +220,9 @@ protected void handleOutbound(Function, Flux closeGracefully() { try { outboundScheduler.dispose(); logger.debug("Graceful shutdown completed"); - } catch (Exception e) { + } + catch (Exception e) { logger.error("Error during graceful shutdown", e); } })).then().subscribeOn(Schedulers.boundedElastic()); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java index 2d4908cb3..8911f2797 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java @@ -49,7 +49,6 @@ public class UDSServerTransportProvider implements McpServerTransportProvider { /** * Creates a new UDSServerTransportProvider with a default ObjectMapper - * * @param unixSocketAddress the UDS socket address to bind to. Must not be null. */ public UDSServerTransportProvider(UnixDomainSocketAddress unixSocketAddress) { @@ -58,9 +57,7 @@ public UDSServerTransportProvider(UnixDomainSocketAddress unixSocketAddress) { /** * Creates a new UDSServerTransportProvider with the specified ObjectMapper - * - * @param objectMapper The ObjectMapper to use for JSON - * serialization/deserialization + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization */ public UDSServerTransportProvider(ObjectMapper objectMapper, UnixDomainSocketAddress unixSocketAddress) { Assert.notNull(objectMapper, "objectMapper cannot be null"); @@ -87,7 +84,7 @@ public Mono notifyClients(String method, Object params) { return Mono.error(new McpError("No session to close")); } return this.session.sendNotification(method, params) - .doOnError(e -> logger.error("Failed to send notification: {}", e.getMessage())); + .doOnError(e -> logger.error("Failed to send notification: {}", e.getMessage())); } @Override @@ -131,7 +128,8 @@ protected void handleException(SelectionKey key, Exception e) { inboundSink.tryEmitComplete(); } }; - } catch (IOException e) { + } + catch (IOException e) { throw new RuntimeException(e); } } @@ -204,15 +202,16 @@ private void startInboundProcessing() { throw new IOException("Error adding jsonMessge to inboundSink"); } }); - } catch (IOException e) { + } + catch (IOException e) { throw new RuntimeException(e); } } } /** - * Starts the outbound processing thread that writes JSON-RPC messages to - * stdout. Messages are serialized to JSON and written with a newline delimiter. + * Starts the outbound processing thread that writes JSON-RPC messages to stdout. + * Messages are serialized to JSON and written with a newline delimiter. */ private void startOutboundProcessing() { Function, Flux> outboundConsumer = messages -> messages // @formatter:off diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/AbstractSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/AbstractSocketChannel.java index 1b25dff48..c6736c378 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/AbstractSocketChannel.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/AbstractSocketChannel.java @@ -113,17 +113,21 @@ protected Runnable getRunnableForProcessing(IOConsumer acceptHand key = iter.next(); if (key.isConnectable()) { handleConnectable(key, connectHandler); - } else if (key.isAcceptable()) { + } + else if (key.isAcceptable()) { handleAcceptable(key, acceptHandler); - } else if (key.isReadable()) { + } + else if (key.isReadable()) { handleReadable(key, readHandler); - } else if (key.isWritable()) { + } + else if (key.isWritable()) { handleWritable(key); } iter.remove(); } } - } catch (Exception e) { + } + catch (Exception e) { handleException(key, e); } }; @@ -215,7 +219,8 @@ protected void handleReadable(SelectionKey key, IOConsumer readHandler) readHandler.apply(messages[i]); } } - } else { + } + else { io.reading = sb; debug("read partial=%s", partial); } @@ -251,7 +256,8 @@ protected void doWrite(SelectionKey key, SocketChannel client, ByteBuffer buf, I debug("doWrite written=%s, remaining=%s", written, buf.remaining()); io.writing = buf.slice(); key.interestOpsOr(SelectionKey.OP_WRITE); - } else { + } + else { if (logger.isDebugEnabled()) { logger.debug("doWrite message=%s", new String(buf.array(), 0, written)); } @@ -270,7 +276,8 @@ protected void executorShutdown() { try { this.executor.awaitTermination(this.terminationTimeout, TimeUnit.MILLISECONDS); this.executor.shutdown(); - } catch (InterruptedException e) { + } + catch (InterruptedException e) { if (logger.isDebugEnabled()) { logger.debug("Exception in executor awaitTermination", e); } @@ -287,7 +294,8 @@ protected void hardCloseClient(SocketChannel client, IOConsumer c closeHandler.apply(client); } client.close(); - } catch (IOException e) { + } + catch (IOException e) { if (logger.isDebugEnabled()) { logger.debug("hardClose client socketchannel.close exception", e); } @@ -301,9 +309,11 @@ protected void writeMessageToChannel(SocketChannel client, String message) throw Objects.requireNonNull(client, "Client must not be null"); Objects.requireNonNull(message, "Message must not be null"); // Escape any embedded newlines in the JSON message - String outputMessage = message.replace("\r\n", "\\n").replace("\n", "\\n").replace("\r", "\\n") - // add message delimiter - .concat(DEFAULT_MESSAGE_DELIMITER); + String outputMessage = message.replace("\r\n", "\\n") + .replace("\n", "\\n") + .replace("\r", "\\n") + // add message delimiter + .concat(DEFAULT_MESSAGE_DELIMITER); debug("writing msg=%s", outputMessage); synchronized (writeLock) { // do the non blocking write in thread while holding lock. @@ -321,7 +331,8 @@ protected void writeMessageToChannel(SocketChannel client, String message) throw try { debug("writeBlocking WAITING(ms)=%s msg=%s", String.valueOf(waitTime / 10), outputMessage); writeLock.wait(waitTime / 10); - } catch (InterruptedException e) { + } + catch (InterruptedException e) { throw new InterruptedIOException("write message wait interrupted"); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/ClientSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/ClientSocketChannel.java index e3b7134e1..4ec69672c 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/ClientSocketChannel.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/ClientSocketChannel.java @@ -18,7 +18,7 @@ public class ClientSocketChannel extends AbstractSocketChannel { protected SocketChannel client; protected final Object connectLock = new Object(); - + public ClientSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { super(selector, incomingBufferSize, executor); } @@ -60,13 +60,13 @@ protected SocketChannel doConnect(SocketChannel client, SocketAddress address, } } catch (InterruptedException e) { - throw new IOException("Connect to address=" + address + " timed out after " + String.valueOf(this.connectTimeout)+ "ms" ); + throw new IOException( + "Connect to address=" + address + " timed out after " + String.valueOf(this.connectTimeout) + "ms"); } debug("connected client=%s", client); return client; } - public void connect(StandardProtocolFamily protocol, SocketAddress address, IOConsumer connectHandler, IOConsumer readHandler) throws IOException { if (this.client != null) { diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/ServSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/ServSocketChannel.java index 0b210c639..78b18d9fe 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/ServSocketChannel.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/ServSocketChannel.java @@ -70,7 +70,8 @@ public void writeMessage(String message) throws IOException { SocketChannel c = this.acceptedClient; if (c != null) { writeMessageToChannel(c, message); - } else { + } + else { throw new IOException("not connected"); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpAsyncClientTests.java index 94b0b8a70..ad43cd8c9 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpAsyncClientTests.java @@ -32,11 +32,12 @@ class UDSMcpAsyncClientTests extends AbstractMcpAsyncClientTests { private void deleteSocketPath() { try { Files.deleteIfExists(socketPath); - } catch (IOException e) { + } + catch (IOException e) { throw new RuntimeException(e); } } - + protected void onStart() { super.onStart(); deleteSocketPath(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpSyncClientTests.java index 3ce8be05c..6d2fc8b59 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpSyncClientTests.java @@ -33,11 +33,12 @@ class UDSMcpSyncClientTests extends AbstractMcpSyncClientTests { private void deleteSocketPath() { try { Files.deleteIfExists(socketPath); - } catch (IOException e) { + } + catch (IOException e) { throw new RuntimeException(e); } } - + protected void onStart() { super.onStart(); deleteSocketPath(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/TestEverythingServer.java b/mcp/src/test/java/io/modelcontextprotocol/server/TestEverythingServer.java index 3b0ca8fca..6d5fcf9b3 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/TestEverythingServer.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/TestEverythingServer.java @@ -137,8 +137,7 @@ else if (messageType.equals("debug")) { .resources(specs) .build(); } - - + public void closeGracefully() { if (this.server != null) { this.server.closeGracefully(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java index 505798502..8d7931a65 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java @@ -29,16 +29,17 @@ class UDSMcpAsyncServerTests extends AbstractMcpAsyncServerTests { private void deleteSocketPath() { try { Files.deleteIfExists(socketPath); - } catch (IOException e) { + } + catch (IOException e) { throw new RuntimeException(e); } } + protected void onStart() { super.onStart(); deleteSocketPath(); } - @Override protected void onClose() { super.onClose(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java index cacb7835a..795b1b2e7 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java @@ -29,16 +29,17 @@ class UDSMcpSyncServerTests extends AbstractMcpSyncServerTests { private void deleteSocketPath() { try { Files.deleteIfExists(socketPath); - } catch (IOException e) { + } + catch (IOException e) { throw new RuntimeException(e); } } + protected void onStart() { super.onStart(); deleteSocketPath(); } - @Override protected void onClose() { super.onClose(); From 03704246866fd5a59328834e5923f90eb0be6226 Mon Sep 17 00:00:00 2001 From: Scott Lewis Date: Wed, 13 Aug 2025 18:41:06 -0700 Subject: [PATCH 19/25] Refactoring for simplification and reliability. Also fixed failing tests in UDSMcpAsync/Sync Server/Client Tests. Now the only failing tests are those that are expecting specific behavior of javascript everything server for tests. io.modelcontextprotocol.client.AbstractMcpSyncClientTests.testCallTool() fails because of this assertion assertThat(result.isError()).isNull(); This is asserting that the isError is null...which it is in the successful case for the javascript server. The java TestEverythingServer, however the isError is false (not null). I believe that false is correct (the java builder will not allow me to set isError to null). io.modelcontextprotocol.client.AbstractMcpAsyncClientTests.testCallToolWithInvalidTool() line 247 has e -> assertThat(e).isInstanceOf(McpError.class).hasMessage("Unknown tool: nonexistent_tool")) And the assertion fails because the java TestEverythingServer (included in this pr) provides the message "Tool not found: nonexistent_tool" io.modelcontextprotocol.client.AbstractMcpAsyncClientTests.testCallTool() fails on line 232 with the same assertion error that io.modelcontextprotocol.client.AbstractMcpSyncClientTests.testCallTool() expecting isError to be null rather than false error. Signed-off-by: Scott Lewis --- .../transport/UDSClientTransportProvider.java | 261 +++++++++++++ .../transport/UDSServerTransportProvider.java | 258 +++++++++++++ .../util/AbstractSocketChannel.java | 355 ++++++++++++++++++ .../util/ClientSocketChannel.java | 100 +++++ .../util/ServSocketChannel.java | 92 +++++ .../util/UDSClientSocketChannel.java | 33 ++ .../util/UDSServerSocketChannel.java | 33 ++ .../client/UDSMcpAsyncClientTests.java | 64 ++++ .../client/UDSMcpSyncClientTests.java | 69 ++++ .../server/TestEverythingServer.java | 148 ++++++++ .../server/UDSMcpAsyncServerTests.java | 58 +++ .../server/UDSMcpSyncServerTests.java | 58 +++ 12 files changed, 1529 insertions(+) create mode 100644 mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/AbstractSocketChannel.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/ClientSocketChannel.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/ServSocketChannel.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/UDSClientSocketChannel.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/UDSServerSocketChannel.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpAsyncClientTests.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpSyncClientTests.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/TestEverythingServer.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java new file mode 100644 index 000000000..26bbc7bd2 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java @@ -0,0 +1,261 @@ +package io.modelcontextprotocol.client.transport; + +import java.io.IOException; +import java.net.UnixDomainSocketAddress; +import java.nio.channels.SelectionKey; +import java.time.Duration; +import java.util.concurrent.Executors; +import java.util.function.Consumer; +import java.util.function.Function; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.UDSClientSocketChannel; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; + +public class UDSClientTransportProvider implements McpClientTransport { + + private static final Logger logger = LoggerFactory.getLogger(StdioClientTransport.class); + + private final Sinks.Many inboundSink; + + private final Sinks.Many outboundSink; + + private ObjectMapper objectMapper; + + /** Scheduler for handling outbound messages to the server process */ + private Scheduler outboundScheduler; + + private final Sinks.Many errorSink; + + private UDSClientSocketChannel clientChannel; + + private UnixDomainSocketAddress targetAddress; + + private volatile boolean isClosing = false; + + // visible for tests + private Consumer stdErrorHandler = error -> logger.info("STDERR Message received: {}", error); + + public UDSClientTransportProvider(UnixDomainSocketAddress targetAddress) { + this(new ObjectMapper(), targetAddress); + } + + public UDSClientTransportProvider(ObjectMapper objectMapper, UnixDomainSocketAddress targetAddress) { + Assert.notNull(objectMapper, "objectMapper cannot be null"); + this.objectMapper = objectMapper; + Assert.notNull(objectMapper, "targetAddress cannot be null"); + this.targetAddress = targetAddress; + this.inboundSink = Sinks.many().unicast().onBackpressureBuffer(); + this.outboundSink = Sinks.many().unicast().onBackpressureBuffer(); + this.errorSink = Sinks.many().unicast().onBackpressureBuffer(); + try { + this.clientChannel = new UDSClientSocketChannel() { + @Override + protected void handleException(SelectionKey key, Exception e) { + isClosing = true; + super.handleException(key, e); + } + }; + } + catch (IOException e) { + throw new RuntimeException(e); + } + this.outboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), "outbound"); + } + + /** + * Creates a new StdioClientTransport with the specified parameters and ObjectMapper. + * @param params The parameters for configuring the server process + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + */ + public UDSClientTransportProvider(ServerParameters params, ObjectMapper objectMapper) { + Assert.notNull(params, "The params can not be null"); + Assert.notNull(objectMapper, "The ObjectMapper can not be null"); + + this.inboundSink = Sinks.many().unicast().onBackpressureBuffer(); + this.outboundSink = Sinks.many().unicast().onBackpressureBuffer(); + + this.objectMapper = objectMapper; + + this.errorSink = Sinks.many().unicast().onBackpressureBuffer(); + + // Start thread for outbound + this.outboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), "outbound"); + } + + /** + * Starts the server process and initializes the message processing streams. This + * method sets up the process with the configured command, arguments, and environment, + * then starts the inbound, outbound, and error processing threads. + * @throws RuntimeException if the process fails to start or if the process streams + * are null + */ + @Override + public Mono connect(Function, Mono> handler) { + return Mono.fromRunnable(() -> { + handleIncomingMessages(handler); + handleIncomingErrors(); + + // Connect client channel + try { + this.clientChannel.connect(targetAddress, (client) -> { + if (logger.isInfoEnabled()) { + logger.info("UDSClientTransportProvider CONNECTED to targetAddress=" + targetAddress); + } + }, (message) -> { + if (logger.isDebugEnabled()) { + logger.debug("received message=" + message); + } + // Incoming messages processed right here + McpSchema.JSONRPCMessage jsonMessage = McpSchema.deserializeJsonRpcMessage(objectMapper, message); + if (!this.inboundSink.tryEmitNext(jsonMessage).isSuccess()) { + if (!isClosing) { + if (logger.isDebugEnabled()) { + logger.error("Failed to enqueue inbound json message: {}", jsonMessage); + } + } + } + }); + } + catch (IOException e) { + this.clientChannel.close(); + throw new RuntimeException( + "Connect to address=" + targetAddress + " failed message: " + e.getMessage()); + } + + startOutboundProcessing(); + + }).subscribeOn(Schedulers.boundedElastic()); + } + + /** + * Sets the handler for processing transport-level errors. + * + *

+ * The provided handler will be called when errors occur during transport operations, + * such as connection failures or protocol violations. + *

+ * @param errorHandler a consumer that processes error messages + */ + public void setStdErrorHandler(Consumer errorHandler) { + this.stdErrorHandler = errorHandler; + } + + private void handleIncomingMessages(Function, Mono> inboundMessageHandler) { + this.inboundSink.asFlux() + .flatMap(message -> Mono.just(message) + .transform(inboundMessageHandler) + .contextWrite(ctx -> ctx.put("observation", "myObservation"))) + .subscribe(); + } + + private void handleIncomingErrors() { + this.errorSink.asFlux().subscribe(e -> { + this.stdErrorHandler.accept(e); + }); + } + + @Override + public Mono sendMessage(JSONRPCMessage message) { + outboundSink.emitNext(message, (signalType, emitResult) -> { + // Allow retry + return true; + }); + return Mono.empty(); + } + + /** + * Starts the outbound processing thread that writes JSON-RPC messages to the + * process's output stream. Messages are serialized to JSON and written with a newline + * delimiter. + */ + private void startOutboundProcessing() { + this.handleOutbound(messages -> messages + // this bit is important since writes come from user threads, and we + // want to ensure that the actual writing happens on a dedicated thread + .publishOn(outboundScheduler) + .handle((message, sink) -> { + if (message != null && !isClosing) { + try { + clientChannel.writeMessage(objectMapper.writeValueAsString(message)); + sink.next(message); + } + catch (IOException e) { + if (!isClosing) { + logger.error("Error writing message", e); + sink.error(new RuntimeException(e)); + } + else { + logger.debug("Stream closed during shutdown", e); + } + } + } + })); + } + + protected void handleOutbound(Function, Flux> outboundConsumer) { + outboundConsumer.apply(outboundSink.asFlux()).doOnComplete(() -> { + isClosing = true; + outboundSink.tryEmitComplete(); + }).doOnError(e -> { + if (!isClosing) { + logger.error("Error in outbound processing", e); + isClosing = true; + outboundSink.tryEmitComplete(); + } + }).subscribe(); + } + + /** + * Gracefully closes the transport by destroying the process and disposing of the + * schedulers. This method sends a TERM signal to the process and waits for it to exit + * before cleaning up resources. + * @return A Mono that completes when the transport is closed + */ + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + isClosing = true; + logger.debug("Initiating graceful shutdown"); + }).then(Mono.defer(() -> { + // First complete all sinks to stop accepting new messages + inboundSink.tryEmitComplete(); + outboundSink.tryEmitComplete(); + errorSink.tryEmitComplete(); + + // Give a short time for any pending messages to be processed + return Mono.delay(Duration.ofMillis(100)).then(); + })).then(Mono.fromRunnable(() -> { + try { + outboundScheduler.dispose(); + logger.debug("Graceful shutdown completed"); + } + catch (Exception e) { + logger.error("Error during graceful shutdown", e); + } + })).then().subscribeOn(Schedulers.boundedElastic()); + } + + public Sinks.Many getErrorSink() { + return this.errorSink; + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return this.objectMapper.convertValue(data, typeRef); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java new file mode 100644 index 000000000..8911f2797 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java @@ -0,0 +1,258 @@ +package io.modelcontextprotocol.server.transport; + +import java.io.IOException; +import java.net.UnixDomainSocketAddress; +import java.nio.channels.SelectionKey; +import java.util.List; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpServerTransport; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.ProtocolVersions; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.UDSServerSocketChannel; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; + +public class UDSServerTransportProvider implements McpServerTransportProvider { + + private static final Logger logger = LoggerFactory.getLogger(StdioServerTransportProvider.class); + + private final ObjectMapper objectMapper; + + private UDSMcpSessionTransport transport; + + private McpServerSession session; + + private final AtomicBoolean isClosing = new AtomicBoolean(false); + + private final Sinks.One inboundReady = Sinks.one(); + + private final Sinks.One outboundReady = Sinks.one(); + + private UnixDomainSocketAddress targetAddress; + + /** + * Creates a new UDSServerTransportProvider with a default ObjectMapper + * @param unixSocketAddress the UDS socket address to bind to. Must not be null. + */ + public UDSServerTransportProvider(UnixDomainSocketAddress unixSocketAddress) { + this(new ObjectMapper(), unixSocketAddress); + } + + /** + * Creates a new UDSServerTransportProvider with the specified ObjectMapper + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + */ + public UDSServerTransportProvider(ObjectMapper objectMapper, UnixDomainSocketAddress unixSocketAddress) { + Assert.notNull(objectMapper, "objectMapper cannot be null"); + this.objectMapper = objectMapper; + Assert.notNull(unixSocketAddress, "unixSocketAddress cannot be null"); + this.targetAddress = unixSocketAddress; + } + + @Override + public List protocolVersions() { + return List.of(ProtocolVersions.MCP_2024_11_05); + } + + @Override + public void setSessionFactory(McpServerSession.Factory sessionFactory) { + this.transport = new UDSMcpSessionTransport(); + this.session = sessionFactory.create(transport); + this.transport.initProcessing(); + } + + @Override + public Mono notifyClients(String method, Object params) { + if (this.session == null) { + return Mono.error(new McpError("No session to close")); + } + return this.session.sendNotification(method, params) + .doOnError(e -> logger.error("Failed to send notification: {}", e.getMessage())); + } + + @Override + public Mono closeGracefully() { + if (this.session == null) { + return Mono.empty(); + } + return this.session.closeGracefully(); + } + + /** + * Implementation of McpServerTransport for the uds session. + */ + private class UDSMcpSessionTransport implements McpServerTransport { + + private final Sinks.Many inboundSink; + + private final Sinks.Many outboundSink; + + /** Scheduler for handling outbound messages */ + private Scheduler outboundScheduler; + + private final AtomicBoolean isStarted = new AtomicBoolean(false); + + private final UDSServerSocketChannel serverSocketChannel; + + public UDSMcpSessionTransport() { + this.inboundSink = Sinks.many().unicast().onBackpressureBuffer(); + this.outboundSink = Sinks.many().unicast().onBackpressureBuffer(); + this.outboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), + "uds-outbound"); + try { + this.serverSocketChannel = new UDSServerSocketChannel() { + @Override + protected void handleException(SelectionKey key, Exception e) { + isClosing.set(true); + if (session != null) { + session.close(); + session = null; + } + inboundSink.tryEmitComplete(); + } + }; + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + return Mono.zip(inboundReady.asMono(), outboundReady.asMono()).then(Mono.defer(() -> { + outboundSink.emitNext(message, (signalType, emitResult) -> { + // Allow retry + return true; + }); + return Mono.empty(); + })); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + isClosing.set(true); + logger.debug("Session transport closing gracefully"); + inboundSink.tryEmitComplete(); + }); + } + + @Override + public void close() { + isClosing.set(true); + logger.debug("Session transport closed"); + } + + private void initProcessing() { + handleIncomingMessages(); + startInboundProcessing(); + startOutboundProcessing(); + + inboundReady.tryEmitValue(null); + outboundReady.tryEmitValue(null); + } + + private void handleIncomingMessages() { + this.inboundSink.asFlux().flatMap(message -> session.handle(message)).doOnTerminate(() -> { + this.outboundSink.tryEmitComplete(); + }).subscribe(); + } + + /** + * Starts the inbound processing thread that reads JSON-RPC messages from stdin. + * Messages are deserialized and passed to the session for handling. + */ + private void startInboundProcessing() { + if (isStarted.compareAndSet(false, true)) { + try { + this.serverSocketChannel.start(targetAddress, (clientChannel) -> { + if (logger.isDebugEnabled()) { + logger.debug("Accepted connect from clientChannel=" + clientChannel); + } + }, (message) -> { + if (logger.isDebugEnabled()) { + logger.debug("Received message=" + message); + } + // Incoming messages processed right here + McpSchema.JSONRPCMessage jsonMessage = McpSchema.deserializeJsonRpcMessage(objectMapper, + message); + if (!this.inboundSink.tryEmitNext(jsonMessage).isSuccess()) { + throw new IOException("Error adding jsonMessge to inboundSink"); + } + }); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + } + + /** + * Starts the outbound processing thread that writes JSON-RPC messages to stdout. + * Messages are serialized to JSON and written with a newline delimiter. + */ + private void startOutboundProcessing() { + Function, Flux> outboundConsumer = messages -> messages // @formatter:off + .doOnSubscribe(subscription -> outboundReady.tryEmitValue(null)) + .publishOn(outboundScheduler) + .handle((message, sink) -> { + if (message != null && !isClosing.get()) { + try { + serverSocketChannel.writeMessage(objectMapper.writeValueAsString(message)); + sink.next(message); + } + catch (IOException e) { + if (!isClosing.get()) { + logger.error("Error writing message", e); + sink.error(new RuntimeException(e)); + } + else { + logger.debug("Stream closed during shutdown", e); + } + } + } + else if (isClosing.get()) { + sink.complete(); + } + }) + .doOnComplete(() -> { + isClosing.set(true); + outboundScheduler.dispose(); + }) + .doOnError(e -> { + if (!isClosing.get()) { + logger.error("Error in outbound processing", e); + isClosing.set(true); + outboundScheduler.dispose(); + } + }) + .map(msg -> (JSONRPCMessage) msg); + + outboundConsumer.apply(outboundSink.asFlux()).subscribe(); + } // @formatter:on + + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/AbstractSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/AbstractSocketChannel.java new file mode 100644 index 000000000..c6736c378 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/AbstractSocketChannel.java @@ -0,0 +1,355 @@ +package io.modelcontextprotocol.util; + +import java.io.IOException; +import java.io.InterruptedIOException; +import java.net.SocketAddress; +import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; +import java.nio.charset.StandardCharsets; +import java.util.Iterator; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public abstract class AbstractSocketChannel { + + private static final Logger logger = LoggerFactory.getLogger(AbstractSocketChannel.class); + + public static final int DEFAULT_INBUFFER_SIZE = 1024; + + public static String DEFAULT_MESSAGE_DELIMITER = "\n"; + + protected String messageDelimiter = DEFAULT_MESSAGE_DELIMITER; + + protected void setMessageDelimiter(String delim) { + this.messageDelimiter = delim; + } + + public static int DEFAULT_WRITE_TIMEOUT = 5000; // ms + + protected int writeTimeout = DEFAULT_WRITE_TIMEOUT; + + protected void setWriteTimeout(int timeout) { + this.writeTimeout = timeout; + } + + public static int DEFAULT_CONNECT_TIMEOUT = 10000; // ms + + protected int connectTimeout = DEFAULT_CONNECT_TIMEOUT; + + protected void setConnectTimeout(int timeout) { + this.connectTimeout = timeout; + } + + public static int DEFAULT_TERMINATION_TIMEOUT = 2000; // ms + + protected int terminationTimeout = DEFAULT_TERMINATION_TIMEOUT; + + protected void setTerminationTimeout(int timeout) { + this.terminationTimeout = timeout; + } + + protected final Selector selector; + + protected final ByteBuffer inBuffer; + + protected final ExecutorService executor; + + private final Object writeLock = new Object(); + + @FunctionalInterface + public interface IOConsumer { + + void apply(T t) throws IOException; + + } + + protected class AttachedIO { + + public ByteBuffer writing; + + public StringBuffer reading; + + } + + public AbstractSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { + Assert.notNull(selector, "Selector must not be null"); + this.selector = selector; + this.inBuffer = ByteBuffer.allocate(incomingBufferSize); + this.executor = (executor == null) ? Executors.newSingleThreadExecutor() : executor; + } + + public AbstractSocketChannel(Selector selector, int incomingBufferSize) { + this(selector, incomingBufferSize, null); + } + + public AbstractSocketChannel(Selector selector) { + this(selector, DEFAULT_INBUFFER_SIZE); + } + + public AbstractSocketChannel() throws IOException { + this(Selector.open()); + } + + protected Runnable getRunnableForProcessing(IOConsumer acceptHandler, + IOConsumer connectHandler, IOConsumer readHandler) { + return () -> { + SelectionKey key = null; + try { + while (true) { + int count = this.selector.select(); + debug("Select returned count=%s", count); + Set selectedKeys = selector.selectedKeys(); + Iterator iter = selectedKeys.iterator(); + while (iter.hasNext()) { + key = iter.next(); + if (key.isConnectable()) { + handleConnectable(key, connectHandler); + } + else if (key.isAcceptable()) { + handleAcceptable(key, acceptHandler); + } + else if (key.isReadable()) { + handleReadable(key, readHandler); + } + else if (key.isWritable()) { + handleWritable(key); + } + iter.remove(); + } + } + } + catch (Exception e) { + handleException(key, e); + } + }; + } + + public abstract void close(); + + protected abstract void handleException(SelectionKey key, Exception e); + + protected void start(IOConsumer acceptHandler, IOConsumer connectHandler, + IOConsumer readHandler) throws IOException { + this.executor.execute(getRunnableForProcessing(acceptHandler, connectHandler, readHandler)); + } + + protected void debug(String format, Object... o) { + if (logger.isDebugEnabled()) { + logger.debug(format, o); + } + } + + // For client subclasses + protected void handleConnectable(SelectionKey key, IOConsumer connectHandler) throws IOException { + SocketChannel client = (SocketChannel) key.channel(); + debug("client=%s", client); + client.configureBlocking(false); + client.register(this.selector, SelectionKey.OP_READ, new AttachedIO()); + if (client.isConnectionPending()) { + client.finishConnect(); + debug("connected client=%s", client); + } + if (connectHandler != null) { + connectHandler.apply(client); + } + } + + protected void handleAcceptable(SelectionKey key, IOConsumer acceptHandler) throws IOException { + ServerSocketChannel serverSocket = (ServerSocketChannel) key.channel(); + SocketChannel client = serverSocket.accept(); + debug("client=%s", client); + client.configureBlocking(false); + client.register(this.selector, SelectionKey.OP_READ, new AttachedIO()); + configureAcceptSocketChannel(client); + if (client.isConnectionPending()) { + client.finishConnect(); + debug("accepted client=%s", client); + } + if (acceptHandler != null) { + acceptHandler.apply(client); + } + } + + protected void configureAcceptSocketChannel(SocketChannel client) throws IOException { + // Subclasses may override + } + + protected AttachedIO getAttachedIO(SelectionKey key) throws IOException { + AttachedIO io = (AttachedIO) key.attachment(); + if (io == null) { + throw new IOException("No AttachedIO object found on key"); + } + return io; + } + + protected void handleReadable(SelectionKey key, IOConsumer readHandler) throws IOException { + SocketChannel client = (SocketChannel) key.channel(); + AttachedIO io = getAttachedIO(key); + debug("read client=%s", client); + // read + int r = client.read(this.inBuffer); + // Check if we should expect any more reads + if (r == -1) { + throw new IOException("Channel read reached end of stream"); + } + this.inBuffer.flip(); + String partial = new String(this.inBuffer.array(), 0, r, StandardCharsets.UTF_8); + // If there is previous partial, get the io.reading string Buffer + StringBuffer sb = (io.reading != null) ? (StringBuffer) io.reading : new StringBuffer(); + // append the just read partial to the existing or new string buffer + sb.append(partial); + if (partial.endsWith(messageDelimiter)) { + // Get the entire message from the string buffer + String message = sb.toString(); + // Set the io.reading value to null as we are done with this message + io.reading = null; + debug("read client=%s msg=", client, message); + if (readHandler != null) { + String[] messages = splitMessage(message); + for (int i = 0; i < messages.length; i++) { + readHandler.apply(messages[i]); + } + } + } + else { + io.reading = sb; + debug("read partial=%s", partial); + } + // Clear inbuffer for next read + this.inBuffer.clear(); + } + + protected void handleWritable(SelectionKey key) throws IOException { + ByteBuffer buf = getAttachedIO(key).writing; + SocketChannel client = (SocketChannel) key.channel(); + if (buf != null) { + doWrite(key, client, buf, (o) -> { + synchronized (writeLock) { + writeLock.notifyAll(); + } + }); + } + } + + protected void doWrite(SocketChannel client, String message, IOConsumer writeHandler) throws IOException { + Assert.notNull(client, "Client must not be null"); + Assert.notNull(message, "Message must not be null"); + doWrite(client.keyFor(this.selector), client, ByteBuffer.wrap(message.getBytes(StandardCharsets.UTF_8)), + writeHandler); + } + + protected void doWrite(SelectionKey key, SocketChannel client, ByteBuffer buf, IOConsumer writeHandler) + throws IOException { + AttachedIO io = (AttachedIO) key.attachment(); + synchronized (writeLock) { + int written = client.write(buf); + if (buf.hasRemaining()) { + debug("doWrite written=%s, remaining=%s", written, buf.remaining()); + io.writing = buf.slice(); + key.interestOpsOr(SelectionKey.OP_WRITE); + } + else { + if (logger.isDebugEnabled()) { + logger.debug("doWrite message=%s", new String(buf.array(), 0, written)); + } + io.writing = null; + key.interestOps(SelectionKey.OP_READ); + if (writeHandler != null) { + writeHandler.apply(null); + } + } + } + } + + protected void executorShutdown() { + if (!this.executor.isShutdown()) { + debug("shutdown"); + try { + this.executor.awaitTermination(this.terminationTimeout, TimeUnit.MILLISECONDS); + this.executor.shutdown(); + } + catch (InterruptedException e) { + if (logger.isDebugEnabled()) { + logger.debug("Exception in executor awaitTermination", e); + } + } + } + } + + protected void hardCloseClient(SocketChannel client, IOConsumer closeHandler) { + if (client != null) { + debug("hardClose client=%s", client); + synchronized (writeLock) { + try { + if (closeHandler != null) { + closeHandler.apply(client); + } + client.close(); + } + catch (IOException e) { + if (logger.isDebugEnabled()) { + logger.debug("hardClose client socketchannel.close exception", e); + } + } + } + executorShutdown(); + } + } + + protected void writeMessageToChannel(SocketChannel client, String message) throws IOException { + Objects.requireNonNull(client, "Client must not be null"); + Objects.requireNonNull(message, "Message must not be null"); + // Escape any embedded newlines in the JSON message + String outputMessage = message.replace("\r\n", "\\n") + .replace("\n", "\\n") + .replace("\r", "\\n") + // add message delimiter + .concat(DEFAULT_MESSAGE_DELIMITER); + debug("writing msg=%s", outputMessage); + synchronized (writeLock) { + // do the non blocking write in thread while holding lock. + doWrite(client, outputMessage, null); + ByteBuffer bufRemaining = null; + long waitTime = System.currentTimeMillis() + this.writeTimeout; + while (waitTime - System.currentTimeMillis() > 0) { + // Before releasing lock, check for writing buffer remaining + bufRemaining = getAttachedIO(client.keyFor(this.selector)).writing; + if (bufRemaining == null || bufRemaining.remaining() == 0) { + // It's done + break; + } + // If write is *not* completed, then wait timeout /10 + try { + debug("writeBlocking WAITING(ms)=%s msg=%s", String.valueOf(waitTime / 10), outputMessage); + writeLock.wait(waitTime / 10); + } + catch (InterruptedException e) { + throw new InterruptedIOException("write message wait interrupted"); + } + } + if (bufRemaining != null && bufRemaining.remaining() > 0) { + throw new IOException("Write not completed. Non empty buffer remaining after timeout"); + } + } + debug("writing done msg=%s", outputMessage); + } + + protected void configureConnectSocketChannel(SocketChannel client, SocketAddress connectAddress) + throws IOException { + // Subclasses may override + } + + protected String[] splitMessage(String message) { + return (message == null) ? new String[0] : message.split(messageDelimiter); + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/ClientSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/ClientSocketChannel.java new file mode 100644 index 000000000..4ec69672c --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/ClientSocketChannel.java @@ -0,0 +1,100 @@ +package io.modelcontextprotocol.util; + +import java.io.IOException; +import java.net.SocketAddress; +import java.net.StandardProtocolFamily; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.util.concurrent.ExecutorService; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ClientSocketChannel extends AbstractSocketChannel { + + private static final Logger logger = LoggerFactory.getLogger(ClientSocketChannel.class); + + protected SocketChannel client; + + protected final Object connectLock = new Object(); + + public ClientSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { + super(selector, incomingBufferSize, executor); + } + + public ClientSocketChannel() throws IOException { + super(); + } + + public ClientSocketChannel(Selector selector, int incomingBufferSize) { + super(selector, incomingBufferSize); + } + + public ClientSocketChannel(Selector selector) { + super(selector); + } + + protected SocketChannel doConnect(SocketChannel client, SocketAddress address, + IOConsumer connectHandler, IOConsumer readHandler) throws IOException { + debug("connect targetAddress=%s", address); + client.configureBlocking(false); + client.register(selector, SelectionKey.OP_CONNECT); + configureConnectSocketChannel(client, address); + // Start the read thread before connect + // No/null accept handler for clients + start(null, (c) -> { + synchronized (connectLock) { + if (connectHandler != null) { + connectHandler.apply(c); + } + connectLock.notifyAll(); + } + }, readHandler); + + client.connect(address); + try { + debug("connect targetAddress=%s", address); + synchronized (connectLock) { + connectLock.wait(this.connectTimeout); + } + } + catch (InterruptedException e) { + throw new IOException( + "Connect to address=" + address + " timed out after " + String.valueOf(this.connectTimeout) + "ms"); + } + debug("connected client=%s", client); + return client; + } + + public void connect(StandardProtocolFamily protocol, SocketAddress address, + IOConsumer connectHandler, IOConsumer readHandler) throws IOException { + if (this.client != null) { + throw new IOException("Already connected"); + } + this.client = doConnect(SocketChannel.open(protocol), address, connectHandler, readHandler); + } + + @Override + protected void handleException(SelectionKey key, Exception e) { + if (logger.isDebugEnabled()) { + logger.debug("handleException", e); + } + close(); + } + + @Override + public void close() { + hardCloseClient(this.client, (client) -> { + this.client = null; + }); + } + + public void writeMessage(String message) throws IOException { + if (this.client == null) { + throw new IOException("Cannot write until client connected"); + } + writeMessageToChannel(client, message); + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/ServSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/ServSocketChannel.java new file mode 100644 index 000000000..78b18d9fe --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/ServSocketChannel.java @@ -0,0 +1,92 @@ +package io.modelcontextprotocol.util; + +import java.io.IOException; +import java.net.SocketAddress; +import java.net.StandardProtocolFamily; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; +import java.util.concurrent.ExecutorService; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ServSocketChannel extends AbstractSocketChannel { + + private static final Logger logger = LoggerFactory.getLogger(ServSocketChannel.class); + + protected SocketChannel acceptedClient; + + public ServSocketChannel() throws IOException { + super(); + } + + public ServSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { + super(selector, incomingBufferSize, executor); + } + + public ServSocketChannel(Selector selector, int incomingBufferSize) { + super(selector, incomingBufferSize); + } + + public ServSocketChannel(Selector selector) { + super(selector); + } + + protected void configureServerSocketChannel(ServerSocketChannel serverSocketChannel, SocketAddress acceptAddress) { + // Subclasses may override + } + + public void start(StandardProtocolFamily protocol, SocketAddress address, IOConsumer acceptHandler, + IOConsumer readHandler) throws IOException { + ServerSocketChannel serverChannel = ServerSocketChannel.open(protocol); + serverChannel.configureBlocking(false); + serverChannel.register(this.selector, SelectionKey.OP_ACCEPT); + configureServerSocketChannel(serverChannel, address); + serverChannel.bind(address); + // Start thread/processing of incoming accept, read + super.start((client) -> { + if (logger.isDebugEnabled()) { + logger.debug("Setting client=" + client); + } + this.acceptedClient = client; + if (acceptHandler != null) { + acceptHandler.apply(this.acceptedClient); + } + // No/null connect handler for Acceptors...only accepthandler + }, null, readHandler); + } + + @Override + protected void handleException(SelectionKey key, Exception e) { + if (logger.isDebugEnabled()) { + logger.debug("handleException", e); + } + close(); + } + + public void writeMessage(String message) throws IOException { + SocketChannel c = this.acceptedClient; + if (c != null) { + writeMessageToChannel(c, message); + } + else { + throw new IOException("not connected"); + } + } + + @Override + public void close() { + SocketChannel client = this.acceptedClient; + if (client != null) { + hardCloseClient(client, (c) -> { + if (logger.isDebugEnabled()) { + logger.debug("Unsetting client=" + c); + } + this.acceptedClient = null; + }); + } + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/UDSClientSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/UDSClientSocketChannel.java new file mode 100644 index 000000000..93539c852 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/UDSClientSocketChannel.java @@ -0,0 +1,33 @@ +package io.modelcontextprotocol.util; + +import java.io.IOException; +import java.net.StandardProtocolFamily; +import java.net.UnixDomainSocketAddress; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.util.concurrent.ExecutorService; + +public class UDSClientSocketChannel extends ClientSocketChannel { + + public UDSClientSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { + super(selector, incomingBufferSize, executor); + } + + public UDSClientSocketChannel() throws IOException { + super(); + } + + public UDSClientSocketChannel(Selector selector, int incomingBufferSize) { + super(selector, incomingBufferSize); + } + + public UDSClientSocketChannel(Selector selector) { + super(selector); + } + + public void connect(UnixDomainSocketAddress address, IOConsumer connectHandler, + IOConsumer readHandler) throws IOException { + super.connect(StandardProtocolFamily.UNIX, address, connectHandler, readHandler); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/UDSServerSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/UDSServerSocketChannel.java new file mode 100644 index 000000000..a6607cf17 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/UDSServerSocketChannel.java @@ -0,0 +1,33 @@ +package io.modelcontextprotocol.util; + +import java.io.IOException; +import java.net.StandardProtocolFamily; +import java.net.UnixDomainSocketAddress; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.util.concurrent.ExecutorService; + +public class UDSServerSocketChannel extends ServSocketChannel { + + public UDSServerSocketChannel() throws IOException { + super(); + } + + public UDSServerSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { + super(selector, incomingBufferSize, executor); + } + + public UDSServerSocketChannel(Selector selector, int incomingBufferSize) { + super(selector, incomingBufferSize); + } + + public UDSServerSocketChannel(Selector selector) { + super(selector); + } + + public void start(UnixDomainSocketAddress address, IOConsumer acceptHandler, + IOConsumer readHandler) throws IOException { + super.start(StandardProtocolFamily.UNIX, address, acceptHandler, readHandler); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpAsyncClientTests.java new file mode 100644 index 000000000..ad43cd8c9 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpAsyncClientTests.java @@ -0,0 +1,64 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.io.IOException; +import java.net.UnixDomainSocketAddress; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; + +import org.junit.jupiter.api.Timeout; + +import io.modelcontextprotocol.client.transport.UDSClientTransportProvider; +import io.modelcontextprotocol.server.TestEverythingServer; +import io.modelcontextprotocol.server.transport.UDSServerTransportProvider; +import io.modelcontextprotocol.spec.McpClientTransport; + +/** + * Tests for the {@link McpAyncClient} with {@link UDSClientTransport}. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + * @author Scott Lewis + */ +@Timeout(150) // Giving extra time beyond the client timeout +class UDSMcpAsyncClientTests extends AbstractMcpAsyncClientTests { + + private Path socketPath = Paths.get(getClass().getName() + ".unix.socket"); + + private void deleteSocketPath() { + try { + Files.deleteIfExists(socketPath); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + protected void onStart() { + super.onStart(); + deleteSocketPath(); + this.server = new TestEverythingServer(new UDSServerTransportProvider(UnixDomainSocketAddress.of(socketPath))); + } + + @Override + protected void onClose() { + super.onClose(); + if (server != null) { + server.closeGracefully(); + server = null; + } + deleteSocketPath(); + } + + private TestEverythingServer server; + + @Override + protected McpClientTransport createMcpTransport() { + return new UDSClientTransportProvider(UnixDomainSocketAddress.of(socketPath)); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpSyncClientTests.java new file mode 100644 index 000000000..6d2fc8b59 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpSyncClientTests.java @@ -0,0 +1,69 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.io.IOException; +import java.net.UnixDomainSocketAddress; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.time.Duration; + +import org.junit.jupiter.api.Timeout; + +import io.modelcontextprotocol.client.transport.UDSClientTransportProvider; +import io.modelcontextprotocol.server.TestEverythingServer; +import io.modelcontextprotocol.server.transport.UDSServerTransportProvider; +import io.modelcontextprotocol.spec.McpClientTransport; + +/** + * Tests for the {@link McpSyncClient} with {@link UDSClientTransport}. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + * @author Scott Lewis + */ +@Timeout(15) // Giving extra time beyond the client timeout +class UDSMcpSyncClientTests extends AbstractMcpSyncClientTests { + + private Path socketPath = Paths.get(getClass().getName() + ".unix.socket"); + + private void deleteSocketPath() { + try { + Files.deleteIfExists(socketPath); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + protected void onStart() { + super.onStart(); + deleteSocketPath(); + this.server = new TestEverythingServer(new UDSServerTransportProvider(UnixDomainSocketAddress.of(socketPath))); + } + + @Override + protected void onClose() { + super.onClose(); + if (server != null) { + server.closeGracefully(); + server = null; + } + deleteSocketPath(); + } + + private TestEverythingServer server; + + @Override + protected McpClientTransport createMcpTransport() { + return new UDSClientTransportProvider(UnixDomainSocketAddress.of(socketPath)); + } + + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(2); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/TestEverythingServer.java b/mcp/src/test/java/io/modelcontextprotocol/server/TestEverythingServer.java new file mode 100644 index 000000000..6d5fcf9b3 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/TestEverythingServer.java @@ -0,0 +1,148 @@ +package io.modelcontextprotocol.server; + +import java.util.List; + +import io.modelcontextprotocol.server.McpServerFeatures.SyncPromptSpecification; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema.Annotations; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.spec.McpSchema.ProgressNotification; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptMessage; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.Resource; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.SamplingMessage; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import io.modelcontextprotocol.spec.McpSchema.TextResourceContents; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest.ContextInclusionStrategy; + +public class TestEverythingServer { + + private static final String TEST_RESOURCE_URI = "test://resources/"; + + private static final String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + private McpSyncServer server; + + public TestEverythingServer(McpServerTransportProvider transport) { + McpServerFeatures.SyncResourceSpecification[] specs = new McpServerFeatures.SyncResourceSpecification[10]; + for (int i = 0; i < 10; i++) { + String istr = String.valueOf(i); + String uri = TEST_RESOURCE_URI + istr; + specs[i] = new McpServerFeatures.SyncResourceSpecification( + Resource.builder() + .uri(uri) + .name("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(), + (exchange, + req) -> new ReadResourceResult(List.of(new TextResourceContents(uri, "text/plain", istr)))); + } + + this.server = McpServer.sync(transport) + .serverInfo(getClass().getName() + "-server", "1.0.0") + .capabilities( + ServerCapabilities.builder().logging().tools(true).prompts(true).resources(true, true).build()) + .toolCall(Tool.builder() + .name("echo") + .description("echo tool description") + .inputSchema(emptyJsonSchema) + .build(), (exchange, request) -> { + return CallToolResult.builder().addTextContent((String) request.arguments().get("message")).build(); + }) + .toolCall(Tool.builder().name("add").description("add two integers").inputSchema(emptyJsonSchema).build(), + (exchange, request) -> { + Integer a = (Integer) request.arguments().get("a"); + Integer b = (Integer) request.arguments().get("b"); + + return CallToolResult.builder().addTextContent(String.valueOf(a + b)).build(); + }) + .toolCall( + Tool.builder().name("sampleLLM").description("sampleLLM tool").inputSchema(emptyJsonSchema).build(), + (exchange, request) -> { + String prompt = (String) request.arguments().get("prompt"); + Integer maxTokens = (Integer) request.arguments().get("maxTokens"); + SamplingMessage sm = new SamplingMessage(McpSchema.Role.USER, + new TextContent("Resource sampleLLM context: " + prompt)); + CreateMessageRequest cmRequest = CreateMessageRequest.builder() + .messages(List.of(sm)) + .systemPrompt("You are a helpful test server.") + .maxTokens(maxTokens) + .temperature(0.7) + .includeContext(ContextInclusionStrategy.THIS_SERVER) + .build(); + CreateMessageResult result = exchange.createMessage(cmRequest); + + return CallToolResult.builder() + .addTextContent("LLM sampling result: " + ((TextContent) result.content()).text()) + .build(); + }) + .toolCall(Tool.builder() + .name("longRunningOperation") + .description("Demonstrates a long running operation with progress updates") + .inputSchema(emptyJsonSchema) + .build(), (exchange, request) -> { + String progressToken = (String) request.progressToken(); + int steps = (Integer) request.arguments().get("steps"); + for (int i = 0; i < steps; i++) { + try { + Thread.sleep(1000); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + if (progressToken != null) { + exchange.progressNotification(new ProgressNotification(progressToken, (double) i + 1, + (double) steps, "progress message " + String.valueOf(i + 1))); + } + } + return CallToolResult.builder().content(List.of(new TextContent("done"))).build(); + }) + .toolCall(Tool.builder().name("annotatedMessage").description("annotated message").build(), + (exchange, request) -> { + String messageType = (String) request.arguments().get("messageType"); + Annotations annotations = null; + if (messageType.equals("success")) { + annotations = new Annotations(List.of(McpSchema.Role.USER), 0.7); + } + else if (messageType.equals("error")) { + annotations = new Annotations(List.of(McpSchema.Role.USER, McpSchema.Role.ASSISTANT), 1.0); + } + else if (messageType.equals("debug")) { + annotations = new Annotations(List.of(McpSchema.Role.ASSISTANT), 0.3); + } + return CallToolResult.builder() + .addContent(new TextContent(annotations, "some response")) + .build(); + }) + .prompts(List.of(new SyncPromptSpecification(new Prompt("simple_prompt", "Simple prompt description", null), + (exchange, request) -> { + return new GetPromptResult("description", + List.of(new PromptMessage(Role.USER, new TextContent("hello")))); + }))) + .resources(specs) + .build(); + } + + public void closeGracefully() { + if (this.server != null) { + this.server.closeGracefully(); + this.server = null; + } + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java new file mode 100644 index 000000000..8d7931a65 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java @@ -0,0 +1,58 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.io.IOException; +import java.net.UnixDomainSocketAddress; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; + +import org.junit.jupiter.api.Timeout; + +import io.modelcontextprotocol.server.transport.UDSServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; + +/** + * Tests for {@link McpAsyncServer} using {@link UDSServerTransport}. + * + * @author Christian Tzolov + * @author Scott Lewis + */ +@Timeout(15) // Giving extra time beyond the client timeout +class UDSMcpAsyncServerTests extends AbstractMcpAsyncServerTests { + + private Path socketPath = Paths.get(getClass().getName() + ".unix.socket"); + + private void deleteSocketPath() { + try { + Files.deleteIfExists(socketPath); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + protected void onStart() { + super.onStart(); + deleteSocketPath(); + } + + @Override + protected void onClose() { + super.onClose(); + deleteSocketPath(); + } + + protected McpServerTransportProvider createMcpTransportProvider() { + return new UDSServerTransportProvider(UnixDomainSocketAddress.of(socketPath)); + } + + @Override + protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(createMcpTransportProvider()); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java new file mode 100644 index 000000000..795b1b2e7 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java @@ -0,0 +1,58 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.io.IOException; +import java.net.UnixDomainSocketAddress; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; + +import org.junit.jupiter.api.Timeout; + +import io.modelcontextprotocol.server.transport.UDSServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; + +/** + * Tests for {@link McpSyncServer} using {@link UDSServerTransportProvider}. + * + * @author Christian Tzolov + * @author Scott Lewis + */ +@Timeout(15) // Giving extra time beyond the client timeout +class UDSMcpSyncServerTests extends AbstractMcpSyncServerTests { + + private Path socketPath = Paths.get(getClass().getName() + ".unix.socket"); + + private void deleteSocketPath() { + try { + Files.deleteIfExists(socketPath); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + protected void onStart() { + super.onStart(); + deleteSocketPath(); + } + + @Override + protected void onClose() { + super.onClose(); + deleteSocketPath(); + } + + protected McpServerTransportProvider createMcpTransportProvider() { + return new UDSServerTransportProvider(UnixDomainSocketAddress.of(socketPath)); + } + + @Override + protected McpServer.SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(createMcpTransportProvider()); + } + +} From 570ad19bfd94840b9dd5c2d77191ab35c4b2c609 Mon Sep 17 00:00:00 2001 From: Scott Lewis Date: Wed, 13 Aug 2025 18:54:23 -0700 Subject: [PATCH 20/25] Removed unnecessary constructor --- .../transport/UDSClientTransportProvider.java | 20 ------------------- 1 file changed, 20 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java index 26bbc7bd2..e27599d41 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java @@ -76,26 +76,6 @@ protected void handleException(SelectionKey key, Exception e) { this.outboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), "outbound"); } - /** - * Creates a new StdioClientTransport with the specified parameters and ObjectMapper. - * @param params The parameters for configuring the server process - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - */ - public UDSClientTransportProvider(ServerParameters params, ObjectMapper objectMapper) { - Assert.notNull(params, "The params can not be null"); - Assert.notNull(objectMapper, "The ObjectMapper can not be null"); - - this.inboundSink = Sinks.many().unicast().onBackpressureBuffer(); - this.outboundSink = Sinks.many().unicast().onBackpressureBuffer(); - - this.objectMapper = objectMapper; - - this.errorSink = Sinks.many().unicast().onBackpressureBuffer(); - - // Start thread for outbound - this.outboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), "outbound"); - } - /** * Starts the server process and initializes the message processing streams. This * method sets up the process with the configured command, arguments, and environment, From 4fc34177405ebdbf8a19ede60781bb882013ee5c Mon Sep 17 00:00:00 2001 From: Scott Lewis Date: Wed, 13 Aug 2025 19:39:57 -0700 Subject: [PATCH 21/25] Fixes for debug output --- .../transport/UDSClientTransportProvider.java | 6 +++- .../transport/UDSServerTransportProvider.java | 5 ++-- .../util/AbstractSocketChannel.java | 28 +++++++++---------- .../util/ClientSocketChannel.java | 6 ++-- 4 files changed, 24 insertions(+), 21 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java index e27599d41..d67c42d69 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java @@ -27,7 +27,7 @@ public class UDSClientTransportProvider implements McpClientTransport { - private static final Logger logger = LoggerFactory.getLogger(StdioClientTransport.class); + private static final Logger logger = LoggerFactory.getLogger(UDSClientTransportProvider.class); private final Sinks.Many inboundSink; @@ -221,6 +221,10 @@ public Mono closeGracefully() { })).then(Mono.fromRunnable(() -> { try { outboundScheduler.dispose(); + if (this.clientChannel != null) { + this.clientChannel.close(); + this.clientChannel = null; + } logger.debug("Graceful shutdown completed"); } catch (Exception e) { diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java index 8911f2797..2c3210bef 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java @@ -14,7 +14,6 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; import io.modelcontextprotocol.spec.McpServerSession; @@ -31,7 +30,7 @@ public class UDSServerTransportProvider implements McpServerTransportProvider { - private static final Logger logger = LoggerFactory.getLogger(StdioServerTransportProvider.class); + private static final Logger logger = LoggerFactory.getLogger(UDSServerTransportProvider.class); private final ObjectMapper objectMapper; @@ -81,7 +80,7 @@ public void setSessionFactory(McpServerSession.Factory sessionFactory) { @Override public Mono notifyClients(String method, Object params) { if (this.session == null) { - return Mono.error(new McpError("No session to close")); + return Mono.error(new Exception("No session to close")); } return this.session.sendNotification(method, params) .doOnError(e -> logger.error("Failed to send notification: {}", e.getMessage())); diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/AbstractSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/AbstractSocketChannel.java index c6736c378..87fa77841 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/AbstractSocketChannel.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/AbstractSocketChannel.java @@ -106,7 +106,7 @@ protected Runnable getRunnableForProcessing(IOConsumer acceptHand try { while (true) { int count = this.selector.select(); - debug("Select returned count=%s", count); + debug("select returned count={}", count); Set selectedKeys = selector.selectedKeys(); Iterator iter = selectedKeys.iterator(); while (iter.hasNext()) { @@ -151,12 +151,12 @@ protected void debug(String format, Object... o) { // For client subclasses protected void handleConnectable(SelectionKey key, IOConsumer connectHandler) throws IOException { SocketChannel client = (SocketChannel) key.channel(); - debug("client=%s", client); + debug("client={}", client); client.configureBlocking(false); client.register(this.selector, SelectionKey.OP_READ, new AttachedIO()); if (client.isConnectionPending()) { client.finishConnect(); - debug("connected client=%s", client); + debug("connected client={}", client); } if (connectHandler != null) { connectHandler.apply(client); @@ -166,13 +166,13 @@ protected void handleConnectable(SelectionKey key, IOConsumer con protected void handleAcceptable(SelectionKey key, IOConsumer acceptHandler) throws IOException { ServerSocketChannel serverSocket = (ServerSocketChannel) key.channel(); SocketChannel client = serverSocket.accept(); - debug("client=%s", client); + debug("client={}", client); client.configureBlocking(false); client.register(this.selector, SelectionKey.OP_READ, new AttachedIO()); configureAcceptSocketChannel(client); if (client.isConnectionPending()) { client.finishConnect(); - debug("accepted client=%s", client); + debug("accepted client={}", client); } if (acceptHandler != null) { acceptHandler.apply(client); @@ -194,7 +194,7 @@ protected AttachedIO getAttachedIO(SelectionKey key) throws IOException { protected void handleReadable(SelectionKey key, IOConsumer readHandler) throws IOException { SocketChannel client = (SocketChannel) key.channel(); AttachedIO io = getAttachedIO(key); - debug("read client=%s", client); + debug("read client={}", client); // read int r = client.read(this.inBuffer); // Check if we should expect any more reads @@ -212,7 +212,7 @@ protected void handleReadable(SelectionKey key, IOConsumer readHandler) String message = sb.toString(); // Set the io.reading value to null as we are done with this message io.reading = null; - debug("read client=%s msg=", client, message); + debug("read client={} msg=", client, message); if (readHandler != null) { String[] messages = splitMessage(message); for (int i = 0; i < messages.length; i++) { @@ -222,7 +222,7 @@ protected void handleReadable(SelectionKey key, IOConsumer readHandler) } else { io.reading = sb; - debug("read partial=%s", partial); + debug("read partial={}", partial); } // Clear inbuffer for next read this.inBuffer.clear(); @@ -253,13 +253,13 @@ protected void doWrite(SelectionKey key, SocketChannel client, ByteBuffer buf, I synchronized (writeLock) { int written = client.write(buf); if (buf.hasRemaining()) { - debug("doWrite written=%s, remaining=%s", written, buf.remaining()); + debug("doWrite written={}, remaining={}", written, buf.remaining()); io.writing = buf.slice(); key.interestOpsOr(SelectionKey.OP_WRITE); } else { if (logger.isDebugEnabled()) { - logger.debug("doWrite message=%s", new String(buf.array(), 0, written)); + logger.debug("doWrite message={}", new String(buf.array(), 0, written)); } io.writing = null; key.interestOps(SelectionKey.OP_READ); @@ -287,7 +287,7 @@ protected void executorShutdown() { protected void hardCloseClient(SocketChannel client, IOConsumer closeHandler) { if (client != null) { - debug("hardClose client=%s", client); + debug("hardClose client={}", client); synchronized (writeLock) { try { if (closeHandler != null) { @@ -314,7 +314,7 @@ protected void writeMessageToChannel(SocketChannel client, String message) throw .replace("\r", "\\n") // add message delimiter .concat(DEFAULT_MESSAGE_DELIMITER); - debug("writing msg=%s", outputMessage); + debug("writing msg={}", outputMessage); synchronized (writeLock) { // do the non blocking write in thread while holding lock. doWrite(client, outputMessage, null); @@ -329,7 +329,7 @@ protected void writeMessageToChannel(SocketChannel client, String message) throw } // If write is *not* completed, then wait timeout /10 try { - debug("writeBlocking WAITING(ms)=%s msg=%s", String.valueOf(waitTime / 10), outputMessage); + debug("writeBlocking WAITING(ms)={} msg={}", String.valueOf(waitTime / 10), outputMessage); writeLock.wait(waitTime / 10); } catch (InterruptedException e) { @@ -340,7 +340,7 @@ protected void writeMessageToChannel(SocketChannel client, String message) throw throw new IOException("Write not completed. Non empty buffer remaining after timeout"); } } - debug("writing done msg=%s", outputMessage); + debug("writing done msg={}", outputMessage); } protected void configureConnectSocketChannel(SocketChannel client, SocketAddress connectAddress) diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/ClientSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/ClientSocketChannel.java index 4ec69672c..2e2f32ec8 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/ClientSocketChannel.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/ClientSocketChannel.java @@ -37,7 +37,7 @@ public ClientSocketChannel(Selector selector) { protected SocketChannel doConnect(SocketChannel client, SocketAddress address, IOConsumer connectHandler, IOConsumer readHandler) throws IOException { - debug("connect targetAddress=%s", address); + debug("connect targetAddress={}", address); client.configureBlocking(false); client.register(selector, SelectionKey.OP_CONNECT); configureConnectSocketChannel(client, address); @@ -54,7 +54,7 @@ protected SocketChannel doConnect(SocketChannel client, SocketAddress address, client.connect(address); try { - debug("connect targetAddress=%s", address); + debug("connect targetAddress={}", address); synchronized (connectLock) { connectLock.wait(this.connectTimeout); } @@ -63,7 +63,7 @@ protected SocketChannel doConnect(SocketChannel client, SocketAddress address, throw new IOException( "Connect to address=" + address + " timed out after " + String.valueOf(this.connectTimeout) + "ms"); } - debug("connected client=%s", client); + debug("connected client={}", client); return client; } From 0440e1851872a321b8f22a3b7da2399582f5db47 Mon Sep 17 00:00:00 2001 From: Scott Lewis Date: Wed, 13 Aug 2025 20:37:40 -0700 Subject: [PATCH 22/25] Generalization for handleException --- .../client/transport/UDSClientTransportProvider.java | 2 +- .../server/transport/UDSServerTransportProvider.java | 2 +- .../io/modelcontextprotocol/util/AbstractSocketChannel.java | 4 ++-- .../io/modelcontextprotocol/util/ClientSocketChannel.java | 2 +- .../java/io/modelcontextprotocol/util/ServSocketChannel.java | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java index d67c42d69..f59af5219 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java @@ -64,7 +64,7 @@ public UDSClientTransportProvider(ObjectMapper objectMapper, UnixDomainSocketAdd try { this.clientChannel = new UDSClientSocketChannel() { @Override - protected void handleException(SelectionKey key, Exception e) { + protected void handleException(SelectionKey key, Throwable e) { isClosing = true; super.handleException(key, e); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java index 2c3210bef..e65be89e6 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java @@ -118,7 +118,7 @@ public UDSMcpSessionTransport() { try { this.serverSocketChannel = new UDSServerSocketChannel() { @Override - protected void handleException(SelectionKey key, Exception e) { + protected void handleException(SelectionKey key, Throwable e) { isClosing.set(true); if (session != null) { session.close(); diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/AbstractSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/AbstractSocketChannel.java index 87fa77841..36501502d 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/AbstractSocketChannel.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/AbstractSocketChannel.java @@ -127,7 +127,7 @@ else if (key.isWritable()) { } } } - catch (Exception e) { + catch (Throwable e) { handleException(key, e); } }; @@ -135,7 +135,7 @@ else if (key.isWritable()) { public abstract void close(); - protected abstract void handleException(SelectionKey key, Exception e); + protected abstract void handleException(SelectionKey key, Throwable e); protected void start(IOConsumer acceptHandler, IOConsumer connectHandler, IOConsumer readHandler) throws IOException { diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/ClientSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/ClientSocketChannel.java index 2e2f32ec8..edda16717 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/ClientSocketChannel.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/ClientSocketChannel.java @@ -76,7 +76,7 @@ public void connect(StandardProtocolFamily protocol, SocketAddress address, } @Override - protected void handleException(SelectionKey key, Exception e) { + protected void handleException(SelectionKey key, Throwable e) { if (logger.isDebugEnabled()) { logger.debug("handleException", e); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/ServSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/ServSocketChannel.java index 78b18d9fe..0cdfd4ee0 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/ServSocketChannel.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/ServSocketChannel.java @@ -59,7 +59,7 @@ public void start(StandardProtocolFamily protocol, SocketAddress address, IOCons } @Override - protected void handleException(SelectionKey key, Exception e) { + protected void handleException(SelectionKey key, Throwable e) { if (logger.isDebugEnabled()) { logger.debug("handleException", e); } From a1cea715a6b583b16607c82b61f438b005f36edc Mon Sep 17 00:00:00 2001 From: Scott Lewis Date: Thu, 14 Aug 2025 16:50:40 -0700 Subject: [PATCH 23/25] Class renaming and added UdsMcpServerTransportProvider and UdsMcpClientTransport interfaces. --- .../client/transport/UdsMcpClientTransport.java | 11 +++++++++++ ...ider.java => UdsMcpClientTransportImpl.java} | 15 +++++++++------ .../UdsMcpServerTransportProvider.java | 11 +++++++++++ ...a => UdsMcpServerTransportProviderImpl.java} | 17 ++++++++++------- .../client/UDSMcpAsyncClientTests.java | 8 ++++---- .../client/UDSMcpSyncClientTests.java | 8 ++++---- .../server/UDSMcpAsyncServerTests.java | 9 ++++----- .../server/UDSMcpSyncServerTests.java | 6 +++--- 8 files changed, 56 insertions(+), 29 deletions(-) create mode 100644 mcp/src/main/java/io/modelcontextprotocol/client/transport/UdsMcpClientTransport.java rename mcp/src/main/java/io/modelcontextprotocol/client/transport/{UDSClientTransportProvider.java => UdsMcpClientTransportImpl.java} (93%) create mode 100644 mcp/src/main/java/io/modelcontextprotocol/server/transport/UdsMcpServerTransportProvider.java rename mcp/src/main/java/io/modelcontextprotocol/server/transport/{UDSServerTransportProvider.java => UdsMcpServerTransportProviderImpl.java} (92%) diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/UdsMcpClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/UdsMcpClientTransport.java new file mode 100644 index 000000000..96de8229d --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/UdsMcpClientTransport.java @@ -0,0 +1,11 @@ +package io.modelcontextprotocol.client.transport; + +import java.net.UnixDomainSocketAddress; + +import io.modelcontextprotocol.spec.McpClientTransport; + +public interface UdsMcpClientTransport extends McpClientTransport { + + UnixDomainSocketAddress getUdsAddress(); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/UdsMcpClientTransportImpl.java similarity index 93% rename from mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java rename to mcp/src/main/java/io/modelcontextprotocol/client/transport/UdsMcpClientTransportImpl.java index f59af5219..3ea50bb56 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/UdsMcpClientTransportImpl.java @@ -14,7 +14,6 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; import io.modelcontextprotocol.util.Assert; @@ -25,9 +24,9 @@ import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Schedulers; -public class UDSClientTransportProvider implements McpClientTransport { +public class UdsMcpClientTransportImpl implements UdsMcpClientTransport { - private static final Logger logger = LoggerFactory.getLogger(UDSClientTransportProvider.class); + private static final Logger logger = LoggerFactory.getLogger(UdsMcpClientTransportImpl.class); private final Sinks.Many inboundSink; @@ -49,11 +48,15 @@ public class UDSClientTransportProvider implements McpClientTransport { // visible for tests private Consumer stdErrorHandler = error -> logger.info("STDERR Message received: {}", error); - public UDSClientTransportProvider(UnixDomainSocketAddress targetAddress) { + public UnixDomainSocketAddress getUdsAddress() { + return this.targetAddress; + } + + public UdsMcpClientTransportImpl(UnixDomainSocketAddress targetAddress) { this(new ObjectMapper(), targetAddress); } - public UDSClientTransportProvider(ObjectMapper objectMapper, UnixDomainSocketAddress targetAddress) { + public UdsMcpClientTransportImpl(ObjectMapper objectMapper, UnixDomainSocketAddress targetAddress) { Assert.notNull(objectMapper, "objectMapper cannot be null"); this.objectMapper = objectMapper; Assert.notNull(objectMapper, "targetAddress cannot be null"); @@ -93,7 +96,7 @@ public Mono connect(Function, Mono> h try { this.clientChannel.connect(targetAddress, (client) -> { if (logger.isInfoEnabled()) { - logger.info("UDSClientTransportProvider CONNECTED to targetAddress=" + targetAddress); + logger.info("UdsMcpClientTransportImpl CONNECTED to targetAddress=" + targetAddress); } }, (message) -> { if (logger.isDebugEnabled()) { diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/UdsMcpServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UdsMcpServerTransportProvider.java new file mode 100644 index 000000000..3501ce4ab --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UdsMcpServerTransportProvider.java @@ -0,0 +1,11 @@ +package io.modelcontextprotocol.server.transport; + +import java.net.UnixDomainSocketAddress; + +import io.modelcontextprotocol.spec.McpServerTransportProvider; + +public interface UdsMcpServerTransportProvider extends McpServerTransportProvider { + + UnixDomainSocketAddress getUdsAddress(); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UdsMcpServerTransportProviderImpl.java similarity index 92% rename from mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java rename to mcp/src/main/java/io/modelcontextprotocol/server/transport/UdsMcpServerTransportProviderImpl.java index e65be89e6..9eb4203ab 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UdsMcpServerTransportProviderImpl.java @@ -18,7 +18,6 @@ import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerTransport; -import io.modelcontextprotocol.spec.McpServerTransportProvider; import io.modelcontextprotocol.spec.ProtocolVersions; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.UDSServerSocketChannel; @@ -28,9 +27,9 @@ import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Schedulers; -public class UDSServerTransportProvider implements McpServerTransportProvider { +public class UdsMcpServerTransportProviderImpl implements UdsMcpServerTransportProvider { - private static final Logger logger = LoggerFactory.getLogger(UDSServerTransportProvider.class); + private static final Logger logger = LoggerFactory.getLogger(UdsMcpServerTransportProviderImpl.class); private final ObjectMapper objectMapper; @@ -46,19 +45,23 @@ public class UDSServerTransportProvider implements McpServerTransportProvider { private UnixDomainSocketAddress targetAddress; + public UnixDomainSocketAddress getUdsAddress() { + return targetAddress; + } + /** - * Creates a new UDSServerTransportProvider with a default ObjectMapper + * Creates a new UdsMcpServerTransportProviderImpl with a default ObjectMapper * @param unixSocketAddress the UDS socket address to bind to. Must not be null. */ - public UDSServerTransportProvider(UnixDomainSocketAddress unixSocketAddress) { + public UdsMcpServerTransportProviderImpl(UnixDomainSocketAddress unixSocketAddress) { this(new ObjectMapper(), unixSocketAddress); } /** - * Creates a new UDSServerTransportProvider with the specified ObjectMapper + * Creates a new UdsMcpServerTransportProviderImpl with the specified ObjectMapper * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization */ - public UDSServerTransportProvider(ObjectMapper objectMapper, UnixDomainSocketAddress unixSocketAddress) { + public UdsMcpServerTransportProviderImpl(ObjectMapper objectMapper, UnixDomainSocketAddress unixSocketAddress) { Assert.notNull(objectMapper, "objectMapper cannot be null"); this.objectMapper = objectMapper; Assert.notNull(unixSocketAddress, "unixSocketAddress cannot be null"); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpAsyncClientTests.java index ad43cd8c9..79c97169a 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpAsyncClientTests.java @@ -12,9 +12,9 @@ import org.junit.jupiter.api.Timeout; -import io.modelcontextprotocol.client.transport.UDSClientTransportProvider; +import io.modelcontextprotocol.client.transport.UdsMcpClientTransportImpl; import io.modelcontextprotocol.server.TestEverythingServer; -import io.modelcontextprotocol.server.transport.UDSServerTransportProvider; +import io.modelcontextprotocol.server.transport.UdsMcpServerTransportProviderImpl; import io.modelcontextprotocol.spec.McpClientTransport; /** @@ -41,7 +41,7 @@ private void deleteSocketPath() { protected void onStart() { super.onStart(); deleteSocketPath(); - this.server = new TestEverythingServer(new UDSServerTransportProvider(UnixDomainSocketAddress.of(socketPath))); + this.server = new TestEverythingServer(new UdsMcpServerTransportProviderImpl(UnixDomainSocketAddress.of(socketPath))); } @Override @@ -58,7 +58,7 @@ protected void onClose() { @Override protected McpClientTransport createMcpTransport() { - return new UDSClientTransportProvider(UnixDomainSocketAddress.of(socketPath)); + return new UdsMcpClientTransportImpl(UnixDomainSocketAddress.of(socketPath)); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpSyncClientTests.java index 6d2fc8b59..d4bd95295 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpSyncClientTests.java @@ -13,9 +13,9 @@ import org.junit.jupiter.api.Timeout; -import io.modelcontextprotocol.client.transport.UDSClientTransportProvider; +import io.modelcontextprotocol.client.transport.UdsMcpClientTransportImpl; import io.modelcontextprotocol.server.TestEverythingServer; -import io.modelcontextprotocol.server.transport.UDSServerTransportProvider; +import io.modelcontextprotocol.server.transport.UdsMcpServerTransportProviderImpl; import io.modelcontextprotocol.spec.McpClientTransport; /** @@ -42,7 +42,7 @@ private void deleteSocketPath() { protected void onStart() { super.onStart(); deleteSocketPath(); - this.server = new TestEverythingServer(new UDSServerTransportProvider(UnixDomainSocketAddress.of(socketPath))); + this.server = new TestEverythingServer(new UdsMcpServerTransportProviderImpl(UnixDomainSocketAddress.of(socketPath))); } @Override @@ -59,7 +59,7 @@ protected void onClose() { @Override protected McpClientTransport createMcpTransport() { - return new UDSClientTransportProvider(UnixDomainSocketAddress.of(socketPath)); + return new UdsMcpClientTransportImpl(UnixDomainSocketAddress.of(socketPath)); } protected Duration getInitializationTimeout() { diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java index 8d7931a65..826ad33f8 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java @@ -12,7 +12,7 @@ import org.junit.jupiter.api.Timeout; -import io.modelcontextprotocol.server.transport.UDSServerTransportProvider; +import io.modelcontextprotocol.server.transport.UdsMcpServerTransportProviderImpl; import io.modelcontextprotocol.spec.McpServerTransportProvider; /** @@ -29,17 +29,16 @@ class UDSMcpAsyncServerTests extends AbstractMcpAsyncServerTests { private void deleteSocketPath() { try { Files.deleteIfExists(socketPath); - } - catch (IOException e) { + } catch (IOException e) { throw new RuntimeException(e); } } - protected void onStart() { super.onStart(); deleteSocketPath(); } + @Override protected void onClose() { super.onClose(); @@ -47,7 +46,7 @@ protected void onClose() { } protected McpServerTransportProvider createMcpTransportProvider() { - return new UDSServerTransportProvider(UnixDomainSocketAddress.of(socketPath)); + return new UdsMcpServerTransportProviderImpl(UnixDomainSocketAddress.of(socketPath)); } @Override diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java index 795b1b2e7..e36febafb 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java @@ -12,11 +12,11 @@ import org.junit.jupiter.api.Timeout; -import io.modelcontextprotocol.server.transport.UDSServerTransportProvider; +import io.modelcontextprotocol.server.transport.UdsMcpServerTransportProviderImpl; import io.modelcontextprotocol.spec.McpServerTransportProvider; /** - * Tests for {@link McpSyncServer} using {@link UDSServerTransportProvider}. + * Tests for {@link McpSyncServer} using {@link UdsMcpServerTransportProviderImpl}. * * @author Christian Tzolov * @author Scott Lewis @@ -47,7 +47,7 @@ protected void onClose() { } protected McpServerTransportProvider createMcpTransportProvider() { - return new UDSServerTransportProvider(UnixDomainSocketAddress.of(socketPath)); + return new UdsMcpServerTransportProviderImpl(UnixDomainSocketAddress.of(socketPath)); } @Override From 905b88015bc961797810b8da57209a82073615ed Mon Sep 17 00:00:00 2001 From: Scott Lewis Date: Thu, 14 Aug 2025 16:55:00 -0700 Subject: [PATCH 24/25] Formatting --- .../client/transport/UdsMcpClientTransport.java | 2 +- .../client/transport/UdsMcpClientTransportImpl.java | 2 +- .../server/transport/UdsMcpServerTransportProvider.java | 2 +- .../server/transport/UdsMcpServerTransportProviderImpl.java | 2 +- .../modelcontextprotocol/client/UDSMcpAsyncClientTests.java | 3 ++- .../modelcontextprotocol/client/UDSMcpSyncClientTests.java | 3 ++- .../modelcontextprotocol/server/UDSMcpAsyncServerTests.java | 5 +++-- 7 files changed, 11 insertions(+), 8 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/UdsMcpClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/UdsMcpClientTransport.java index 96de8229d..b6d8e9919 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/UdsMcpClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/UdsMcpClientTransport.java @@ -7,5 +7,5 @@ public interface UdsMcpClientTransport extends McpClientTransport { UnixDomainSocketAddress getUdsAddress(); - + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/UdsMcpClientTransportImpl.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/UdsMcpClientTransportImpl.java index 3ea50bb56..6599d69a0 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/UdsMcpClientTransportImpl.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/UdsMcpClientTransportImpl.java @@ -51,7 +51,7 @@ public class UdsMcpClientTransportImpl implements UdsMcpClientTransport { public UnixDomainSocketAddress getUdsAddress() { return this.targetAddress; } - + public UdsMcpClientTransportImpl(UnixDomainSocketAddress targetAddress) { this(new ObjectMapper(), targetAddress); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/UdsMcpServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UdsMcpServerTransportProvider.java index 3501ce4ab..84d69cd09 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/UdsMcpServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UdsMcpServerTransportProvider.java @@ -7,5 +7,5 @@ public interface UdsMcpServerTransportProvider extends McpServerTransportProvider { UnixDomainSocketAddress getUdsAddress(); - + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/UdsMcpServerTransportProviderImpl.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UdsMcpServerTransportProviderImpl.java index 9eb4203ab..7554bd918 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/UdsMcpServerTransportProviderImpl.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UdsMcpServerTransportProviderImpl.java @@ -48,7 +48,7 @@ public class UdsMcpServerTransportProviderImpl implements UdsMcpServerTransportP public UnixDomainSocketAddress getUdsAddress() { return targetAddress; } - + /** * Creates a new UdsMcpServerTransportProviderImpl with a default ObjectMapper * @param unixSocketAddress the UDS socket address to bind to. Must not be null. diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpAsyncClientTests.java index 79c97169a..701ad3f0c 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpAsyncClientTests.java @@ -41,7 +41,8 @@ private void deleteSocketPath() { protected void onStart() { super.onStart(); deleteSocketPath(); - this.server = new TestEverythingServer(new UdsMcpServerTransportProviderImpl(UnixDomainSocketAddress.of(socketPath))); + this.server = new TestEverythingServer( + new UdsMcpServerTransportProviderImpl(UnixDomainSocketAddress.of(socketPath))); } @Override diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpSyncClientTests.java index d4bd95295..93e19f90c 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpSyncClientTests.java @@ -42,7 +42,8 @@ private void deleteSocketPath() { protected void onStart() { super.onStart(); deleteSocketPath(); - this.server = new TestEverythingServer(new UdsMcpServerTransportProviderImpl(UnixDomainSocketAddress.of(socketPath))); + this.server = new TestEverythingServer( + new UdsMcpServerTransportProviderImpl(UnixDomainSocketAddress.of(socketPath))); } @Override diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java index 826ad33f8..74b088a81 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java @@ -29,16 +29,17 @@ class UDSMcpAsyncServerTests extends AbstractMcpAsyncServerTests { private void deleteSocketPath() { try { Files.deleteIfExists(socketPath); - } catch (IOException e) { + } + catch (IOException e) { throw new RuntimeException(e); } } + protected void onStart() { super.onStart(); deleteSocketPath(); } - @Override protected void onClose() { super.onClose(); From 1f77e7c3a25b1a828017af6dd1c8bbf716f7363f Mon Sep 17 00:00:00 2001 From: Scott Lewis Date: Mon, 18 Aug 2025 17:20:48 -0700 Subject: [PATCH 25/25] Fix to prevent server from sending notifications before a connection has occurred --- mcp/pom.xml | 2 +- .../server/transport/UdsMcpServerTransportProviderImpl.java | 3 --- .../java/io/modelcontextprotocol/util/ServSocketChannel.java | 5 ++++- pom.xml | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mcp/pom.xml b/mcp/pom.xml index 1cf61c48f..dc85a419e 100644 --- a/mcp/pom.xml +++ b/mcp/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.12.0-SNAPSHOT + 0.13.0-SNAPSHOT mcp jar diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/UdsMcpServerTransportProviderImpl.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UdsMcpServerTransportProviderImpl.java index 7554bd918..ab486e8eb 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/UdsMcpServerTransportProviderImpl.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UdsMcpServerTransportProviderImpl.java @@ -82,9 +82,6 @@ public void setSessionFactory(McpServerSession.Factory sessionFactory) { @Override public Mono notifyClients(String method, Object params) { - if (this.session == null) { - return Mono.error(new Exception("No session to close")); - } return this.session.sendNotification(method, params) .doOnError(e -> logger.error("Failed to send notification: {}", e.getMessage())); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/ServSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/ServSocketChannel.java index 0cdfd4ee0..35dba15bc 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/ServSocketChannel.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/ServSocketChannel.java @@ -72,7 +72,10 @@ public void writeMessage(String message) throws IOException { writeMessageToChannel(c, message); } else { - throw new IOException("not connected"); + if (logger.isDebugEnabled()) { + logger.debug("No connected client to send message={}", message); + } + ; } } diff --git a/pom.xml b/pom.xml index c0b1f7a44..9990a4663 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.12.0-SNAPSHOT + 0.13.0-SNAPSHOT pom https://github.com/modelcontextprotocol/java-sdk