Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,10 @@ https://smithery.ai/server/@kenliao94/mcp-server-rabbitmq

## Roadmap
1. Expose admin API tools and pika SDK tools
1. Support Streamable HTTP when it is GA in Python SDK
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this one leaked from the other PR. Let's remove it from this PR. Otherwise LGTM.

1. Support OAuth 2.1 and use it with RabbitMQ OAuth

✅ Support Streamable HTTP now that is GA in Python SDK

## Development

### Setup Development Environment
Expand Down
31 changes: 31 additions & 0 deletions mcp_server_rabbitmq/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,34 @@ def handle_get_exchange_info(
rabbitmq_admin: RabbitMQAdmin, exchange: str, vhost: str = "/"
) -> dict:
return rabbitmq_admin.get_exchange_info(exchange, vhost)


def handle_get_messages(
rabbitmq: RabbitMQConnection, queue: str, ack: bool = False, num_messages: int = 1
) -> list[dict]:
"""
Get up to num_messages from a queue and either ack (dequeue) or nack (requeue) each message after all are fetched.
Returns a list of dicts with 'body' and 'delivery_tag'.
Matches RabbitMQ Management UI behavior.
"""
connection, channel = rabbitmq.get_channel()
messages = []
method_frames = []
try:
for _ in range(num_messages):
method_frame, header_frame, body = channel.basic_get(queue=queue, auto_ack=False)
if method_frame is None:
break
messages.append(
{"body": body.decode() if body else "", "delivery_tag": method_frame.delivery_tag}
)
method_frames.append(method_frame)
# After fetching, ack or nack (requeue) each message as per the flag
for method_frame in method_frames:
if ack:
channel.basic_ack(delivery_tag=method_frame.delivery_tag)
else:
channel.basic_nack(delivery_tag=method_frame.delivery_tag, requeue=True)
return messages
finally:
connection.close()
45 changes: 44 additions & 1 deletion mcp_server_rabbitmq/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
handle_enqueue,
handle_fanout,
handle_get_exchange_info,
handle_get_messages,
handle_get_queue_info,
handle_list_exchanges,
handle_list_queues,
Expand Down Expand Up @@ -215,14 +216,50 @@ def get_exchange_info(exchange: str, vhost: str = "/") -> str:
self.logger.error(f"{e}")
return f"Failed to get exchange info: {e}"

@self.mcp.tool()
def get_messages(queue: str, ack: bool = False, num_messages: int = 1) -> str:
"""
Get up to num_messages from a queue and either ack (dequeue) or nack (requeue) each message.
WARNING: If ack=True, messages will be permanently removed from the queue.
Use with caution and confirm before proceeding.
Returns a list of messages and their delivery tags.
"""
validate_rabbitmq_name(queue, "Queue name")
try:
rabbitmq = RabbitMQConnection(
self.rabbitmq_host,
self.rabbitmq_port,
self.rabbitmq_username,
self.rabbitmq_password,
self.rabbitmq_use_tls,
)
messages = handle_get_messages(rabbitmq, queue, ack=ack, num_messages=num_messages)
if not messages:
return "No message available in queue"
return "\n".join(
[
f"Message {i + 1} (delivery_tag={msg['delivery_tag']}): {msg['body']}"
for i, msg in enumerate(messages)
]
)
except Exception as e:
self.logger.error(f"{e}")
return f"Failed to read message(s): {e}"

def run(self, args):
"""Run the MCP server with the provided arguments."""
self.logger.info(f"Starting RabbitMQ MCP Server v{MCP_SERVER_VERSION}")
self.logger.info(f"Connecting to RabbitMQ at {self.rabbitmq_host}:{self.rabbitmq_port}")

if args.sse:
# Set port if specified
if args.server_port:
self.mcp.settings.port = args.server_port

# Determine transport type and run
if args.sse:
self.mcp.run(transport="sse")
elif args.streamable_http:
self.mcp.run(transport="streamable-http")
else:
self.mcp.run()

