Skip to content

Commit d0e1165

Browse files
felipemello1Felipe Mello
authored andcommitted
Metric Logging - Enable kwargs so user can add extra args (meta-pytorch#482)
Co-authored-by: Felipe Mello <[email protected]>
1 parent 33295f1 commit d0e1165

File tree

3 files changed

+97
-62
lines changed

3 files changed

+97
-62
lines changed

src/forge/observability/metric_actors.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,16 @@ def _validate_backend_config(
225225
f"{', '.join([mode.value for mode in LoggingMode])}."
226226
)
227227

228-
mode_str = config["logging_mode"]
229-
mode = LoggingMode(mode_str)
228+
# Convert string to LoggingMode enum
229+
mode_value = config["logging_mode"]
230+
if isinstance(mode_value, str):
231+
mode = LoggingMode(mode_value)
232+
elif isinstance(mode_value, LoggingMode):
233+
mode = mode_value
234+
else:
235+
raise TypeError(
236+
f"logging_mode must be str or LoggingMode enum, got {type(mode_value)}"
237+
)
230238

231239
# Validate per_rank_share_run configuration
232240
share_run = config.get("per_rank_share_run", False)
@@ -303,7 +311,7 @@ async def init_backends(self, config: dict[str, Any]) -> None:
303311
mode = backend_config["logging_mode"]
304312

305313
backend: LoggerBackend = get_logger_backend_class(backend_name)(
306-
backend_config
314+
**backend_config
307315
)
308316
await backend.init(role=BackendRole.GLOBAL, process_name="global_reduce")
309317

src/forge/observability/metrics.py

Lines changed: 75 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
from typing import Any, Dict, List
1818

1919
import pytz
20-
from monarch.actor import current_rank
2120

2221
from forge.observability.utils import get_proc_name_with_rank
2322

2423
from forge.util.logging import get_logger, log_once
24+
from monarch.actor import current_rank
2525

2626
logger = get_logger("INFO")
2727

@@ -606,7 +606,7 @@ async def init_backends(
606606

607607
# instantiate local backend
608608
backend: LoggerBackend = get_logger_backend_class(backend_name)(
609-
backend_config
609+
**backend_config
610610
)
611611
await backend.init(
612612
role=BackendRole.LOCAL,
@@ -760,10 +760,21 @@ async def shutdown(self):
760760

761761

762762
class LoggerBackend(ABC):
763-
"""Abstract logger_backend for metric logging, e.g. wandb, jsonl, etc."""
763+
"""Abstract logger_backend for metric logging, e.g. wandb, jsonl, etc.
764+
765+
Args:
766+
logging_mode: Logging behavior mode.
767+
per_rank_share_run: Whether ranks share run. Default False.
768+
**kwargs: Backend-specific arguments (e.g., project, name, tags for WandB).
769+
"""
770+
771+
def __init__(
772+
self, *, logging_mode: LoggingMode, per_rank_share_run: bool = False, **kwargs
773+
) -> None:
764774

765-
def __init__(self, logger_backend_config: dict[str, Any]) -> None:
766-
self.logger_backend_config = logger_backend_config
775+
self.logging_mode = logging_mode
776+
self.per_rank_share_run = per_rank_share_run
777+
self.backend_kwargs = kwargs
767778

768779
@abstractmethod
769780
async def init(
@@ -823,8 +834,13 @@ def get_metadata_for_secondary_ranks(self) -> dict[str, Any] | None:
823834
class ConsoleBackend(LoggerBackend):
824835
"""Simple console logging of metrics."""
825836

826-
def __init__(self, logger_backend_config: dict[str, Any]) -> None:
827-
super().__init__(logger_backend_config)
837+
def __init__(
838+
self, *, logging_mode: LoggingMode, per_rank_share_run: bool = False, **kwargs
839+
) -> None:
840+
super().__init__(
841+
logging_mode=logging_mode, per_rank_share_run=per_rank_share_run, **kwargs
842+
)
843+
self.process_name = None
828844

829845
async def init(
830846
self,
@@ -868,85 +884,101 @@ class WandbBackend(LoggerBackend):
868884
869885
For logging mode details, see `forge.observability.metrics.LoggingMode` documentation.
870886
871-
More details on wandb distributed logging here: https://docs.wandb.ai/guides/track/log/distributed-training/
887+
More details on wandb distributed logging: https://docs.wandb.ai/guides/track/log/distributed-training/
872888
873889
Configuration:
874-
logging_mode (LoggingMode): Determines logging behavior
890+
logging_mode (LoggingMode): Determines logging behavior.
875891
per_rank_share_run (bool, default False): For per-rank modes, whether to share run ID across ranks.
876-
If true, then a single wandb is created and all ranks log to it. Its particularly useful if
877-
logging with no_reduce to capture a time based stream of information. Not recommended if reducing values.
878-
project (str): WandB project name
879-
group (str, optional): WandB group name for organizing runs. Defaults to "experiment_group"
892+
If true, a single wandb run is created and all ranks log to it. Particularly useful for
893+
logging with no_reduce to capture time-based streams. Not recommended if reducing values.
894+
**kwargs: Any argument accepted by wandb.init() (e.g., project, group, name, tags, notes, etc.)
895+
896+
Example:
897+
WandbBackend(
898+
logging_mode=LoggingMode.PER_RANK_REDUCE,
899+
per_rank_share_run=False,
900+
project="my_project",
901+
group="exp_group",
902+
name="my_experiment",
903+
tags=["rl", "v2"],
904+
notes="Testing new reward"
905+
)
880906
"""
881907

882-
def __init__(self, logger_backend_config: dict[str, Any]) -> None:
883-
super().__init__(logger_backend_config)
884-
self.project = logger_backend_config["project"]
885-
self.group = logger_backend_config.get("group", "experiment_group")
886-
self.process_name = None
908+
def __init__(
909+
self, *, logging_mode: LoggingMode, per_rank_share_run: bool = False, **kwargs
910+
) -> None:
911+
super().__init__(
912+
logging_mode=logging_mode, per_rank_share_run=per_rank_share_run, **kwargs
913+
)
887914
self.run = None
888915
self.logging_mode = LoggingMode(logger_backend_config["logging_mode"])
889916
self.per_rank_share_run = logger_backend_config.get("per_rank_share_run", False)
890917
self._tables: dict[str, "wandb.Table"] = {}
918+
self.process_name = None
891919

892920
async def init(
893921
self,
894922
role: BackendRole,
895923
controller_logger_metadata: dict[str, Any] | None = None,
896924
process_name: str | None = None,
897925
) -> None:
898-
899926
if controller_logger_metadata is None:
900927
controller_logger_metadata = {}
901928

929+
# Pop name, if any, to concat to process_name.
930+
run_name = self.backend_kwargs.pop("name", None)
902931
self.process_name = process_name
903932

904-
# GLOBAL_REDUCE mode: only inits on controller
933+
# Format run name based on mode and role
905934
if self.logging_mode == LoggingMode.GLOBAL_REDUCE:
906935
if role != BackendRole.GLOBAL:
907936
logger.warning(f"Skipped init for GLOBAL_REDUCE mode and {role} role.")
908937
return
909-
await self._init_global()
938+
# use name as-is, no need to append controller process_name
939+
await self._init_global(run_name)
910940

911-
# Per-rank modes based on per_rank_share_run bool
912941
elif role == BackendRole.GLOBAL and self.per_rank_share_run:
913-
await self._init_shared_global()
942+
# use name as-is, no need to append controller process_name
943+
await self._init_shared_global(run_name)
914944

915945
elif role == BackendRole.LOCAL:
946+
# Per-rank: append process_name
947+
run_name = f"{run_name}_{process_name}" if run_name else process_name
948+
916949
if self.per_rank_share_run:
917-
await self._init_shared_local(controller_logger_metadata)
950+
shared_id = controller_logger_metadata.get("shared_run_id")
951+
if shared_id is None:
952+
raise ValueError(
953+
f"Shared ID required but not provided for {process_name} backend init"
954+
)
955+
await self._init_shared_local(run_name, shared_id, process_name)
918956
else:
919-
await self._init_per_rank()
957+
await self._init_per_rank(run_name)
920958

921-
async def _init_global(self):
959+
async def _init_global(self, run_name: str | None):
922960
import wandb
923961

924-
self.run = wandb.init(project=self.project, group=self.group)
962+
self.run = wandb.init(name=run_name, **self.backend_kwargs)
925963

926-
async def _init_per_rank(self):
964+
async def _init_per_rank(self, run_name: str):
927965
import wandb
928966

929-
self.run = wandb.init(
930-
project=self.project, group=self.group, name=self.process_name
931-
)
967+
self.run = wandb.init(name=run_name, **self.backend_kwargs)
932968

933-
async def _init_shared_global(self):
969+
async def _init_shared_global(self, run_name: str | None):
934970
import wandb
935971

936972
settings = wandb.Settings(
937973
mode="shared", x_primary=True, x_label="controller_primary"
938974
)
939-
self.run = wandb.init(project=self.project, group=self.group, settings=settings)
975+
self.run = wandb.init(name=run_name, settings=settings, **self.backend_kwargs)
940976

941-
async def _init_shared_local(self, controller_metadata: dict[str, Any]):
977+
async def _init_shared_local(
978+
self, run_name: str, shared_id: str, process_name: str
979+
):
942980
import wandb
943981

944-
shared_id = controller_metadata.get("shared_run_id")
945-
if shared_id is None:
946-
raise ValueError(
947-
f"Shared ID required but not provided for {self.process_name} backend init"
948-
)
949-
950982
# Clear any stale service tokens that might be pointing to dead processes
951983
# In multiprocessing environments, WandB service tokens can become stale and point
952984
# to dead service processes. This causes wandb.init() to hang indefinitely trying
@@ -955,14 +987,9 @@ async def _init_shared_local(self, controller_metadata: dict[str, Any]):
955987

956988
service_token.clear_service_in_env()
957989

958-
settings = wandb.Settings(
959-
mode="shared", x_primary=False, x_label=self.process_name
960-
)
990+
settings = wandb.Settings(mode="shared", x_primary=False, x_label=process_name)
961991
self.run = wandb.init(
962-
id=shared_id,
963-
project=self.project,
964-
group=self.group,
965-
settings=settings,
992+
name=run_name, id=shared_id, settings=settings, **self.backend_kwargs
966993
)
967994

968995
async def log_batch(
@@ -990,7 +1017,7 @@ def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None:
9901017
return
9911018

9921019
# Log with custom timestamp for precision
993-
# Users can choose x-axis as timestamp in WandB UI and display as dateimte
1020+
# Users can choose x-axis as timestamp in WandB UI and display as datetime
9941021
log_data = {
9951022
metric.key: metric.value,
9961023
"timestamp": metric.timestamp,

tests/unit_tests/observability/test_metrics.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,12 @@ def test_new_enums_and_constants(self):
8585
async def test_backend_role_usage(self):
8686
"""Test that BackendRole constants are actually used instead of string literals."""
8787
# Test ConsoleBackend
88-
console_backend = ConsoleBackend({})
88+
console_backend = ConsoleBackend(logging_mode=LoggingMode.GLOBAL_REDUCE)
8989
await console_backend.init(role=BackendRole.LOCAL)
9090

9191
# Test WandbBackend role validation without WandB initialization
9292
wandb_backend = WandbBackend(
93-
{"project": "test", "logging_mode": "global_reduce"}
93+
logging_mode=LoggingMode.GLOBAL_REDUCE, project="test"
9494
)
9595

9696
# Mock all the WandB init methods to focus only on role validation
@@ -329,15 +329,15 @@ def test_record_metric_enabled_explicit(self, mock_collector_class, mock_rank):
329329
def test_wandb_backend_creation(self):
330330
"""Test WandbBackend creation and basic setup without WandB dependency."""
331331

332-
config = {
333-
"project": "test_project",
334-
"group": "test_group",
335-
"logging_mode": "global_reduce",
336-
}
337-
backend = WandbBackend(config)
332+
backend = WandbBackend(
333+
logging_mode=LoggingMode.GLOBAL_REDUCE,
334+
project="test_project",
335+
group="test_group",
336+
)
338337

339-
assert backend.project == "test_project"
340-
assert backend.group == "test_group"
338+
# Test backend kwargs storage
339+
assert backend.backend_kwargs["project"] == "test_project"
340+
assert backend.backend_kwargs["group"] == "test_group"
341341
assert backend.logging_mode == LoggingMode.GLOBAL_REDUCE
342342
assert backend.per_rank_share_run is False # default
343343

@@ -348,7 +348,7 @@ def test_wandb_backend_creation(self):
348348
@pytest.mark.asyncio
349349
async def test_console_backend(self):
350350
"""Test ConsoleBackend basic operations."""
351-
backend = ConsoleBackend({})
351+
backend = ConsoleBackend(logging_mode=LoggingMode.GLOBAL_REDUCE)
352352

353353
await backend.init(role=BackendRole.LOCAL)
354354

0 commit comments

Comments
 (0)