1616import asyncio
1717import threading
1818from contextlib import contextmanager
19+ import socket
1920
2021TASK_CONTEXT_VAR : ContextVar [str | None ] = ContextVar ("CURRENT_TASK_ID" , default = None )
2122
23+ # Global variable to store the actual ZMQ address being used
24+ _ZMQ_ADDRESS : str = "tcp://127.0.0.1:6000"
25+
26+
27+ def find_available_port (start_port : int = 6000 , max_attempts : int = 10 ) -> int :
28+ """Find an available port starting from start_port."""
29+ for port in range (start_port , start_port + max_attempts ):
30+ try :
31+ with socket .socket (socket .AF_INET , socket .SOCK_STREAM ) as s :
32+ s .bind (("127.0.0.1" , port ))
33+ return port
34+ except OSError :
35+ continue
36+ raise RuntimeError (
37+ f"Could not find an available port in range { start_port } -{ start_port + max_attempts - 1 } "
38+ )
39+
40+
41+ def get_zmq_address () -> str :
42+ """Get the current ZMQ address."""
43+ return _ZMQ_ADDRESS
44+
45+
46+ def set_zmq_address (address : str ) -> None :
47+ """Set the ZMQ address."""
48+ global _ZMQ_ADDRESS
49+ _ZMQ_ADDRESS = address
50+
51+
52+ def _extract_port_from_address (addr : str ) -> int :
53+ """Extract port number from ZMQ address."""
54+ try :
55+ return int (addr .split (":" )[- 1 ])
56+ except (ValueError , IndexError ):
57+ return 6000
58+
59+
60+ def _bind_zmq_socket (sock , bind_addr : str ) -> str :
61+ """Bind ZMQ socket to an available port and return the actual address."""
62+ port = _extract_port_from_address (bind_addr )
63+
64+ try :
65+ available_port = find_available_port (port )
66+ actual_addr = f"tcp://127.0.0.1:{ available_port } "
67+ sock .bind (actual_addr )
68+ return actual_addr
69+ except RuntimeError :
70+ # Fallback to random port
71+ port = sock .bind_to_random_port ("tcp://127.0.0.1" )
72+ return f"tcp://127.0.0.1:{ port } "
73+
2274
2375class ZMQLogHandler (logging .Handler ):
24- def __init__ (self , addr = "tcp://127.0.0.1:6000" , tool_name = "unknown_tool" ):
76+ def __init__ (self , addr = None , tool_name = "unknown_tool" ):
2577 super ().__init__ ()
2678 ctx = zmq .Context ()
2779 self .sock = ctx .socket (zmq .PUSH )
28- self .sock .connect (addr )
80+
81+ # Use the global ZMQ address if no specific address is provided
82+ if addr is None :
83+ addr = get_zmq_address ()
84+
85+ # Try to connect to the address
86+ try :
87+ self .sock .connect (addr )
88+ logging .getLogger (__name__ ).info (f"ZMQ handler connected to: { addr } " )
89+ except zmq .error .ZMQError as e :
90+ # If connection fails, disable the handler
91+ logging .getLogger (__name__ ).warning (
92+ f"Could not connect to ZMQ listener at { addr } : { e } "
93+ )
94+ logging .getLogger (__name__ ).warning (
95+ "Disabling ZMQ logging for this handler"
96+ )
97+ self .sock = None
98+
2999 self .task_id = os .environ .get ("TASK_ID" , "0" )
30100 self .tool_name = tool_name
31101
32102 def emit (self , record ):
103+ if self .sock is None :
104+ return
105+
33106 try :
34107 msg = f"{ record .getMessage ()} "
35108 self .sock .send_string (f"{ self .task_id } ||{ self .tool_name } ||{ msg } " )
@@ -40,13 +113,17 @@ def emit(self, record):
40113async def zmq_log_listener (bind_addr = "tcp://127.0.0.1:6000" ):
41114 ctx = zmq .asyncio .Context ()
42115 sock = ctx .socket (zmq .PULL )
43- sock .bind (bind_addr )
116+
117+ # Bind to available port
118+ actual_addr = _bind_zmq_socket (sock , bind_addr )
119+ set_zmq_address (actual_addr )
120+ logging .getLogger (__name__ ).info (f"ZMQ listener bound to: { actual_addr } " )
44121
45122 root_logger = logging .getLogger ()
46123
47124 while True :
48125 raw = await sock .recv_string ()
49- if "|" in raw :
126+ if "|| " in raw :
50127 task_id , tool_name , msg = raw .split ("||" , 2 )
51128
52129 record = root_logger .makeRecord (
@@ -71,9 +148,7 @@ def start_zmq_listener():
71148 loop .run_until_complete (zmq_log_listener ())
72149
73150
74- def setup_mcp_logging (
75- level = "INFO" , addr = "tcp://127.0.0.1:6000" , tool_name = "unknown_tool"
76- ):
151+ def setup_mcp_logging (level = "INFO" , addr = None , tool_name = "unknown_tool" ):
77152 root = logging .getLogger ()
78153 root .setLevel (level )
79154
@@ -90,7 +165,7 @@ def setup_mcp_logging(
90165 h .close ()
91166 logger .propagate = True # Ensure bubbling to root
92167
93- # Re-add the ZMQ handler
168+ # Re-add the ZMQ handler (will use global address if addr is None)
94169 handler = ZMQLogHandler (addr = addr , tool_name = tool_name )
95170 handler .setFormatter (
96171 logging .Formatter ("[TOOL] %(asctime)s %(levelname)s: %(message)s" )
0 commit comments