Expand All @@ -243,6 +280,12 @@ def main():
"--api-port", type=int, default=15671, help="Port for the RabbitMQ management API"
)
parser.add_argument("--sse", action="store_true", help="Use SSE transport")
parser.add_argument(
"--streamable-http",
dest="streamable_http",
action="store_true",
help="Use Streamable HTTP transport",
)
parser.add_argument(
"--server-port", type=int, default=8888, help="Port to run the MCP server on"
)
Expand Down
85 changes: 85 additions & 0 deletions tests/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
handle_enqueue,
handle_fanout,
handle_get_exchange_info,
handle_get_messages,
handle_get_queue_info,
handle_list_exchanges,
handle_list_queues,
Expand Down Expand Up @@ -148,6 +149,90 @@ def test_handle_purge_queue(self):
# Verify the call
mock_admin.purge_queue.assert_called_once_with("test-queue", "custom-vhost")

@patch("mcp_server_rabbitmq.handlers.RabbitMQConnection")
def test_handle_get_messages(self, mock_connection_class):
"""Test that handle_get_messages retrieves messages and acks/nacks as expected for single and multiple messages."""
# Setup mocks
mock_connection = MagicMock()
mock_channel = MagicMock()
mock_connection.get_channel.return_value = (mock_connection, mock_channel)
mock_method_frame1 = MagicMock()
mock_method_frame1.delivery_tag = 123
mock_method_frame2 = MagicMock()
mock_method_frame2.delivery_tag = 456
mock_header_frame = MagicMock()
mock_body1 = b"hello"
mock_body2 = b"world"

# --- Single message, ack=True ---
mock_channel.basic_get.side_effect = [
(mock_method_frame1, mock_header_frame, mock_body1),
(None, None, None),
]
result = handle_get_messages(mock_connection, "test-queue", ack=True, num_messages=1)
assert result == [
{"body": "hello", "delivery_tag": 123},
]
mock_channel.basic_ack.assert_called_once_with(delivery_tag=123)
mock_channel.basic_nack.assert_not_called()
mock_connection.close.assert_called_once()

# --- Single message, ack=False ---
mock_channel.basic_ack.reset_mock()
mock_channel.basic_nack.reset_mock()
mock_connection.close.reset_mock()
mock_channel.basic_get.side_effect = [
(mock_method_frame1, mock_header_frame, mock_body1),
(None, None, None),
]
result = handle_get_messages(mock_connection, "test-queue", ack=False, num_messages=1)
assert result == [
{"body": "hello", "delivery_tag": 123},
]
mock_channel.basic_ack.assert_not_called()
mock_channel.basic_nack.assert_called_once_with(delivery_tag=123, requeue=True)
mock_connection.close.assert_called_once()

# --- Multiple messages, ack=True ---
mock_channel.basic_ack.reset_mock()
mock_channel.basic_nack.reset_mock()
mock_connection.close.reset_mock()
mock_channel.basic_get.side_effect = [
(mock_method_frame1, mock_header_frame, mock_body1),
(mock_method_frame2, mock_header_frame, mock_body2),
(None, None, None),
]
result = handle_get_messages(mock_connection, "test-queue", ack=True, num_messages=2)
assert result == [
{"body": "hello", "delivery_tag": 123},
{"body": "world", "delivery_tag": 456},
]
assert mock_channel.basic_ack.call_count == 2
mock_channel.basic_ack.assert_any_call(delivery_tag=123)
mock_channel.basic_ack.assert_any_call(delivery_tag=456)
mock_channel.basic_nack.assert_not_called()
mock_connection.close.assert_called_once()

# --- Multiple messages, ack=False ---
mock_channel.basic_ack.reset_mock()
mock_channel.basic_nack.reset_mock()
mock_connection.close.reset_mock()
mock_channel.basic_get.side_effect = [
(mock_method_frame1, mock_header_frame, mock_body1),
(mock_method_frame2, mock_header_frame, mock_body2),
(None, None, None),
]
result = handle_get_messages(mock_connection, "test-queue", ack=False, num_messages=2)
assert result == [
{"body": "hello", "delivery_tag": 123},
{"body": "world", "delivery_tag": 456},
]
assert mock_channel.basic_ack.call_count == 0
assert mock_channel.basic_nack.call_count == 2
mock_channel.basic_nack.assert_any_call(delivery_tag=123, requeue=True)
mock_channel.basic_nack.assert_any_call(delivery_tag=456, requeue=True)
mock_connection.close.assert_called_once()


class TestExchangeHandlers:
"""Test the exchange-related handler functions."""
Expand Down