Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions framework/py/flwr/cli/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@
"""Flower command line interface `log` command."""


import re
import time
from logging import DEBUG, ERROR, INFO
from typing import Annotated, cast

import click
import grpc
import typer
from rich.console import Console
from rich.text import Text

from flwr.cli.config_migration import migrate, warn_if_federation_config_overrides
from flwr.cli.constant import FEDERATION_CONFIG_HELP_MESSAGE
Expand All @@ -34,6 +37,16 @@

from .utils import flwr_cli_grpc_exc_handler, init_channel_from_connection

CONSOLE = Console(highlight=False, markup=False)
LOG_STYLES = {
"DEBUG": "blue",
"INFO": "green",
"WARNING": "yellow",
"ERROR": "red",
"CRITICAL": "magenta",
}
LOG_LEVEL_RE = re.compile(r"^(DEBUG|INFO|WARNING|ERROR|CRITICAL)(?=[:\s])")


class AllLogsRetrieved(BaseException):
"""Exception raised when all available logs have been retrieved.
Expand All @@ -43,6 +56,29 @@ class AllLogsRetrieved(BaseException):
"""


def _print_log_output(log_output: str, end: str = "\n") -> None:
"""Render streamed log output, including ANSI colors."""
CONSOLE.print(_render_log_output(log_output), end=end)


def _render_log_output(log_output: str) -> Text:
"""Render streamed output from ANSI or plain log lines."""
if "\x1b[" in log_output:
return Text.from_ansi(log_output)

text = Text()
for line in log_output.splitlines(keepends=True):
level_match = LOG_LEVEL_RE.match(line)
head_end = line.find(":")
if level_match and head_end >= level_match.end():
level = level_match.group(1)
text.append(line[:head_end], style=LOG_STYLES[level])
text.append(line[head_end:])
else:
text.append(line)
return text


