fromRunnable(() -> {
+ handleIncomingMessages(handler);
+ handleIncomingErrors();
+
+ // Connect client channel
+ try {
+ this.clientChannel.connect(targetAddress, (client) -> {
+ if (logger.isInfoEnabled()) {
+ logger.info("UdsMcpClientTransportImpl CONNECTED to targetAddress=" + targetAddress);
+ }
+ }, (message) -> {
+ if (logger.isDebugEnabled()) {
+ logger.debug("received message=" + message);
+ }
+ // Incoming messages processed right here
+ McpSchema.JSONRPCMessage jsonMessage = McpSchema.deserializeJsonRpcMessage(objectMapper, message);
+ if (!this.inboundSink.tryEmitNext(jsonMessage).isSuccess()) {
+ if (!isClosing) {
+ if (logger.isDebugEnabled()) {
+ logger.error("Failed to enqueue inbound json message: {}", jsonMessage);
+ }
+ }
+ }
+ });
+ }
+ catch (IOException e) {
+ this.clientChannel.close();
+ throw new RuntimeException(
+ "Connect to address=" + targetAddress + " failed message: " + e.getMessage());
+ }
+
+ startOutboundProcessing();
+
+ }).subscribeOn(Schedulers.boundedElastic());
+ }
+
+ /**
+ * Sets the handler for processing transport-level errors.
+ *
+ *
+ * The provided handler will be called when errors occur during transport operations,
+ * such as connection failures or protocol violations.
+ *
+ * @param errorHandler a consumer that processes error messages
+ */
+ public void setStdErrorHandler(Consumer errorHandler) {
+ this.stdErrorHandler = errorHandler;
+ }
+
+ private void handleIncomingMessages(Function, Mono> inboundMessageHandler) {
+ this.inboundSink.asFlux()
+ .flatMap(message -> Mono.just(message)
+ .transform(inboundMessageHandler)
+ .contextWrite(ctx -> ctx.put("observation", "myObservation")))
+ .subscribe();
+ }
+
+ private void handleIncomingErrors() {
+ this.errorSink.asFlux().subscribe(e -> {
+ this.stdErrorHandler.accept(e);
+ });
+ }
+
+ @Override
+ public Mono sendMessage(JSONRPCMessage message) {
+ outboundSink.emitNext(message, (signalType, emitResult) -> {
+ // Allow retry
+ return true;
+ });
+ return Mono.empty();
+ }
+
+ /**
+ * Starts the outbound processing thread that writes JSON-RPC messages to the
+ * process's output stream. Messages are serialized to JSON and written with a newline
+ * delimiter.
+ */
+ private void startOutboundProcessing() {
+ this.handleOutbound(messages -> messages
+ // this bit is important since writes come from user threads, and we
+ // want to ensure that the actual writing happens on a dedicated thread
+ .publishOn(outboundScheduler)
+ .handle((message, sink) -> {
+ if (message != null && !isClosing) {
+ try {
+ clientChannel.writeMessage(objectMapper.writeValueAsString(message));
+ sink.next(message);
+ }
+ catch (IOException e) {
+ if (!isClosing) {
+ logger.error("Error writing message", e);
+ sink.error(new RuntimeException(e));
+ }
+ else {
+ logger.debug("Stream closed during shutdown", e);
+ }
+ }
+ }
+ }));
+ }
+
+ protected void handleOutbound(Function, Flux> outboundConsumer) {
+ outboundConsumer.apply(outboundSink.asFlux()).doOnComplete(() -> {
+ isClosing = true;
+ outboundSink.tryEmitComplete();
+ }).doOnError(e -> {
+ if (!isClosing) {
+ logger.error("Error in outbound processing", e);
+ isClosing = true;
+ outboundSink.tryEmitComplete();
+ }
+ }).subscribe();
+ }
+
+ /**
+ * Gracefully closes the transport by destroying the process and disposing of the
+ * schedulers. This method sends a TERM signal to the process and waits for it to exit
+ * before cleaning up resources.
+ * @return A Mono that completes when the transport is closed
+ */
+ @Override
+ public Mono closeGracefully() {
+ return Mono.fromRunnable(() -> {
+ isClosing = true;
+ logger.debug("Initiating graceful shutdown");
+ }).then(Mono.defer(() -> {
+ // First complete all sinks to stop accepting new messages
+ inboundSink.tryEmitComplete();
+ outboundSink.tryEmitComplete();
+ errorSink.tryEmitComplete();
+
+ // Give a short time for any pending messages to be processed
+ return Mono.delay(Duration.ofMillis(100)).then();
+ })).then(Mono.fromRunnable(() -> {
+ try {
+ outboundScheduler.dispose();
+ if (this.clientChannel != null) {
+ this.clientChannel.close();
+ this.clientChannel = null;
+ }
+ logger.debug("Graceful shutdown completed");
+ }
+ catch (Exception e) {
+ logger.error("Error during graceful shutdown", e);
+ }
+ })).then().subscribeOn(Schedulers.boundedElastic());
+ }
+
+ public Sinks.Many getErrorSink() {
+ return this.errorSink;
+ }
+
+ @Override
+ public T unmarshalFrom(Object data, TypeReference typeRef) {
+ return this.objectMapper.convertValue(data, typeRef);
+ }
+
+}
diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/UdsMcpServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UdsMcpServerTransportProvider.java
new file mode 100644
index 000000000..84d69cd09
--- /dev/null
+++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UdsMcpServerTransportProvider.java
@@ -0,0 +1,11 @@
+package io.modelcontextprotocol.server.transport;
+
+import java.net.UnixDomainSocketAddress;
+
+import io.modelcontextprotocol.spec.McpServerTransportProvider;
+
+public interface UdsMcpServerTransportProvider extends McpServerTransportProvider {
+
+ UnixDomainSocketAddress getUdsAddress();
+
+}
diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/UdsMcpServerTransportProviderImpl.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UdsMcpServerTransportProviderImpl.java
new file mode 100644
index 000000000..ab486e8eb
--- /dev/null
+++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UdsMcpServerTransportProviderImpl.java
@@ -0,0 +1,257 @@
+package io.modelcontextprotocol.server.transport;
+
+import java.io.IOException;
+import java.net.UnixDomainSocketAddress;
+import java.nio.channels.SelectionKey;
+import java.util.List;
+import java.util.concurrent.Executors;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.function.Function;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.fasterxml.jackson.core.type.TypeReference;
+import com.fasterxml.jackson.databind.ObjectMapper;
+
+import io.modelcontextprotocol.spec.McpSchema;
+import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage;
+import io.modelcontextprotocol.spec.McpServerSession;
+import io.modelcontextprotocol.spec.McpServerTransport;
+import io.modelcontextprotocol.spec.ProtocolVersions;
+import io.modelcontextprotocol.util.Assert;
+import io.modelcontextprotocol.util.UDSServerSocketChannel;
+import reactor.core.publisher.Flux;
+import reactor.core.publisher.Mono;
+import reactor.core.publisher.Sinks;
+import reactor.core.scheduler.Scheduler;
+import reactor.core.scheduler.Schedulers;
+
+public class UdsMcpServerTransportProviderImpl implements UdsMcpServerTransportProvider {
+
+ private static final Logger logger = LoggerFactory.getLogger(UdsMcpServerTransportProviderImpl.class);
+
+ private final ObjectMapper objectMapper;
+
+ private UDSMcpSessionTransport transport;
+
+ private McpServerSession session;
+
+ private final AtomicBoolean isClosing = new AtomicBoolean(false);
+
+ private final Sinks.One inboundReady = Sinks.one();
+
+ private final Sinks.One outboundReady = Sinks.one();
+
+ private UnixDomainSocketAddress targetAddress;
+
+ public UnixDomainSocketAddress getUdsAddress() {
+ return targetAddress;
+ }
+
+ /**
+ * Creates a new UdsMcpServerTransportProviderImpl with a default ObjectMapper
+ * @param unixSocketAddress the UDS socket address to bind to. Must not be null.
+ */
+ public UdsMcpServerTransportProviderImpl(UnixDomainSocketAddress unixSocketAddress) {
+ this(new ObjectMapper(), unixSocketAddress);
+ }
+
+ /**
+ * Creates a new UdsMcpServerTransportProviderImpl with the specified ObjectMapper
+ * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
+ */
+ public UdsMcpServerTransportProviderImpl(ObjectMapper objectMapper, UnixDomainSocketAddress unixSocketAddress) {
+ Assert.notNull(objectMapper, "objectMapper cannot be null");
+ this.objectMapper = objectMapper;
+ Assert.notNull(unixSocketAddress, "unixSocketAddress cannot be null");
+ this.targetAddress = unixSocketAddress;
+ }
+
+ @Override
+ public List protocolVersions() {
+ return List.of(ProtocolVersions.MCP_2024_11_05);
+ }
+
+ @Override
+ public void setSessionFactory(McpServerSession.Factory sessionFactory) {
+ this.transport = new UDSMcpSessionTransport();
+ this.session = sessionFactory.create(transport);
+ this.transport.initProcessing();
+ }
+
+ @Override
+ public Mono notifyClients(String method, Object params) {
+ return this.session.sendNotification(method, params)
+ .doOnError(e -> logger.error("Failed to send notification: {}", e.getMessage()));
+ }
+
+ @Override
+ public Mono closeGracefully() {
+ if (this.session == null) {
+ return Mono.empty();
+ }
+ return this.session.closeGracefully();
+ }
+
+ /**
+ * Implementation of McpServerTransport for the uds session.
+ */
+ private class UDSMcpSessionTransport implements McpServerTransport {
+
+ private final Sinks.Many inboundSink;
+
+ private final Sinks.Many outboundSink;
+
+ /** Scheduler for handling outbound messages */
+ private Scheduler outboundScheduler;
+
+ private final AtomicBoolean isStarted = new AtomicBoolean(false);
+
+ private final UDSServerSocketChannel serverSocketChannel;
+
+ public UDSMcpSessionTransport() {
+ this.inboundSink = Sinks.many().unicast().onBackpressureBuffer();
+ this.outboundSink = Sinks.many().unicast().onBackpressureBuffer();
+ this.outboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(),
+ "uds-outbound");
+ try {
+ this.serverSocketChannel = new UDSServerSocketChannel() {
+ @Override
+ protected void handleException(SelectionKey key, Throwable e) {
+ isClosing.set(true);
+ if (session != null) {
+ session.close();
+ session = null;
+ }
+ inboundSink.tryEmitComplete();
+ }
+ };
+ }
+ catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ public Mono sendMessage(McpSchema.JSONRPCMessage message) {
+ return Mono.zip(inboundReady.asMono(), outboundReady.asMono()).then(Mono.defer(() -> {
+ outboundSink.emitNext(message, (signalType, emitResult) -> {
+ // Allow retry
+ return true;
+ });
+ return Mono.empty();
+ }));
+ }
+
+ @Override
+ public T unmarshalFrom(Object data, TypeReference typeRef) {
+ return objectMapper.convertValue(data, typeRef);
+ }
+
+ @Override
+ public Mono closeGracefully() {
+ return Mono.fromRunnable(() -> {
+ isClosing.set(true);
+ logger.debug("Session transport closing gracefully");
+ inboundSink.tryEmitComplete();
+ });
+ }
+
+ @Override
+ public void close() {
+ isClosing.set(true);
+ logger.debug("Session transport closed");
+ }
+
+ private void initProcessing() {
+ handleIncomingMessages();
+ startInboundProcessing();
+ startOutboundProcessing();
+
+ inboundReady.tryEmitValue(null);
+ outboundReady.tryEmitValue(null);
+ }
+
+ private void handleIncomingMessages() {
+ this.inboundSink.asFlux().flatMap(message -> session.handle(message)).doOnTerminate(() -> {
+ this.outboundSink.tryEmitComplete();
+ }).subscribe();
+ }
+
+ /**
+ * Starts the inbound processing thread that reads JSON-RPC messages from stdin.
+ * Messages are deserialized and passed to the session for handling.
+ */
+ private void startInboundProcessing() {
+ if (isStarted.compareAndSet(false, true)) {
+ try {
+ this.serverSocketChannel.start(targetAddress, (clientChannel) -> {
+ if (logger.isDebugEnabled()) {
+ logger.debug("Accepted connect from clientChannel=" + clientChannel);
+ }
+ }, (message) -> {
+ if (logger.isDebugEnabled()) {
+ logger.debug("Received message=" + message);
+ }
+ // Incoming messages processed right here
+ McpSchema.JSONRPCMessage jsonMessage = McpSchema.deserializeJsonRpcMessage(objectMapper,
+ message);
+ if (!this.inboundSink.tryEmitNext(jsonMessage).isSuccess()) {
+ throw new IOException("Error adding jsonMessge to inboundSink");
+ }
+ });
+ }
+ catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ }
+
+ /**
+ * Starts the outbound processing thread that writes JSON-RPC messages to stdout.
+ * Messages are serialized to JSON and written with a newline delimiter.
+ */
+ private void startOutboundProcessing() {
+ Function, Flux> outboundConsumer = messages -> messages // @formatter:off
+ .doOnSubscribe(subscription -> outboundReady.tryEmitValue(null))
+ .publishOn(outboundScheduler)
+ .handle((message, sink) -> {
+ if (message != null && !isClosing.get()) {
+ try {
+ serverSocketChannel.writeMessage(objectMapper.writeValueAsString(message));
+ sink.next(message);
+ }
+ catch (IOException e) {
+ if (!isClosing.get()) {
+ logger.error("Error writing message", e);
+ sink.error(new RuntimeException(e));
+ }
+ else {
+ logger.debug("Stream closed during shutdown", e);
+ }
+ }
+ }
+ else if (isClosing.get()) {
+ sink.complete();
+ }
+ })
+ .doOnComplete(() -> {
+ isClosing.set(true);
+ outboundScheduler.dispose();
+ })
+ .doOnError(e -> {
+ if (!isClosing.get()) {
+ logger.error("Error in outbound processing", e);
+ isClosing.set(true);
+ outboundScheduler.dispose();
+ }
+ })
+ .map(msg -> (JSONRPCMessage) msg);
+
+ outboundConsumer.apply(outboundSink.asFlux()).subscribe();
+ } // @formatter:on
+
+ }
+
+}
diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/AbstractSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/AbstractSocketChannel.java
new file mode 100644
index 000000000..36501502d
--- /dev/null
+++ b/mcp/src/main/java/io/modelcontextprotocol/util/AbstractSocketChannel.java
@@ -0,0 +1,355 @@
+package io.modelcontextprotocol.util;
+
+import java.io.IOException;
+import java.io.InterruptedIOException;
+import java.net.SocketAddress;
+import java.nio.ByteBuffer;
+import java.nio.channels.SelectionKey;
+import java.nio.channels.Selector;
+import java.nio.channels.ServerSocketChannel;
+import java.nio.channels.SocketChannel;
+import java.nio.charset.StandardCharsets;
+import java.util.Iterator;
+import java.util.Objects;
+import java.util.Set;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public abstract class AbstractSocketChannel {
+
+ private static final Logger logger = LoggerFactory.getLogger(AbstractSocketChannel.class);
+
+ public static final int DEFAULT_INBUFFER_SIZE = 1024;
+
+ public static String DEFAULT_MESSAGE_DELIMITER = "\n";
+
+ protected String messageDelimiter = DEFAULT_MESSAGE_DELIMITER;
+
+ protected void setMessageDelimiter(String delim) {
+ this.messageDelimiter = delim;
+ }
+
+ public static int DEFAULT_WRITE_TIMEOUT = 5000; // ms
+
+ protected int writeTimeout = DEFAULT_WRITE_TIMEOUT;
+
+ protected void setWriteTimeout(int timeout) {
+ this.writeTimeout = timeout;
+ }
+
+ public static int DEFAULT_CONNECT_TIMEOUT = 10000; // ms
+
+ protected int connectTimeout = DEFAULT_CONNECT_TIMEOUT;
+
+ protected void setConnectTimeout(int timeout) {
+ this.connectTimeout = timeout;
+ }
+
+ public static int DEFAULT_TERMINATION_TIMEOUT = 2000; // ms
+
+ protected int terminationTimeout = DEFAULT_TERMINATION_TIMEOUT;
+
+ protected void setTerminationTimeout(int timeout) {
+ this.terminationTimeout = timeout;
+ }
+
+ protected final Selector selector;
+
+ protected final ByteBuffer inBuffer;
+
+ protected final ExecutorService executor;
+
+ private final Object writeLock = new Object();
+
+ @FunctionalInterface
+ public interface IOConsumer {
+
+ void apply(T t) throws IOException;
+
+ }
+
+ protected class AttachedIO {
+
+ public ByteBuffer writing;
+
+ public StringBuffer reading;
+
+ }
+
+ public AbstractSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) {
+ Assert.notNull(selector, "Selector must not be null");
+ this.selector = selector;
+ this.inBuffer = ByteBuffer.allocate(incomingBufferSize);
+ this.executor = (executor == null) ? Executors.newSingleThreadExecutor() : executor;
+ }
+
+ public AbstractSocketChannel(Selector selector, int incomingBufferSize) {
+ this(selector, incomingBufferSize, null);
+ }
+
+ public AbstractSocketChannel(Selector selector) {
+ this(selector, DEFAULT_INBUFFER_SIZE);
+ }
+
+ public AbstractSocketChannel() throws IOException {
+ this(Selector.open());
+ }
+
+ protected Runnable getRunnableForProcessing(IOConsumer acceptHandler,
+ IOConsumer connectHandler, IOConsumer readHandler) {
+ return () -> {
+ SelectionKey key = null;
+ try {
+ while (true) {
+ int count = this.selector.select();
+ debug("select returned count={}", count);
+ Set selectedKeys = selector.selectedKeys();
+ Iterator iter = selectedKeys.iterator();
+ while (iter.hasNext()) {
+ key = iter.next();
+ if (key.isConnectable()) {
+ handleConnectable(key, connectHandler);
+ }
+ else if (key.isAcceptable()) {
+ handleAcceptable(key, acceptHandler);
+ }
+ else if (key.isReadable()) {
+ handleReadable(key, readHandler);
+ }
+ else if (key.isWritable()) {
+ handleWritable(key);
+ }
+ iter.remove();
+ }
+ }
+ }
+ catch (Throwable e) {
+ handleException(key, e);
+ }
+ };
+ }
+
+ public abstract void close();
+
+ protected abstract void handleException(SelectionKey key, Throwable e);
+
+ protected void start(IOConsumer acceptHandler, IOConsumer connectHandler,
+ IOConsumer readHandler) throws IOException {
+ this.executor.execute(getRunnableForProcessing(acceptHandler, connectHandler, readHandler));
+ }
+
+ protected void debug(String format, Object... o) {
+ if (logger.isDebugEnabled()) {
+ logger.debug(format, o);
+ }
+ }
+
+ // For client subclasses
+ protected void handleConnectable(SelectionKey key, IOConsumer connectHandler) throws IOException {
+ SocketChannel client = (SocketChannel) key.channel();
+ debug("client={}", client);
+ client.configureBlocking(false);
+ client.register(this.selector, SelectionKey.OP_READ, new AttachedIO());
+ if (client.isConnectionPending()) {
+ client.finishConnect();
+ debug("connected client={}", client);
+ }
+ if (connectHandler != null) {
+ connectHandler.apply(client);
+ }
+ }
+
+ protected void handleAcceptable(SelectionKey key, IOConsumer acceptHandler) throws IOException {
+ ServerSocketChannel serverSocket = (ServerSocketChannel) key.channel();
+ SocketChannel client = serverSocket.accept();
+ debug("client={}", client);
+ client.configureBlocking(false);
+ client.register(this.selector, SelectionKey.OP_READ, new AttachedIO());
+ configureAcceptSocketChannel(client);
+ if (client.isConnectionPending()) {
+ client.finishConnect();
+ debug("accepted client={}", client);
+ }
+ if (acceptHandler != null) {
+ acceptHandler.apply(client);
+ }
+ }
+
+ protected void configureAcceptSocketChannel(SocketChannel client) throws IOException {
+ // Subclasses may override
+ }
+
+ protected AttachedIO getAttachedIO(SelectionKey key) throws IOException {
+ AttachedIO io = (AttachedIO) key.attachment();
+ if (io == null) {
+ throw new IOException("No AttachedIO object found on key");
+ }
+ return io;
+ }
+
+ protected void handleReadable(SelectionKey key, IOConsumer readHandler) throws IOException {
+ SocketChannel client = (SocketChannel) key.channel();
+ AttachedIO io = getAttachedIO(key);
+ debug("read client={}", client);
+ // read
+ int r = client.read(this.inBuffer);
+ // Check if we should expect any more reads
+ if (r == -1) {
+ throw new IOException("Channel read reached end of stream");
+ }
+ this.inBuffer.flip();
+ String partial = new String(this.inBuffer.array(), 0, r, StandardCharsets.UTF_8);
+ // If there is previous partial, get the io.reading string Buffer
+ StringBuffer sb = (io.reading != null) ? (StringBuffer) io.reading : new StringBuffer();
+ // append the just read partial to the existing or new string buffer
+ sb.append(partial);
+ if (partial.endsWith(messageDelimiter)) {
+ // Get the entire message from the string buffer
+ String message = sb.toString();
+ // Set the io.reading value to null as we are done with this message
+ io.reading = null;
+ debug("read client={} msg=", client, message);
+ if (readHandler != null) {
+ String[] messages = splitMessage(message);
+ for (int i = 0; i < messages.length; i++) {
+ readHandler.apply(messages[i]);
+ }
+ }
+ }
+ else {
+ io.reading = sb;
+ debug("read partial={}", partial);
+ }
+ // Clear inbuffer for next read
+ this.inBuffer.clear();
+ }
+
+ protected void handleWritable(SelectionKey key) throws IOException {
+ ByteBuffer buf = getAttachedIO(key).writing;
+ SocketChannel client = (SocketChannel) key.channel();
+ if (buf != null) {
+ doWrite(key, client, buf, (o) -> {
+ synchronized (writeLock) {
+ writeLock.notifyAll();
+ }
+ });
+ }
+ }
+
+ protected void doWrite(SocketChannel client, String message, IOConsumer