Skip to content

Commit 648f6cf

Browse files
zmyzxbBinWang28Copilot
authored
feat(logging): support tool logs and per-task log storage (#69)
* feat(logging): support tool logs and per-task log storage * Update src/logging/logger.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update src/logging/logger.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update src/logging/logger.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Bin Wang <bwang28c@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent e1dc1be commit 648f6cf

File tree

12 files changed

+262
-37
lines changed

12 files changed

+262
-37
lines changed

common_benchmark.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@
1919
from omegaconf import DictConfig, OmegaConf
2020

2121
from utils.eval_utils import verify_answer_for_datasets
22-
from src.logging.logger import bootstrap_logger
22+
from src.logging.logger import bootstrap_logger, task_logging_context, init_logging_for_benchmark_evaluation
2323
from config import config_name, config_path
2424
from src.core.pipeline import (
2525
create_pipeline_components,
2626
execute_task_pipeline,
2727
)
28-
28+
init_logging_for_benchmark_evaluation(print_task_logs=False)
2929

3030
class TaskStatus(StrEnum):
3131
PENDING = "pending"
@@ -179,7 +179,8 @@ async def run_single_task(self, task: BenchmarkTask) -> BenchmarkResult:
179179
found_correct_answer = False
180180

181181
# Print debug info about log directory
182-
print(f" Current log directory: {self.output_dir}")
182+
print(f" Current result directory: {self.output_dir}")
183+
print(f" Current task log directory: {self.output_dir}/task_logs")
183184

184185
try:
185186
# Prepare task
@@ -371,7 +372,9 @@ async def run_parallel_inference(
371372

372373
async def run_with_semaphore(task):
373374
async with semaphore:
374-
return await self.run_single_task(task)
375+
with task_logging_context(task.task_id, self.get_log_dir()):
376+
result = await self.run_single_task(task)
377+
return result
375378

376379
# Shuffle tasks to avoid order bias and improve balancing
377380
shuffled_tasks = tasks.copy()

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ dependencies = [
3030
"requests>=2.32.4",
3131
"filetype>=1.2.0",
3232
"Pillow",
33+
"pyzmq"
3334
]
3435

3536
[build-system]

src/logging/logger.py

Lines changed: 178 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,44 +2,201 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
import os
6+
import zmq
7+
import zmq.asyncio
58
import logging
69
from functools import lru_cache
7-
from typing import Literal
8-
10+
from pathlib import Path
11+
from typing import Literal, Dict
12+
from contextvars import ContextVar
913
import hydra
1014
from rich.console import Console
1115
from rich.logging import RichHandler
16+
import asyncio
17+
import threading
18+
from contextlib import contextmanager
19+
TASK_CONTEXT_VAR: ContextVar[str | None] = ContextVar("CURRENT_TASK_ID", default=None)
20+
21+
class ZMQLogHandler(logging.Handler):
22+
def __init__(self, addr="tcp://127.0.0.1:6000", tool_name="unknown_tool"):
23+
super().__init__()
24+
ctx = zmq.Context()
25+
self.sock = ctx.socket(zmq.PUSH)
26+
self.sock.connect(addr)
27+
self.task_id = os.environ.get("TASK_ID", "0")
28+
self.tool_name = tool_name
29+
30+
def emit(self, record):
31+
try:
32+
msg = f"{record.getMessage()}"
33+
self.sock.send_string(f"{self.task_id}||{self.tool_name}||{msg}")
34+
except Exception:
35+
self.handleError(record)
36+
37+
async def zmq_log_listener(bind_addr="tcp://127.0.0.1:6000"):
38+
ctx = zmq.asyncio.Context()
39+
sock = ctx.socket(zmq.PULL)
40+
sock.bind(bind_addr)
41+
42+
root_logger = logging.getLogger()
43+
44+
while True:
45+
raw = await sock.recv_string()
46+
if "|" in raw:
47+
task_id, tool_name, msg = raw.split("||", 2)
48+
49+
record = root_logger.makeRecord(
50+
name=f'[TOOL] {tool_name}',
51+
level=logging.INFO,
52+
fn="", lno=0, msg=msg, args=(),
53+
exc_info=None
54+
)
55+
record.task_id = task_id
56+
57+
root_logger.handle(record)
58+
else:
59+
root_logger.info(raw)
60+
61+
def start_zmq_listener():
62+
loop = asyncio.new_event_loop()
63+
asyncio.set_event_loop(loop)
64+
loop.run_until_complete(zmq_log_listener())
65+
66+
def setup_mcp_logging(level="INFO", addr="tcp://127.0.0.1:6000", tool_name="unknown_tool"):
67+
root = logging.getLogger()
68+
root.setLevel(level)
69+
70+
# Remove root handlers
71+
for h in root.handlers[:]:
72+
root.removeHandler(h)
73+
h.close()
1274

75+
# Remove all handlers from fastmcp child loggers
76+
for name, logger in logging.Logger.manager.loggerDict.items():
77+
if isinstance(logger, logging.Logger):
78+
for h in logger.handlers[:]:
79+
logger.removeHandler(h)
80+
h.close()
81+
logger.propagate = True # 确保冒泡到 root
82+
83+
# Re-add the ZMQ handler
84+
handler = ZMQLogHandler(addr=addr, tool_name=tool_name)
85+
handler.setFormatter(logging.Formatter("[TOOL] %(asctime)s %(levelname)s: %(message)s"))
86+
root.addHandler(handler)
87+
88+
def setup_log_record_factory():
89+
old_factory = logging.getLogRecordFactory()
90+
def record_factory(*args, **kwargs):
91+
record = old_factory(*args, **kwargs)
92+
record.task_id = TASK_CONTEXT_VAR.get()
93+
return record
94+
logging.setLogRecordFactory(record_factory)
95+
96+
class TaskFilter(logging.Filter):
97+
def __init__(self, task_id: str):
98+
super().__init__()
99+
self.task_id = task_id
100+
101+
def filter(self, record: logging.LogRecord) -> bool:
102+
return getattr(record, "task_id", None) == self.task_id
103+
104+
def make_task_logger(task_id: str, log_dir: Path) -> logging.Handler:
105+
log_dir.mkdir(parents=True, exist_ok=True)
106+
file_path = log_dir / f"task_{task_id}.log"
107+
fh = logging.FileHandler(file_path, encoding="utf-8")
108+
fmt = logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s")
109+
fh.setFormatter(fmt)
110+
fh.addFilter(TaskFilter(task_id))
111+
logging.getLogger().addHandler(fh)
112+
return fh
113+
114+
def remove_all_console_handlers():
115+
"""
116+
移除当前进程中所有 logger 上的 console handler (StreamHandler/RichHandler)。
117+
"""
118+
for name, logger in logging.Logger.manager.loggerDict.items():
119+
if isinstance(logger, logging.Logger):
120+
handlers_to_remove = []
121+
for h in logger.handlers:
122+
if isinstance(h, logging.StreamHandler) or isinstance(h, RichHandler):
123+
handlers_to_remove.append(h)
124+
for h in handlers_to_remove:
125+
logger.removeHandler(h)
126+
h.close()
127+
128+
root_logger = logging.getLogger()
129+
handlers_to_remove = []
130+
for h in root_logger.handlers:
131+
if isinstance(h, logging.StreamHandler):
132+
handlers_to_remove.append(h)
133+
for h in handlers_to_remove:
134+
root_logger.removeHandler(h)
135+
h.close()
136+
137+
@contextmanager
138+
def task_logging_context(task_id: str, log_dir: Path):
139+
token = TASK_CONTEXT_VAR.set(task_id)
140+
handler = make_task_logger(task_id, log_dir / "task_logs")
141+
try:
142+
yield
143+
finally:
144+
TASK_CONTEXT_VAR.reset(token)
145+
logging.getLogger().removeHandler(handler)
146+
handler.close()
147+
148+
def init_logging_for_benchmark_evaluation(print_task_logs=False):
149+
threading.Thread(target=start_zmq_listener, daemon=True).start() #monitoring tool logs
150+
logging.basicConfig(handlers=[])
151+
setup_log_record_factory()
152+
if not print_task_logs:
153+
remove_all_console_handlers()
13154

14155
@lru_cache
15156
def bootstrap_logger(
16157
level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] | int = "INFO",
158+
logger_name: str = "miroflow",
17159
logger: logging.Logger | None = None,
160+
log_dir: str | Path | None = None, # 日志存储目录
161+
log_filename: str = "miroflow.log", # 默认日志文件名
162+
to_console: bool = True, # 是否显示到 console
18163
) -> logging.Logger:
19164
"""Configure only this logger, not the root logger"""
20165
if logger is None:
21-
logger = logging.getLogger("miroflow")
166+
logger = logging.getLogger(logger_name)
167+
22168
for handler in logger.handlers[:]:
23169
logger.removeHandler(handler)
24170

25-
# use rich for better readability of stack trace.
26-
handler = RichHandler(
27-
console=Console(
28-
stderr=True,
29-
width=200,
30-
color_system=None, # Disable colors to avoid ANSI escape sequences in log files
31-
force_terminal=False, # Don't force terminal mode
32-
legacy_windows=False,
33-
),
34-
rich_tracebacks=True,
35-
tracebacks_suppress=[hydra],
36-
tracebacks_show_locals=True,
37-
show_level=False,
38-
)
39-
formatter = logging.Formatter("[%(levelname)s] %(message)s")
40-
handler.setFormatter(formatter)
41-
logger.addHandler(handler)
171+
if to_console:
172+
handler = RichHandler(
173+
console=Console(
174+
stderr=True,
175+
width=200,
176+
color_system=None,
177+
force_terminal=False,
178+
legacy_windows=False,
179+
),
180+
rich_tracebacks=True,
181+
tracebacks_suppress=[hydra],
182+
tracebacks_show_locals=True,
183+
show_level=False,
184+
)
185+
formatter = logging.Formatter("[%(levelname)s] %(message)s")
186+
handler.setFormatter(formatter)
187+
logger.addHandler(handler)
188+
189+
if log_dir is not None:
190+
log_dir = Path(log_dir)
191+
log_dir.mkdir(parents=True, exist_ok=True)
192+
file_path = log_dir / log_filename
193+
file_handler = logging.FileHandler(file_path, encoding="utf-8")
194+
file_handler.setFormatter(logging.Formatter(
195+
"%(asctime)s [%(levelname)s] %(name)s: %(message)s"
196+
))
197+
logger.addHandler(file_handler)
198+
42199
logger.setLevel(level)
43-
logger.propagate = False
200+
logger.propagate = True
44201

45202
return logger

src/tool/manager.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@
2020

2121
R = TypeVar("R")
2222

23+
def update_server_params_with_context_var(server_params: StdioServerParameters) -> StdioServerParameters:
24+
"""
25+
Update the server params with the context var.
26+
"""
27+
from src.logging.logger import TASK_CONTEXT_VAR
28+
if TASK_CONTEXT_VAR.get() is not None:
29+
server_params.env["TASK_ID"] = TASK_CONTEXT_VAR.get()
30+
return server_params
2331

2432
def with_timeout(timeout_s: float = 300.0):
2533
"""
@@ -108,7 +116,7 @@ async def _find_servers_with_tool(self, tool_name):
108116

109117
try:
110118
if isinstance(server_params, StdioServerParameters):
111-
async with stdio_client(server_params) as (read, write):
119+
async with stdio_client(update_server_params_with_context_var(server_params)) as (read, write):
112120
async with ClientSession(
113121
read, write, sampling_callback=None
114122
) as session:
@@ -168,7 +176,7 @@ async def get_all_tool_definitions(self):
168176

169177
try:
170178
if isinstance(server_params, StdioServerParameters):
171-
async with stdio_client(server_params) as (read, write):
179+
async with stdio_client(update_server_params_with_context_var(server_params)) as (read, write):
172180
async with ClientSession(
173181
read, write, sampling_callback=None
174182
) as session:
@@ -342,7 +350,7 @@ async def execute_tool_call(self, server_name, tool_name, arguments) -> Any:
342350
try:
343351
result_content = None
344352
if isinstance(server_params, StdioServerParameters):
345-
async with stdio_client(server_params) as (read, write):
353+
async with stdio_client(update_server_params_with_context_var(server_params)) as (read, write):
346354
async with ClientSession(
347355
read, write, sampling_callback=None
348356
) as session:

src/tool/mcp_servers/audio_mcp_server.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@
2525
)
2626

2727
# Initialize FastMCP server
28+
from src.logging.logger import setup_mcp_logging
29+
setup_mcp_logging(tool_name=os.path.basename(__file__))
2830
mcp = FastMCP("audio-mcp-server")
2931

30-
3132
def _get_audio_extension(url: str, content_type: str = None) -> str:
3233
"""
3334
Determine the appropriate audio file extension from URL or content type.
@@ -289,4 +290,4 @@ async def audio_question_answering(audio_path_or_url: str, question: str) -> str
289290

290291

291292
if __name__ == "__main__":
292-
mcp.run(transport="stdio")
293+
mcp.run(transport="stdio",show_banner=False)

src/tool/mcp_servers/browser_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,4 +97,4 @@ async def test_persistent_session():
9797

9898

9999
if __name__ == "__main__":
100-
asyncio.run(test_persistent_session())
100+
asyncio.run(test_persistent_session(),show_banner=False)

src/tool/mcp_servers/python_server.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from fastmcp import FastMCP
1010

1111
# Initialize FastMCP server
12+
from src.logging.logger import setup_mcp_logging
13+
setup_mcp_logging(tool_name=os.path.basename(__file__))
1214
mcp = FastMCP("e2b-python-interpreter")
1315

1416
# API keys
@@ -411,4 +413,4 @@ async def download_file_from_sandbox_to_local(
411413

412414

413415
if __name__ == "__main__":
414-
mcp.run(transport="stdio")
416+
mcp.run(transport="stdio",show_banner=False)

src/tool/mcp_servers/reading_mcp_server.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from .utils.smart_request import smart_request
1616

1717
# Initialize FastMCP server
18+
from src.logging.logger import setup_mcp_logging
19+
setup_mcp_logging(tool_name=os.path.basename(__file__))
1820
mcp = FastMCP("reading-mcp-server")
1921
SERPER_API_KEY = os.environ.get("SERPER_API_KEY", "")
2022
JINA_API_KEY = os.environ.get("JINA_API_KEY", "")
@@ -153,7 +155,7 @@ def _cleanup_tempfile(path):
153155

154156
# Run the server with the specified transport method
155157
if args.transport == "stdio":
156-
mcp.run(transport="stdio")
158+
mcp.run(transport="stdio",show_banner=False)
157159
else:
158160
# For HTTP transport, include port and path options
159-
mcp.run(transport="streamable-http", port=args.port, path=args.path)
161+
mcp.run(transport="streamable-http", port=args.port, path=args.path,show_banner=False)

src/tool/mcp_servers/reasoning_mcp_server.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
OPENAI_MODEL_NAME = os.environ.get("OPENAI_MODEL_NAME", "o3")
2020

2121
# Initialize FastMCP server
22+
from src.logging.logger import setup_mcp_logging
23+
setup_mcp_logging(tool_name=os.path.basename(__file__))
2224
mcp = FastMCP("reasoning-mcp-server")
2325

2426

@@ -124,4 +126,4 @@ async def reasoning(question: str) -> str:
124126

125127

126128
if __name__ == "__main__":
127-
mcp.run(transport="stdio")
129+
mcp.run(transport="stdio",show_banner=False)

src/tool/mcp_servers/searching_mcp_server.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
)
3232

3333
# Initialize FastMCP server
34+
from src.logging.logger import setup_mcp_logging
35+
setup_mcp_logging(tool_name=os.path.basename(__file__))
3436
mcp = FastMCP("searching-mcp-server")
3537

3638

@@ -666,4 +668,4 @@ async def scrape_website(url: str) -> str:
666668

667669

668670
if __name__ == "__main__":
669-
mcp.run(transport="stdio")
671+
mcp.run(transport="stdio",show_banner=False)

0 commit comments

Comments
 (0)