def start_stream(
run_id: int, channel: grpc.Channel, refresh_period: int = CONN_REFRESH_PERIOD
) -> None:
Expand Down Expand Up @@ -108,7 +144,7 @@ def stream_logs(
try:
with flwr_cli_grpc_exc_handler():
for res in stub.StreamLogs(req, timeout=duration):
print(res.log_output, end="")
_print_log_output(res.log_output, end="")
raise AllLogsRetrieved()
except grpc.RpcError as e:
# pylint: disable=E1101
Expand Down Expand Up @@ -140,7 +176,7 @@ def print_logs(run_id: int, channel: grpc.Channel, timeout: int) -> None:
with flwr_cli_grpc_exc_handler():
# Enforce timeout for graceful exit
for res in stub.StreamLogs(req, timeout=timeout):
print(res.log_output)
_print_log_output(res.log_output)
break
except grpc.RpcError as e:
if e.code() == grpc.StatusCode.NOT_FOUND: # pylint: disable=E1101
Expand Down
36 changes: 29 additions & 7 deletions framework/py/flwr/cli/log_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
from unittest.mock import Mock, call, patch

from flwr.proto.control_pb2 import StreamLogsResponse # pylint: disable=E0611
from rich.text import Text

from .log import print_logs, stream_logs
from .log import _render_log_output, print_logs, stream_logs


class InterruptedStreamLogsResponse:
Expand Down Expand Up @@ -70,17 +71,38 @@ def tearDown(self) -> None:

def test_flwr_log_stream_method(self) -> None:
"""Test stream_logs."""
with patch("builtins.print") as mock_print:
with patch("flwr.cli.log._print_log_output") as mock_print:
with self.assertRaises(KeyboardInterrupt):
stream_logs(
run_id=123, stub=self.mock_stub, duration=1, after_timestamp=0.0
)
# Assert that mock print was called with the expected arguments
mock_print.assert_has_calls(self.expected_stream_call)
# Assert that log chunks were printed as a stream
mock_print.assert_has_calls(
[
call("log_output_1", end=""),
call("log_output_2", end=""),
call("log_output_3", end=""),
]
)

def test_flwr_log_print_method(self) -> None:
"""Test print_logs."""
with patch("builtins.print") as mock_print:
with patch("flwr.cli.log._print_log_output") as mock_print:
print_logs(run_id=123, channel=self.mock_channel, timeout=0)
# Assert that mock print was called with the expected arguments
mock_print.assert_has_calls(self.expected_print_call)
# Assert that only the first log chunk was printed in show mode
mock_print.assert_has_calls([call("log_output_1")])

def test_render_log_output_styles_plain_header(self) -> None:
"""Test coloring plain log headers for streamed logs."""
text = _render_log_output("INFO: hello\nplain line\n")
assert isinstance(text, Text)
assert text.plain == "INFO: hello\nplain line\n"
assert text.spans
assert all(span.end <= len("INFO") for span in text.spans)

def test_render_log_output_from_ansi(self) -> None:
"""Test parsing ANSI log output."""
text = _render_log_output("\x1b[32mINFO\x1b[0m: hello\n")
assert isinstance(text, Text)
assert text.plain == "INFO: hello"
assert text.spans
105 changes: 58 additions & 47 deletions framework/py/flwr/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@
from logging import ERROR, WARN, LogRecord
from logging.handlers import HTTPHandler
from queue import Empty, Queue
from typing import TYPE_CHECKING, Any, TextIO
from typing import Any, TextIO

import grpc
import typer
from rich.console import Console
from rich.text import Text

from flwr.proto.log_pb2 import PushLogsRequest # pylint: disable=E0611
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
Expand All @@ -45,22 +46,18 @@
FLOWER_LOGGER.setLevel(logging.DEBUG)
log = FLOWER_LOGGER.log # pylint: disable=invalid-name

LOG_COLORS = {
"DEBUG": "\033[94m", # Blue
"INFO": "\033[92m", # Green
"WARNING": "\033[93m", # Yellow
"ERROR": "\033[91m", # Red
"CRITICAL": "\033[95m", # Magenta
"RESET": "\033[0m", # Reset to default
LOG_STYLES = {
"DEBUG": "blue",
"INFO": "green",
"WARNING": "yellow",
"ERROR": "red",
"CRITICAL": "magenta",
}
MESSAGE_FORMATTER = logging.Formatter("%(message)s")
TIME_FORMATTER = logging.Formatter()

if TYPE_CHECKING:
StreamHandler = logging.StreamHandler[Any]
else:
StreamHandler = logging.StreamHandler


class ConsoleHandler(StreamHandler):
class ConsoleHandler(logging.StreamHandler): # type: ignore[type-arg]
"""Console handler that allows configurable formatting."""

def __init__(
Expand All @@ -74,32 +71,52 @@ def __init__(
self.timestamps = timestamps
self.json = json
self.colored = colored
self.console: Console | None = None
self._console_stream: TextIO | None = None
self._console_colored: bool | None = None

def _get_console(self) -> Console:
if (
self.console is None
or self._console_stream is not self.stream
or self._console_colored != self.colored
):
self.console = Console(
file=self.stream,
highlight=False,
markup=False,
no_color=not self.colored,
force_terminal=self.colored,
)
self._console_stream = self.stream
self._console_colored = self.colored
return self.console

Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Console is initialized with file=self.stream, but elsewhere in this module console_handler.stream is reassigned (e.g., mirror_output_to_queue, redirect_output, restore_output) to capture/redirect output. rich.Console keeps its own file reference, so log output will continue going to the original stream instead of the updated handler.stream, which can break CLI JSON output capture/redirection. Consider overriding setStream/adding a setter that updates self.console.file whenever self.stream changes, or re-create/update the Console in emit() based on the current stream.

Suggested change
def setStream(self, stream: TextIO | None) -> None: # type: ignore[override]
"""Set the stream and keep the Console's file in sync."""
super().setStream(stream)
# Ensure the rich Console writes to the same stream as the handler.
# logging.StreamHandler.setStream sets self.stream, which may be None.
# rich.Console expects a file-like object; if self.stream is None,
# Console will continue using its existing file.
if hasattr(self, "console") and self.stream is not None:
self.console.file = self.stream

Copilot uses AI. Check for mistakes.
def emit(self, record: LogRecord) -> None:
"""Emit a record."""
console = self._get_console()
message = MESSAGE_FORMATTER.format(record)
formatted_time = TIME_FORMATTER.formatTime(record)
if self.json:
record.message = record.getMessage().replace("\t", "").strip()

# Check if the message is empty
if not record.message:
message = message.replace("\t", "").strip()
if not message:
return

super().emit(record)

def format(self, record: LogRecord) -> str:
"""Format function that adds colors to log level."""
seperator = " " * (8 - len(record.levelname))
if self.json:
log_fmt = "{lvl='%(levelname)s', time='%(asctime)s', msg='%(message)s'}"
renderable: str | Text = f"{{lvl='{record.levelname}', "
renderable += f"time='{formatted_time}', msg='{message}'}}"
else:
log_fmt = (
f"{LOG_COLORS[record.levelname] if self.colored else ''}"
f"%(levelname)s {'%(asctime)s' if self.timestamps else ''}"
f"{LOG_COLORS['RESET'] if self.colored else ''}"
f": {seperator} %(message)s"
)
formatter = logging.Formatter(log_fmt)
return formatter.format(record)
separator = " " * (8 - len(record.levelname))
timestamp = f" {formatted_time}" if self.timestamps else ""
head = f"{record.levelname}{timestamp}"
Comment on lines +98 to +109
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

formatted_time = TIME_FORMATTER.formatTime(record) is computed for every record even when self.json is false and self.timestamps is false, where the value is unused. Moving time formatting inside the branches that need it would avoid unnecessary strftime/time conversion work in the hot path.

Copilot uses AI. Check for mistakes.
tail = f": {separator} {message}"
if self.colored:
renderable = Text.assemble(
(head, LOG_STYLES.get(record.levelname)),
tail,
)
else:
renderable = f"{head}{tail}"

console.print(renderable, soft_wrap=True, overflow="ignore", crop=False)


def update_console_handler(
Expand All @@ -108,22 +125,16 @@ def update_console_handler(
colored: bool | None = None,
) -> None:
"""Update the logging handler."""
for handler in logging.getLogger(LOGGER_NAME).handlers:
if isinstance(handler, ConsoleHandler):
if level is not None:
handler.setLevel(level)
if timestamps is not None:
handler.timestamps = timestamps
if colored is not None:
handler.colored = colored
if level is not None:
console_handler.setLevel(level)
if timestamps is not None:
console_handler.timestamps = timestamps
if colored is not None:
console_handler.colored = colored


# Configure console logger
console_handler = ConsoleHandler(
timestamps=False,
json=False,
colored=True,
)
console_handler = ConsoleHandler(timestamps=False, json=False, colored=True)
console_handler.setLevel(logging.INFO)
FLOWER_LOGGER.addHandler(console_handler)

Expand Down
32 changes: 31 additions & 1 deletion framework/py/flwr/common/logger_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
"""Flower Logger tests."""


import logging
import sys
from queue import Queue
from typing import Any

from .logger import mirror_output_to_queue, restore_output
from . import logger as logger_module
from .logger import ConsoleHandler, mirror_output_to_queue, restore_output


def test_mirror_output_to_queue() -> None:
Expand Down Expand Up @@ -54,3 +57,30 @@ def test_restore_output() -> None:
assert log_queue.get() == "Test message before restore"
assert log_queue.get() == "\n"
assert log_queue.empty()


def test_console_handler_rebuilds_console_for_streamed_colored_logs(
monkeypatch: Any,
) -> None:
"""Test that streamed logs keep color by recreating Console on stream change."""
inits: list[dict[str, Any]] = []

class DummyConsole:
"""Capture Console init kwargs and ignore writes."""

def __init__(self, **kwargs: Any) -> None:
inits.append(kwargs)

def print(self, *args: Any, **kwargs: Any) -> None:
pass

monkeypatch.setattr(logger_module, "Console", DummyConsole)
handler = ConsoleHandler(colored=True)

handler.emit(logging.LogRecord("flwr", logging.INFO, "", 0, "first", (), None))
handler.stream = sys.stdout
handler.emit(logging.LogRecord("flwr", logging.INFO, "", 0, "second", (), None))

assert len(inits) == 2
assert all(kwargs["force_terminal"] is True for kwargs in inits)
assert all(kwargs["no_color"] is False for kwargs in inits)
Loading