diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index e1f0a65df..0ba2b8c73 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -224,8 +224,16 @@ def _validate_backend_config( f"{', '.join([mode.value for mode in LoggingMode])}." ) - mode_str = config["logging_mode"] - mode = LoggingMode(mode_str) + # Convert string to LoggingMode enum + mode_value = config["logging_mode"] + if isinstance(mode_value, str): + mode = LoggingMode(mode_value) + elif isinstance(mode_value, LoggingMode): + mode = mode_value + else: + raise TypeError( + f"logging_mode must be str or LoggingMode enum, got {type(mode_value)}" + ) # Validate per_rank_share_run configuration share_run = config.get("per_rank_share_run", False) @@ -302,7 +310,7 @@ async def init_backends(self, config: dict[str, Any]) -> None: mode = backend_config["logging_mode"] backend: LoggerBackend = get_logger_backend_class(backend_name)( - backend_config + **backend_config ) await backend.init(role=BackendRole.GLOBAL, process_name="global_reduce") diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 55a3c31a2..c5ca7cf1f 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -503,7 +503,7 @@ async def init_backends( # instantiate local backend backend: LoggerBackend = get_logger_backend_class(backend_name)( - backend_config + **backend_config ) await backend.init( role=BackendRole.LOCAL, @@ -643,10 +643,21 @@ async def shutdown(self): class LoggerBackend(ABC): - """Abstract logger_backend for metric logging, e.g. wandb, jsonl, etc.""" + """Abstract logger_backend for metric logging, e.g. wandb, jsonl, etc. - def __init__(self, logger_backend_config: dict[str, Any]) -> None: - self.logger_backend_config = logger_backend_config + Args: + logging_mode: Logging behavior mode. + per_rank_share_run: Whether ranks share run. Default False. + **kwargs: Backend-specific arguments (e.g., project, name, tags for WandB). + """ + + def __init__( + self, *, logging_mode: LoggingMode, per_rank_share_run: bool = False, **kwargs + ) -> None: + + self.logging_mode = logging_mode + self.per_rank_share_run = per_rank_share_run + self.backend_kwargs = kwargs @abstractmethod async def init( @@ -706,8 +717,13 @@ def get_metadata_for_secondary_ranks(self) -> dict[str, Any] | None: class ConsoleBackend(LoggerBackend): """Simple console logging of metrics.""" - def __init__(self, logger_backend_config: dict[str, Any]) -> None: - super().__init__(logger_backend_config) + def __init__( + self, *, logging_mode: LoggingMode, per_rank_share_run: bool = False, **kwargs + ) -> None: + super().__init__( + logging_mode=logging_mode, per_rank_share_run=per_rank_share_run, **kwargs + ) + self.process_name = None async def init( self, @@ -741,25 +757,35 @@ class WandbBackend(LoggerBackend): For logging mode details, see `forge.observability.metrics.LoggingMode` documentation. - More details on wandb distributed logging here: https://docs.wandb.ai/guides/track/log/distributed-training/ + More details on wandb distributed logging: https://docs.wandb.ai/guides/track/log/distributed-training/ Configuration: - logging_mode (LoggingMode): Determines logging behavior + logging_mode (LoggingMode): Determines logging behavior. per_rank_share_run (bool, default False): For per-rank modes, whether to share run ID across ranks. - If true, then a single wandb is created and all ranks log to it. Its particularly useful if - logging with no_reduce to capture a time based stream of information. Not recommended if reducing values. - project (str): WandB project name - group (str, optional): WandB group name for organizing runs. Defaults to "experiment_group" + If true, a single wandb run is created and all ranks log to it. Particularly useful for + logging with no_reduce to capture time-based streams. Not recommended if reducing values. + **kwargs: Any argument accepted by wandb.init() (e.g., project, group, name, tags, notes, etc.) + + Example: + WandbBackend( + logging_mode=LoggingMode.PER_RANK_REDUCE, + per_rank_share_run=False, + project="my_project", + group="exp_group", + name="my_experiment", + tags=["rl", "v2"], + notes="Testing new reward" + ) """ - def __init__(self, logger_backend_config: dict[str, Any]) -> None: - super().__init__(logger_backend_config) - self.project = logger_backend_config["project"] - self.group = logger_backend_config.get("group", "experiment_group") - self.process_name = None + def __init__( + self, *, logging_mode: LoggingMode, per_rank_share_run: bool = False, **kwargs + ) -> None: + super().__init__( + logging_mode=logging_mode, per_rank_share_run=per_rank_share_run, **kwargs + ) self.run = None - self.logging_mode = LoggingMode(logger_backend_config["logging_mode"]) - self.per_rank_share_run = logger_backend_config.get("per_rank_share_run", False) + self.process_name = None async def init( self, @@ -767,58 +793,62 @@ async def init( controller_logger_metadata: dict[str, Any] | None = None, process_name: str | None = None, ) -> None: - if controller_logger_metadata is None: controller_logger_metadata = {} + # Pop name, if any, to concat to process_name. + run_name = self.backend_kwargs.pop("name", None) self.process_name = process_name - # GLOBAL_REDUCE mode: only inits on controller + # Format run name based on mode and role if self.logging_mode == LoggingMode.GLOBAL_REDUCE: if role != BackendRole.GLOBAL: logger.warning(f"Skipped init for GLOBAL_REDUCE mode and {role} role.") return - await self._init_global() + # use name as-is, no need to append controller process_name + await self._init_global(run_name) - # Per-rank modes based on per_rank_share_run bool elif role == BackendRole.GLOBAL and self.per_rank_share_run: - await self._init_shared_global() + # use name as-is, no need to append controller process_name + await self._init_shared_global(run_name) elif role == BackendRole.LOCAL: + # Per-rank: append process_name + run_name = f"{run_name}_{process_name}" if run_name else process_name + if self.per_rank_share_run: - await self._init_shared_local(controller_logger_metadata) + shared_id = controller_logger_metadata.get("shared_run_id") + if shared_id is None: + raise ValueError( + f"Shared ID required but not provided for {process_name} backend init" + ) + await self._init_shared_local(run_name, shared_id, process_name) else: - await self._init_per_rank() + await self._init_per_rank(run_name) - async def _init_global(self): + async def _init_global(self, run_name: str | None): import wandb - self.run = wandb.init(project=self.project, group=self.group) + self.run = wandb.init(name=run_name, **self.backend_kwargs) - async def _init_per_rank(self): + async def _init_per_rank(self, run_name: str): import wandb - self.run = wandb.init( - project=self.project, group=self.group, name=self.process_name - ) + self.run = wandb.init(name=run_name, **self.backend_kwargs) - async def _init_shared_global(self): + async def _init_shared_global(self, run_name: str | None): import wandb settings = wandb.Settings( mode="shared", x_primary=True, x_label="controller_primary" ) - self.run = wandb.init(project=self.project, group=self.group, settings=settings) + self.run = wandb.init(name=run_name, settings=settings, **self.backend_kwargs) - async def _init_shared_local(self, controller_metadata: dict[str, Any]): + async def _init_shared_local( + self, run_name: str, shared_id: str, process_name: str + ): import wandb - shared_id = controller_metadata.get("shared_run_id") - if shared_id is None: - raise ValueError( - f"Shared ID required but not provided for {self.process_name} backend init" - ) - # Clear any stale service tokens that might be pointing to dead processes # In multiprocessing environments, WandB service tokens can become stale and point # to dead service processes. This causes wandb.init() to hang indefinitely trying @@ -827,14 +857,9 @@ async def _init_shared_local(self, controller_metadata: dict[str, Any]): service_token.clear_service_in_env() - settings = wandb.Settings( - mode="shared", x_primary=False, x_label=self.process_name - ) + settings = wandb.Settings(mode="shared", x_primary=False, x_label=process_name) self.run = wandb.init( - id=shared_id, - project=self.project, - group=self.group, - settings=settings, + name=run_name, id=shared_id, settings=settings, **self.backend_kwargs ) async def log_batch( @@ -862,7 +887,7 @@ def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None: return # Log with custom timestamp for precision - # Users can choose x-axis as timestamp in WandB UI and display as dateimte + # Users can choose x-axis as timestamp in WandB UI and display as datetime log_data = { metric.key: metric.value, "timestamp": metric.timestamp, diff --git a/tests/unit_tests/observability/test_metrics.py b/tests/unit_tests/observability/test_metrics.py index cda3679a5..070e8a4f5 100644 --- a/tests/unit_tests/observability/test_metrics.py +++ b/tests/unit_tests/observability/test_metrics.py @@ -85,12 +85,12 @@ def test_new_enums_and_constants(self): async def test_backend_role_usage(self): """Test that BackendRole constants are actually used instead of string literals.""" # Test ConsoleBackend - console_backend = ConsoleBackend({}) + console_backend = ConsoleBackend(logging_mode=LoggingMode.GLOBAL_REDUCE) await console_backend.init(role=BackendRole.LOCAL) # Test WandbBackend role validation without WandB initialization wandb_backend = WandbBackend( - {"project": "test", "logging_mode": "global_reduce"} + logging_mode=LoggingMode.GLOBAL_REDUCE, project="test" ) # Mock all the WandB init methods to focus only on role validation @@ -298,15 +298,15 @@ def test_record_metric_enabled_explicit(self, mock_collector_class, mock_rank): def test_wandb_backend_creation(self): """Test WandbBackend creation and basic setup without WandB dependency.""" - config = { - "project": "test_project", - "group": "test_group", - "logging_mode": "global_reduce", - } - backend = WandbBackend(config) + backend = WandbBackend( + logging_mode=LoggingMode.GLOBAL_REDUCE, + project="test_project", + group="test_group", + ) - assert backend.project == "test_project" - assert backend.group == "test_group" + # Test backend kwargs storage + assert backend.backend_kwargs["project"] == "test_project" + assert backend.backend_kwargs["group"] == "test_group" assert backend.logging_mode == LoggingMode.GLOBAL_REDUCE assert backend.per_rank_share_run is False # default @@ -317,7 +317,7 @@ def test_wandb_backend_creation(self): @pytest.mark.asyncio async def test_console_backend(self): """Test ConsoleBackend basic operations.""" - backend = ConsoleBackend({}) + backend = ConsoleBackend(logging_mode=LoggingMode.GLOBAL_REDUCE) await backend.init(role=BackendRole.LOCAL)