Skip to content

Commit 305c43c

Browse files
committed
refactor connection into its own class
1 parent ac16997 commit 305c43c

File tree

3 files changed

+37
-32
lines changed

3 files changed

+37
-32
lines changed
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import pika
2+
import ssl
3+
from typing import Optional
4+
5+
class RabbitMQConnection:
6+
def __init__(self, host: str, port: int, username: str, password: str, use_tls: bool):
7+
self.protocol = "amqps" if use_tls else "amqp"
8+
self.url = f"{self.protocol}://{username}:{password}@{host}:{port}"
9+
self.parameters = pika.URLParameters(self.url)
10+
11+
if use_tls:
12+
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
13+
ssl_context.set_ciphers('ECDHE+AESGCM:!ECDSA')
14+
self.parameters.ssl_options = pika.SSLOptions(context=ssl_context)
15+
16+
def get_channel(self) -> tuple[pika.BlockingConnection, pika.channel.Channel]:
17+
connection = pika.BlockingConnection(self.parameters)
18+
channel = connection.channel()
19+
return connection, channel
20+
21+
def validate_rabbitmq_name(name: str, field_name: str) -> None:
22+
"""Validate RabbitMQ queue/exchange names"""
23+
if not name or not name.strip():
24+
raise ValueError(f"{field_name} cannot be empty")
25+
if not all(c.isalnum() or c in '-_.:' for c in name):
26+
raise ValueError(f"{field_name} can only contain letters, digits, hyphen, underscore, period, or colon")
27+
if len(name) > 255:
28+
raise ValueError(f"{field_name} must be less than 255 characters")

src/mcp_server_rabbitmq/handlers.py

Whitespace-only changes.

src/mcp_server_rabbitmq/server.py

Lines changed: 9 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import ssl
99
from .models import Enqueue, Fanout
1010
from .logger import Logger, LOG_LEVEL
11+
from .connection import RabbitMQConnection, validate_rabbitmq_name
1112

1213

1314
async def serve(rabbitmq_host: str, port: int, username: str, password: str, use_tls: bool, log_level: str = LOG_LEVEL.DEBUG.name) -> None:
@@ -23,14 +24,8 @@ async def serve(rabbitmq_host: str, port: int, username: str, password: str, use
2324
logger = Logger("server.log", log_level)
2425
if is_log_level_exception:
2526
logger.warning("Wrong log_level received. Default to WARNING")
26-
# Setup RabbitMQ connection metadata
27-
protocol = "amqps" if use_tls else "amqp"
28-
url = f"{protocol}://{username}:{password}@{rabbitmq_host}:{port}"
29-
parameters = pika.URLParameters(url)
30-
if use_tls:
31-
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
32-
ssl_context.set_ciphers('ECDHE+AESGCM:!ECDSA')
33-
parameters.ssl_options = pika.SSLOptions(context=ssl_context)
27+
# Setup RabbitMQ connection
28+
rabbitmq = RabbitMQConnection(rabbitmq_host, port, username, password, use_tls)
3429

3530
@server.list_tools()
3631
async def list_tools() -> list[Tool]:
@@ -57,22 +52,13 @@ async def call_tool(
5752
message = arguments["message"]
5853
queue = arguments["queue"]
5954

60-
if not message or not message.strip():
61-
raise ValueError("Message cannot be empty")
62-
if not queue or not queue.strip():
63-
raise ValueError("Queue name cannot be empty")
64-
# RabbitMQ queue names can only contain letters, digits, hyphen, underscore, period, or colon
65-
# and must be less than 255 characters
66-
if not all(c.isalnum() or c in '-_.:' for c in queue):
67-
raise ValueError("Queue name can only contain letters, digits, hyphen, underscore, period, or colon")
68-
if len(queue) > 255:
69-
raise ValueError("Queue name must be less than 255 characters")
55+
validate_rabbitmq_name(queue, "Queue name")
7056

7157
try:
72-
connection = pika.BlockingConnection(parameters)
73-
channel = connection.channel()
58+
connection, channel = rabbitmq.get_channel()
7459
channel.queue_declare(queue)
7560
channel.basic_publish(exchange="", routing_key=queue, body=message)
61+
connection.close()
7662
return [TextContent(type="text", text=str("suceeded"))]
7763
except Exception as e:
7864
logger.error(f"{e}")
@@ -82,22 +68,13 @@ async def call_tool(
8268
message = arguments["message"]
8369
exchange = arguments["exchange"]
8470

85-
if not message or not message.strip():
86-
raise ValueError("Message cannot be empty")
87-
if not exchange or not exchange.strip():
88-
raise ValueError("Exchange name cannot be empty")
89-
# RabbitMQ exchange names can only contain letters, digits, hyphen, underscore, period, or colon
90-
# and must be less than 255 characters
91-
if not all(c.isalnum() or c in '-_.:' for c in exchange):
92-
raise ValueError("Exchange name can only contain letters, digits, hyphen, underscore, period, or colon")
93-
if len(exchange) > 255:
94-
raise ValueError("Exchange name must be less than 255 characters")
71+
validate_rabbitmq_name(exchange, "Exchange name")
9572

9673
try:
97-
connection = pika.BlockingConnection(parameters)
98-
channel = connection.channel()
74+
connection, channel = rabbitmq.get_channel()
9975
channel.exchange_declare(exchange=exchange, exchange_type="fanout")
10076
channel.basic_publish(exchange=exchange, routing_key="", body=message)
77+
connection.close()
10178
return [TextContent(type="text", text=str("suceeded"))]
10279
except Exception as e:
10380
logger.error(f"{e}")

0 commit comments

Comments
 (0)