Skip to content

Commit 8d7cb10

Browse files
felipemello1Felipe Mello
andauthored
Metric Logging - Enable kwargs so user can add extra args (#482)
Co-authored-by: Felipe Mello <[email protected]>
1 parent 1a92113 commit 8d7cb10

File tree

3 files changed

+96
-63
lines changed

3 files changed

+96
-63
lines changed

src/forge/observability/metric_actors.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,16 @@ def _validate_backend_config(
224224
f"{', '.join([mode.value for mode in LoggingMode])}."
225225
)
226226

227-
mode_str = config["logging_mode"]
228-
mode = LoggingMode(mode_str)
227+
# Convert string to LoggingMode enum
228+
mode_value = config["logging_mode"]
229+
if isinstance(mode_value, str):
230+
mode = LoggingMode(mode_value)
231+
elif isinstance(mode_value, LoggingMode):
232+
mode = mode_value
233+
else:
234+
raise TypeError(
235+
f"logging_mode must be str or LoggingMode enum, got {type(mode_value)}"
236+
)
229237

230238
# Validate per_rank_share_run configuration
231239
share_run = config.get("per_rank_share_run", False)
@@ -302,7 +310,7 @@ async def init_backends(self, config: dict[str, Any]) -> None:
302310
mode = backend_config["logging_mode"]
303311

304312
backend: LoggerBackend = get_logger_backend_class(backend_name)(
305-
backend_config
313+
**backend_config
306314
)
307315
await backend.init(role=BackendRole.GLOBAL, process_name="global_reduce")
308316

src/forge/observability/metrics.py

Lines changed: 74 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ async def init_backends(
503503

504504
# instantiate local backend
505505
backend: LoggerBackend = get_logger_backend_class(backend_name)(
506-
backend_config
506+
**backend_config
507507
)
508508
await backend.init(
509509
role=BackendRole.LOCAL,
@@ -643,10 +643,21 @@ async def shutdown(self):
643643

644644

645645
class LoggerBackend(ABC):
646-
"""Abstract logger_backend for metric logging, e.g. wandb, jsonl, etc."""
646+
"""Abstract logger_backend for metric logging, e.g. wandb, jsonl, etc.
647647
648-
def __init__(self, logger_backend_config: dict[str, Any]) -> None:
649-
self.logger_backend_config = logger_backend_config
648+
Args:
649+
logging_mode: Logging behavior mode.
650+
per_rank_share_run: Whether ranks share run. Default False.
651+
**kwargs: Backend-specific arguments (e.g., project, name, tags for WandB).
652+
"""
653+
654+
def __init__(
655+
self, *, logging_mode: LoggingMode, per_rank_share_run: bool = False, **kwargs
656+
) -> None:
657+
658+
self.logging_mode = logging_mode
659+
self.per_rank_share_run = per_rank_share_run
660+
self.backend_kwargs = kwargs
650661

651662
@abstractmethod
652663
async def init(
@@ -706,8 +717,13 @@ def get_metadata_for_secondary_ranks(self) -> dict[str, Any] | None:
706717
class ConsoleBackend(LoggerBackend):
707718
"""Simple console logging of metrics."""
708719

709-
def __init__(self, logger_backend_config: dict[str, Any]) -> None:
710-
super().__init__(logger_backend_config)
720+
def __init__(
721+
self, *, logging_mode: LoggingMode, per_rank_share_run: bool = False, **kwargs
722+
) -> None:
723+
super().__init__(
724+
logging_mode=logging_mode, per_rank_share_run=per_rank_share_run, **kwargs
725+
)
726+
self.process_name = None
711727

712728
async def init(
713729
self,
@@ -741,84 +757,98 @@ class WandbBackend(LoggerBackend):
741757
742758
For logging mode details, see `forge.observability.metrics.LoggingMode` documentation.
743759
744-
More details on wandb distributed logging here: https://docs.wandb.ai/guides/track/log/distributed-training/
760+
More details on wandb distributed logging: https://docs.wandb.ai/guides/track/log/distributed-training/
745761
746762
Configuration:
747-
logging_mode (LoggingMode): Determines logging behavior
763+
logging_mode (LoggingMode): Determines logging behavior.
748764
per_rank_share_run (bool, default False): For per-rank modes, whether to share run ID across ranks.
749-
If true, then a single wandb is created and all ranks log to it. Its particularly useful if
750-
logging with no_reduce to capture a time based stream of information. Not recommended if reducing values.
751-
project (str): WandB project name
752-
group (str, optional): WandB group name for organizing runs. Defaults to "experiment_group"
765+
If true, a single wandb run is created and all ranks log to it. Particularly useful for
766+
logging with no_reduce to capture time-based streams. Not recommended if reducing values.
767+
**kwargs: Any argument accepted by wandb.init() (e.g., project, group, name, tags, notes, etc.)
768+
769+
Example:
770+
WandbBackend(
771+
logging_mode=LoggingMode.PER_RANK_REDUCE,
772+
per_rank_share_run=False,
773+
project="my_project",
774+
group="exp_group",
775+
name="my_experiment",
776+
tags=["rl", "v2"],
777+
notes="Testing new reward"
778+
)
753779
"""
754780

755-
def __init__(self, logger_backend_config: dict[str, Any]) -> None:
756-
super().__init__(logger_backend_config)
757-
self.project = logger_backend_config["project"]
758-
self.group = logger_backend_config.get("group", "experiment_group")
759-
self.process_name = None
781+
def __init__(
782+
self, *, logging_mode: LoggingMode, per_rank_share_run: bool = False, **kwargs
783+
) -> None:
784+
super().__init__(
785+
logging_mode=logging_mode, per_rank_share_run=per_rank_share_run, **kwargs
786+
)
760787
self.run = None
761-
self.logging_mode = LoggingMode(logger_backend_config["logging_mode"])
762-
self.per_rank_share_run = logger_backend_config.get("per_rank_share_run", False)
788+
self.process_name = None
763789

764790
async def init(
765791
self,
766792
role: BackendRole,
767793
controller_logger_metadata: dict[str, Any] | None = None,
768794
process_name: str | None = None,
769795
) -> None:
770-
771796
if controller_logger_metadata is None:
772797
controller_logger_metadata = {}
773798

799+
# Pop name, if any, to concat to process_name.
800+
run_name = self.backend_kwargs.pop("name", None)
774801
self.process_name = process_name
775802

776-
# GLOBAL_REDUCE mode: only inits on controller
803+
# Format run name based on mode and role
777804
if self.logging_mode == LoggingMode.GLOBAL_REDUCE:
778805
if role != BackendRole.GLOBAL:
779806
logger.warning(f"Skipped init for GLOBAL_REDUCE mode and {role} role.")
780807
return
781-
await self._init_global()
808+
# use name as-is, no need to append controller process_name
809+
await self._init_global(run_name)
782810

783-
# Per-rank modes based on per_rank_share_run bool
784811
elif role == BackendRole.GLOBAL and self.per_rank_share_run:
785-
await self._init_shared_global()
812+
# use name as-is, no need to append controller process_name
813+
await self._init_shared_global(run_name)
786814

787815
elif role == BackendRole.LOCAL:
816+
# Per-rank: append process_name
817+
run_name = f"{run_name}_{process_name}" if run_name else process_name
818+
788819
if self.per_rank_share_run:
789-
await self._init_shared_local(controller_logger_metadata)
820+
shared_id = controller_logger_metadata.get("shared_run_id")
821+
if shared_id is None:
822+
raise ValueError(
823+
f"Shared ID required but not provided for {process_name} backend init"
824+
)
825+
await self._init_shared_local(run_name, shared_id, process_name)
790826
else:
791-
await self._init_per_rank()
827+
await self._init_per_rank(run_name)
792828

793-
async def _init_global(self):
829+
async def _init_global(self, run_name: str | None):
794830
import wandb
795831

796-
self.run = wandb.init(project=self.project, group=self.group)
832+
self.run = wandb.init(name=run_name, **self.backend_kwargs)
797833

798-
async def _init_per_rank(self):
834+
async def _init_per_rank(self, run_name: str):
799835
import wandb
800836

801-
self.run = wandb.init(
802-
project=self.project, group=self.group, name=self.process_name
803-
)
837+
self.run = wandb.init(name=run_name, **self.backend_kwargs)
804838

805-
async def _init_shared_global(self):
839+
async def _init_shared_global(self, run_name: str | None):
806840
import wandb
807841

808842
settings = wandb.Settings(
809843
mode="shared", x_primary=True, x_label="controller_primary"
810844
)
811-
self.run = wandb.init(project=self.project, group=self.group, settings=settings)
845+
self.run = wandb.init(name=run_name, settings=settings, **self.backend_kwargs)
812846

813-
async def _init_shared_local(self, controller_metadata: dict[str, Any]):
847+
async def _init_shared_local(
848+
self, run_name: str, shared_id: str, process_name: str
849+
):
814850
import wandb
815851

816-
shared_id = controller_metadata.get("shared_run_id")
817-
if shared_id is None:
818-
raise ValueError(
819-
f"Shared ID required but not provided for {self.process_name} backend init"
820-
)
821-
822852
# Clear any stale service tokens that might be pointing to dead processes
823853
# In multiprocessing environments, WandB service tokens can become stale and point
824854
# to dead service processes. This causes wandb.init() to hang indefinitely trying
@@ -827,14 +857,9 @@ async def _init_shared_local(self, controller_metadata: dict[str, Any]):
827857

828858
service_token.clear_service_in_env()
829859

830-
settings = wandb.Settings(
831-
mode="shared", x_primary=False, x_label=self.process_name
832-
)
860+
settings = wandb.Settings(mode="shared", x_primary=False, x_label=process_name)
833861
self.run = wandb.init(
834-
id=shared_id,
835-
project=self.project,
836-
group=self.group,
837-
settings=settings,
862+
name=run_name, id=shared_id, settings=settings, **self.backend_kwargs
838863
)
839864

840865
async def log_batch(
@@ -862,7 +887,7 @@ def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None:
862887
return
863888

864889
# Log with custom timestamp for precision
865-
# Users can choose x-axis as timestamp in WandB UI and display as dateimte
890+
# Users can choose x-axis as timestamp in WandB UI and display as datetime
866891
log_data = {
867892
metric.key: metric.value,
868893
"timestamp": metric.timestamp,

tests/unit_tests/observability/test_metrics.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,12 @@ def test_new_enums_and_constants(self):
8585
async def test_backend_role_usage(self):
8686
"""Test that BackendRole constants are actually used instead of string literals."""
8787
# Test ConsoleBackend
88-
console_backend = ConsoleBackend({})
88+
console_backend = ConsoleBackend(logging_mode=LoggingMode.GLOBAL_REDUCE)
8989
await console_backend.init(role=BackendRole.LOCAL)
9090

9191
# Test WandbBackend role validation without WandB initialization
9292
wandb_backend = WandbBackend(
93-
{"project": "test", "logging_mode": "global_reduce"}
93+
logging_mode=LoggingMode.GLOBAL_REDUCE, project="test"
9494
)
9595

9696
# Mock all the WandB init methods to focus only on role validation
@@ -298,15 +298,15 @@ def test_record_metric_enabled_explicit(self, mock_collector_class, mock_rank):
298298
def test_wandb_backend_creation(self):
299299
"""Test WandbBackend creation and basic setup without WandB dependency."""
300300

301-
config = {
302-
"project": "test_project",
303-
"group": "test_group",
304-
"logging_mode": "global_reduce",
305-
}
306-
backend = WandbBackend(config)
301+
backend = WandbBackend(
302+
logging_mode=LoggingMode.GLOBAL_REDUCE,
303+
project="test_project",
304+
group="test_group",
305+
)
307306

308-
assert backend.project == "test_project"
309-
assert backend.group == "test_group"
307+
# Test backend kwargs storage
308+
assert backend.backend_kwargs["project"] == "test_project"
309+
assert backend.backend_kwargs["group"] == "test_group"
310310
assert backend.logging_mode == LoggingMode.GLOBAL_REDUCE
311311
assert backend.per_rank_share_run is False # default
312312

@@ -317,7 +317,7 @@ def test_wandb_backend_creation(self):
317317
@pytest.mark.asyncio
318318
async def test_console_backend(self):
319319
"""Test ConsoleBackend basic operations."""
320-
backend = ConsoleBackend({})
320+
backend = ConsoleBackend(logging_mode=LoggingMode.GLOBAL_REDUCE)
321321

322322
await backend.init(role=BackendRole.LOCAL)
323323

0 commit comments

Comments
 (0)