From ab92e4227b6ef19bb41475f573650d17d47b1220 Mon Sep 17 00:00:00 2001 From: ivanfioravanti Date: Sun, 1 Jun 2025 18:19:55 +0200 Subject: [PATCH 1/2] feat: add support for Streamable HTTP transport and update README --- README.md | 3 ++- mcp_server_rabbitmq/server.py | 14 +++++++++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index e819484..a1e8e74 100644 --- a/README.md +++ b/README.md @@ -52,9 +52,10 @@ https://smithery.ai/server/@kenliao94/mcp-server-rabbitmq ## Roadmap 1. Expose admin API tools and pika SDK tools -1. Support Streamable HTTP when it is GA in Python SDK 1. Support OAuth 2.1 and use it with RabbitMQ OAuth +✅ Support Streamable HTTP now that is GA in Python SDK + ## Development ### Setup Development Environment diff --git a/mcp_server_rabbitmq/server.py b/mcp_server_rabbitmq/server.py index 245d30a..e10df49 100644 --- a/mcp_server_rabbitmq/server.py +++ b/mcp_server_rabbitmq/server.py @@ -220,9 +220,15 @@ def run(self, args): self.logger.info(f"Starting RabbitMQ MCP Server v{MCP_SERVER_VERSION}") self.logger.info(f"Connecting to RabbitMQ at {self.rabbitmq_host}:{self.rabbitmq_port}") - if args.sse: + # Set port if specified + if args.server_port: self.mcp.settings.port = args.server_port + + # Determine transport type and run + if args.sse: self.mcp.run(transport="sse") + elif args.streamable_http: + self.mcp.run(transport="streamable-http") else: self.mcp.run() @@ -243,6 +249,12 @@ def main(): "--api-port", type=int, default=15671, help="Port for the RabbitMQ management API" ) parser.add_argument("--sse", action="store_true", help="Use SSE transport") + parser.add_argument( + "--streamable-http", + dest="streamable_http", + action="store_true", + help="Use Streamable HTTP transport", + ) parser.add_argument( "--server-port", type=int, default=8888, help="Port to run the MCP server on" ) From 1a29b488dceba8cc550a390e448b5af182416e97 Mon Sep 17 00:00:00 2001 From: ivanfioravanti Date: Sun, 1 Jun 2025 19:21:08 +0200 Subject: [PATCH 2/2] feat: add handle_get_messages function to retrieve and acknowledge/nack messages from RabbitMQ queue --- mcp_server_rabbitmq/handlers.py | 31 ++++++++++++ mcp_server_rabbitmq/server.py | 31 ++++++++++++ tests/test_handlers.py | 85 +++++++++++++++++++++++++++++++++ 3 files changed, 147 insertions(+) diff --git a/mcp_server_rabbitmq/handlers.py b/mcp_server_rabbitmq/handlers.py index 134f780..d0038bb 100644 --- a/mcp_server_rabbitmq/handlers.py +++ b/mcp_server_rabbitmq/handlers.py @@ -48,3 +48,34 @@ def handle_get_exchange_info( rabbitmq_admin: RabbitMQAdmin, exchange: str, vhost: str = "/" ) -> dict: return rabbitmq_admin.get_exchange_info(exchange, vhost) + + +def handle_get_messages( + rabbitmq: RabbitMQConnection, queue: str, ack: bool = False, num_messages: int = 1 +) -> list[dict]: + """ + Get up to num_messages from a queue and either ack (dequeue) or nack (requeue) each message after all are fetched. + Returns a list of dicts with 'body' and 'delivery_tag'. + Matches RabbitMQ Management UI behavior. + """ + connection, channel = rabbitmq.get_channel() + messages = [] + method_frames = [] + try: + for _ in range(num_messages): + method_frame, header_frame, body = channel.basic_get(queue=queue, auto_ack=False) + if method_frame is None: + break + messages.append( + {"body": body.decode() if body else "", "delivery_tag": method_frame.delivery_tag} + ) + method_frames.append(method_frame) + # After fetching, ack or nack (requeue) each message as per the flag + for method_frame in method_frames: + if ack: + channel.basic_ack(delivery_tag=method_frame.delivery_tag) + else: + channel.basic_nack(delivery_tag=method_frame.delivery_tag, requeue=True) + return messages + finally: + connection.close() diff --git a/mcp_server_rabbitmq/server.py b/mcp_server_rabbitmq/server.py index e10df49..2f7948c 100644 --- a/mcp_server_rabbitmq/server.py +++ b/mcp_server_rabbitmq/server.py @@ -14,6 +14,7 @@ handle_enqueue, handle_fanout, handle_get_exchange_info, + handle_get_messages, handle_get_queue_info, handle_list_exchanges, handle_list_queues, @@ -215,6 +216,36 @@ def get_exchange_info(exchange: str, vhost: str = "/") -> str: self.logger.error(f"{e}") return f"Failed to get exchange info: {e}" + @self.mcp.tool() + def get_messages(queue: str, ack: bool = False, num_messages: int = 1) -> str: + """ + Get up to num_messages from a queue and either ack (dequeue) or nack (requeue) each message. + WARNING: If ack=True, messages will be permanently removed from the queue. + Use with caution and confirm before proceeding. + Returns a list of messages and their delivery tags. + """ + validate_rabbitmq_name(queue, "Queue name") + try: + rabbitmq = RabbitMQConnection( + self.rabbitmq_host, + self.rabbitmq_port, + self.rabbitmq_username, + self.rabbitmq_password, + self.rabbitmq_use_tls, + ) + messages = handle_get_messages(rabbitmq, queue, ack=ack, num_messages=num_messages) + if not messages: + return "No message available in queue" + return "\n".join( + [ + f"Message {i + 1} (delivery_tag={msg['delivery_tag']}): {msg['body']}" + for i, msg in enumerate(messages) + ] + ) + except Exception as e: + self.logger.error(f"{e}") + return f"Failed to read message(s): {e}" + def run(self, args): """Run the MCP server with the provided arguments.""" self.logger.info(f"Starting RabbitMQ MCP Server v{MCP_SERVER_VERSION}") diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 6591fcc..84aff4a 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -8,6 +8,7 @@ handle_enqueue, handle_fanout, handle_get_exchange_info, + handle_get_messages, handle_get_queue_info, handle_list_exchanges, handle_list_queues, @@ -148,6 +149,90 @@ def test_handle_purge_queue(self): # Verify the call mock_admin.purge_queue.assert_called_once_with("test-queue", "custom-vhost") + @patch("mcp_server_rabbitmq.handlers.RabbitMQConnection") + def test_handle_get_messages(self, mock_connection_class): + """Test that handle_get_messages retrieves messages and acks/nacks as expected for single and multiple messages.""" + # Setup mocks + mock_connection = MagicMock() + mock_channel = MagicMock() + mock_connection.get_channel.return_value = (mock_connection, mock_channel) + mock_method_frame1 = MagicMock() + mock_method_frame1.delivery_tag = 123 + mock_method_frame2 = MagicMock() + mock_method_frame2.delivery_tag = 456 + mock_header_frame = MagicMock() + mock_body1 = b"hello" + mock_body2 = b"world" + + # --- Single message, ack=True --- + mock_channel.basic_get.side_effect = [ + (mock_method_frame1, mock_header_frame, mock_body1), + (None, None, None), + ] + result = handle_get_messages(mock_connection, "test-queue", ack=True, num_messages=1) + assert result == [ + {"body": "hello", "delivery_tag": 123}, + ] + mock_channel.basic_ack.assert_called_once_with(delivery_tag=123) + mock_channel.basic_nack.assert_not_called() + mock_connection.close.assert_called_once() + + # --- Single message, ack=False --- + mock_channel.basic_ack.reset_mock() + mock_channel.basic_nack.reset_mock() + mock_connection.close.reset_mock() + mock_channel.basic_get.side_effect = [ + (mock_method_frame1, mock_header_frame, mock_body1), + (None, None, None), + ] + result = handle_get_messages(mock_connection, "test-queue", ack=False, num_messages=1) + assert result == [ + {"body": "hello", "delivery_tag": 123}, + ] + mock_channel.basic_ack.assert_not_called() + mock_channel.basic_nack.assert_called_once_with(delivery_tag=123, requeue=True) + mock_connection.close.assert_called_once() + + # --- Multiple messages, ack=True --- + mock_channel.basic_ack.reset_mock() + mock_channel.basic_nack.reset_mock() + mock_connection.close.reset_mock() + mock_channel.basic_get.side_effect = [ + (mock_method_frame1, mock_header_frame, mock_body1), + (mock_method_frame2, mock_header_frame, mock_body2), + (None, None, None), + ] + result = handle_get_messages(mock_connection, "test-queue", ack=True, num_messages=2) + assert result == [ + {"body": "hello", "delivery_tag": 123}, + {"body": "world", "delivery_tag": 456}, + ] + assert mock_channel.basic_ack.call_count == 2 + mock_channel.basic_ack.assert_any_call(delivery_tag=123) + mock_channel.basic_ack.assert_any_call(delivery_tag=456) + mock_channel.basic_nack.assert_not_called() + mock_connection.close.assert_called_once() + + # --- Multiple messages, ack=False --- + mock_channel.basic_ack.reset_mock() + mock_channel.basic_nack.reset_mock() + mock_connection.close.reset_mock() + mock_channel.basic_get.side_effect = [ + (mock_method_frame1, mock_header_frame, mock_body1), + (mock_method_frame2, mock_header_frame, mock_body2), + (None, None, None), + ] + result = handle_get_messages(mock_connection, "test-queue", ack=False, num_messages=2) + assert result == [ + {"body": "hello", "delivery_tag": 123}, + {"body": "world", "delivery_tag": 456}, + ] + assert mock_channel.basic_ack.call_count == 0 + assert mock_channel.basic_nack.call_count == 2 + mock_channel.basic_nack.assert_any_call(delivery_tag=123, requeue=True) + mock_channel.basic_nack.assert_any_call(delivery_tag=456, requeue=True) + mock_connection.close.assert_called_once() + class TestExchangeHandlers: """Test the exchange-related handler functions."""