Skip to content

Commit 3080e2a

Browse files
authored
fix(logger): to support logger for multiple runs through port assignment (MiroMindAI#71)
* to pass lint * accomodate copilot suggests * pass lint
1 parent eca0c19 commit 3080e2a

File tree

1 file changed

+83
-8
lines changed

1 file changed

+83
-8
lines changed

src/logging/logger.py

Lines changed: 83 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,93 @@
1616
import asyncio
1717
import threading
1818
from contextlib import contextmanager
19+
import socket
1920

2021
TASK_CONTEXT_VAR: ContextVar[str | None] = ContextVar("CURRENT_TASK_ID", default=None)
2122

23+
# Global variable to store the actual ZMQ address being used
24+
_ZMQ_ADDRESS: str = "tcp://127.0.0.1:6000"
25+
26+
27+
def find_available_port(start_port: int = 6000, max_attempts: int = 10) -> int:
28+
"""Find an available port starting from start_port."""
29+
for port in range(start_port, start_port + max_attempts):
30+
try:
31+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
32+
s.bind(("127.0.0.1", port))
33+
return port
34+
except OSError:
35+
continue
36+
raise RuntimeError(
37+
f"Could not find an available port in range {start_port}-{start_port + max_attempts - 1}"
38+
)
39+
40+
41+
def get_zmq_address() -> str:
42+
"""Get the current ZMQ address."""
43+
return _ZMQ_ADDRESS
44+
45+
46+
def set_zmq_address(address: str) -> None:
47+
"""Set the ZMQ address."""
48+
global _ZMQ_ADDRESS
49+
_ZMQ_ADDRESS = address
50+
51+
52+
def _extract_port_from_address(addr: str) -> int:
53+
"""Extract port number from ZMQ address."""
54+
try:
55+
return int(addr.split(":")[-1])
56+
except (ValueError, IndexError):
57+
return 6000
58+
59+
60+
def _bind_zmq_socket(sock, bind_addr: str) -> str:
61+
"""Bind ZMQ socket to an available port and return the actual address."""
62+
port = _extract_port_from_address(bind_addr)
63+
64+
try:
65+
available_port = find_available_port(port)
66+
actual_addr = f"tcp://127.0.0.1:{available_port}"
67+
sock.bind(actual_addr)
68+
return actual_addr
69+
except RuntimeError:
70+
# Fallback to random port
71+
port = sock.bind_to_random_port("tcp://127.0.0.1")
72+
return f"tcp://127.0.0.1:{port}"
73+
2274

2375
class ZMQLogHandler(logging.Handler):
24-
def __init__(self, addr="tcp://127.0.0.1:6000", tool_name="unknown_tool"):
76+
def __init__(self, addr=None, tool_name="unknown_tool"):
2577
super().__init__()
2678
ctx = zmq.Context()
2779
self.sock = ctx.socket(zmq.PUSH)
28-
self.sock.connect(addr)
80+
81+
# Use the global ZMQ address if no specific address is provided
82+
if addr is None:
83+
addr = get_zmq_address()
84+
85+
# Try to connect to the address
86+
try:
87+
self.sock.connect(addr)
88+
logging.getLogger(__name__).info(f"ZMQ handler connected to: {addr}")
89+
except zmq.error.ZMQError as e:
90+
# If connection fails, disable the handler
91+
logging.getLogger(__name__).warning(
92+
f"Could not connect to ZMQ listener at {addr}: {e}"
93+
)
94+
logging.getLogger(__name__).warning(
95+
"Disabling ZMQ logging for this handler"
96+
)
97+
self.sock = None
98+
2999
self.task_id = os.environ.get("TASK_ID", "0")
30100
self.tool_name = tool_name
31101

32102
def emit(self, record):
103+
if self.sock is None:
104+
return
105+
33106
try:
34107
msg = f"{record.getMessage()}"
35108
self.sock.send_string(f"{self.task_id}||{self.tool_name}||{msg}")
@@ -40,13 +113,17 @@ def emit(self, record):
40113
async def zmq_log_listener(bind_addr="tcp://127.0.0.1:6000"):
41114
ctx = zmq.asyncio.Context()
42115
sock = ctx.socket(zmq.PULL)
43-
sock.bind(bind_addr)
116+
117+
# Bind to available port
118+
actual_addr = _bind_zmq_socket(sock, bind_addr)
119+
set_zmq_address(actual_addr)
120+
logging.getLogger(__name__).info(f"ZMQ listener bound to: {actual_addr}")
44121

45122
root_logger = logging.getLogger()
46123

47124
while True:
48125
raw = await sock.recv_string()
49-
if "|" in raw:
126+
if "||" in raw:
50127
task_id, tool_name, msg = raw.split("||", 2)
51128

52129
record = root_logger.makeRecord(
@@ -71,9 +148,7 @@ def start_zmq_listener():
71148
loop.run_until_complete(zmq_log_listener())
72149

73150

74-
def setup_mcp_logging(
75-
level="INFO", addr="tcp://127.0.0.1:6000", tool_name="unknown_tool"
76-
):
151+
def setup_mcp_logging(level="INFO", addr=None, tool_name="unknown_tool"):
77152
root = logging.getLogger()
78153
root.setLevel(level)
79154

@@ -90,7 +165,7 @@ def setup_mcp_logging(
90165
h.close()
91166
logger.propagate = True # Ensure bubbling to root
92167

93-
# Re-add the ZMQ handler
168+
# Re-add the ZMQ handler (will use global address if addr is None)
94169
handler = ZMQLogHandler(addr=addr, tool_name=tool_name)
95170
handler.setFormatter(
96171
logging.Formatter("[TOOL] %(asctime)s %(levelname)s: %(message)s")

0 commit comments

Comments
 (0)