diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 5a6576d7e..60c6f1317 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -304,7 +304,7 @@ async def main(cfg: DictConfig): else: provisioner = await init_provisioner() - metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}}) + metric_logging_cfg = cfg.get("metric_logging", {}) mlogger = await get_or_create_metric_logger(process_name="Controller") await mlogger.init_backends.call_one(metric_logging_cfg) diff --git a/apps/grpo/qwen3_1_7b.yaml b/apps/grpo/qwen3_1_7b.yaml index 14e4871cf..a9acd268d 100644 --- a/apps/grpo/qwen3_1_7b.yaml +++ b/apps/grpo/qwen3_1_7b.yaml @@ -16,11 +16,12 @@ rollout_threads: 1 # Recommended to set equal to policy.num_replicas # Observability configuration metric_logging: wandb: - project: "grpo-training" - group: "grpo_exp_${oc.env:USER}" - reduce_across_ranks: True + project: grpo-training + group: grpo_exp_${oc.env:USER} + logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce + per_rank_share_run: False console: - reduce_across_ranks: True + logging_mode: global_reduce # Dataset configuration dataset: diff --git a/apps/grpo/qwen3_32b.yaml b/apps/grpo/qwen3_32b.yaml index e7a0cf509..b7b17b987 100644 --- a/apps/grpo/qwen3_32b.yaml +++ b/apps/grpo/qwen3_32b.yaml @@ -19,11 +19,12 @@ rollout_threads: 32 # make this 4x the number of policy replicas seems to work w # Observability configuration metric_logging: wandb: - project: "grpo-training" - group: "grpo_exp_${oc.env:USER}" - reduce_across_ranks: True + project: grpo-training + group: grpo_exp_${oc.env:USER} + logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce + per_rank_share_run: False console: - reduce_across_ranks: True + logging_mode: global_reduce # Dataset configuration dataset: diff --git a/apps/grpo/qwen3_8b.yaml b/apps/grpo/qwen3_8b.yaml index 534e5b92a..617643f2e 100644 --- a/apps/grpo/qwen3_8b.yaml +++ b/apps/grpo/qwen3_8b.yaml @@ -12,11 +12,12 @@ off_by_n: 1 # Off by one by default # Observability configuration metric_logging: wandb: - project: "grpo-training" - group: "grpo_exp_${oc.env:USER}" - reduce_across_ranks: True + project: grpo-training + group: grpo_exp_${oc.env:USER} + logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce + per_rank_share_run: False console: - reduce_across_ranks: True + logging_mode: global_reduce # Dataset configuration dataset: diff --git a/src/forge/observability/__init__.py b/src/forge/observability/__init__.py index 8efd3dace..555aa761e 100644 --- a/src/forge/observability/__init__.py +++ b/src/forge/observability/__init__.py @@ -12,7 +12,9 @@ from .metrics import ( BackendRole, ConsoleBackend, + get_logger_backend_class, LoggerBackend, + LoggingMode, MaxAccumulator, MeanAccumulator, Metric, @@ -43,6 +45,7 @@ "BackendRole", # Enums "Reduce", + "LoggingMode", # Utility functions "get_proc_name_with_rank", # Actor classes diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index f053d6a56..a86045fb6 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -22,6 +22,7 @@ BackendRole, get_logger_backend_class, LoggerBackend, + LoggingMode, MetricCollector, reduce_metrics_states, ) @@ -68,8 +69,8 @@ async def get_or_create_metric_logger( # Initialize logging backends await mlogger.init_backends({ - "console": {"reduce_across_ranks": True}, - "wandb": {"project": "my_project", "reduce_across_ranks": False} + "console": {"logging_mode": "global_reduce"}, + "wandb": {"project": "my_project", "logging_mode": "per_rank_reduce"} }) # Initialize services... @@ -127,7 +128,7 @@ async def get_or_create_metric_logger( class LocalFetcherActor(Actor): - """Thin per-process actor used to trigger MetricCollector singleton + """Thin per-rank actor used to trigger MetricCollector singleton operations without direct access. It is what GlobalLoggingActor uses to broadcast inits/flushes across ranks. @@ -165,20 +166,20 @@ async def flush( @endpoint async def init_backends( self, - metadata_per_primary_backend: dict[str, dict[str, Any]], + metadata_per_controller_backend: dict[str, dict[str, Any]], config: dict[str, Any], global_step: int = 0, ) -> None: """Init local (per-rank) logger backends and MetricCollector. Args: - metadata_per_primary_backend (dict[str, dict[str, Any]]): Metadata from primary backends for shared state. + metadata_per_controller_backend (dict[str, dict[str, Any]]): Metadata from controller backends for shared state. config (dict[str, Any]): Backend configurations with logging modes and settings. global_step (int): Initial step for metrics. """ collector = MetricCollector() await collector.init_backends( - metadata_per_primary_backend, + metadata_per_controller_backend, config, global_step, process_name=self.process_name, @@ -191,76 +192,128 @@ async def shutdown(self) -> None: class GlobalLoggingActor(Actor): - """Coordinates metric logging across all ranks for every training step. + """Coordinates metric logging across all ranks for every global step. Supports multiple logging backends (e.g., WandB, TensorBoard, etc.), for per-rank and/or global reduction logging modes. - If a backend config has flag `reduce_across_ranks=False`, an instance of the backend - is initialized per-rank, otherwise it is done once globally. - This GlobalLoggingActor should be spawned once in the controller. A LocalFetcherActor is automatically spawned per-rank in `forge.controller.provisioner.py` and registered with this actor. The LocalFetcherActor is responsible for instantiating - the per-rank MetricCollector. + the per-rank MetricCollector and working as a bridge between GlobalLoggingActor and processes. In summary, the flow is: - - GlobalLoggingActor init_backends() -> LocalFetcherActor init_backends() -> per-rank MetricCollector - - GlobalLoggingActor flush() -> LocalFetcherActor flush() -> per-rank MetricCollector flush + - GlobalLoggingActor.init_backends() -> LocalFetcherActor.init_backends() -> per-rank MetricCollector.init_backends() + - GlobalLoggingActor.flush() -> LocalFetcherActor.flush() -> per-rank MetricCollector.flush """ def __init__(self): self.fetchers: dict[str, LocalFetcherActor] = {} self.config: dict[str, Any] | None = None self.global_logger_backends: dict[str, LoggerBackend] = {} - self.metadata_per_primary_backend: dict[str, dict[str, Any]] = {} + self.metadata_per_controller_backend: dict[str, dict[str, Any]] = {} + + def _validate_backend_config( + self, backend_name: str, config: dict[str, Any] + ) -> dict[str, Any]: + """Validate and normalize backend configuration.""" + if "logging_mode" not in config: + logger.debug( + f"logging_mode not provided for backend {backend_name}. Defaulting to global_reduce." + ) + + mode_str = config.get("logging_mode", "global_reduce") + mode = LoggingMode(mode_str) + + # Validate per_rank_share_run configuration + share_run = config.get("per_rank_share_run", False) + if mode == LoggingMode.GLOBAL_REDUCE and share_run: + logger.warning( + f"{backend_name}: per_rank_share_run=True is ignored in {mode.value} mode. " + "Setting it to False." + ) + share_run = False + + # WandB-specific warning for suboptimal configuration + if ( + backend_name == "wandb" + and mode == LoggingMode.PER_RANK_REDUCE + and share_run + ): + logger.warning( + "WandB: Using 'per_rank_reduce' with 'per_rank_share_run=True' is not recommended. " + "This configuration can lead to confusing metrics where reduced values from multiple ranks " + "are written to the same run/step, displaying only one of them. Consider either:\n" + " 1. Set 'per_rank_share_run=False' to create separate runs per rank, OR\n" + " 2. Use 'per_rank_no_reduce' for real-time streaming to a shared run" + ) + + return { + **config, + "logging_mode": mode, + "per_rank_share_run": share_run, + } @endpoint async def init_backends(self, config: dict[str, Any]) -> None: - """ - Sets config in global actor, so other actors can get it, then eagerly initializes backend and MetricCollectors + """Sets config in global actor, initializes controller backends and eagerly initializes MetricCollectors in all registered fetchers. - A backend is always initialized in the controller (primary backend) and can be used as a logger or as a source - for metadata to be shared with per-rank backends, e.g. shared run IDs for wandb. - - The backend instantiation is controlled by the backend config flag `reduce_across_ranks`: if False, - a per-rank backend is initialized, i.e. if there are 2 ranks, each will have its own backend, - and will log independently, i.e. each rank will have its own run in wandb. - - Else, if True, the GlobalLoggingActor will fetch all local metrics collectors to get their states - and reduce them to a single value, which will be logged by the primary backend in this controller. + The backend instantiation is controlled by the logging_mode field. Controller backends + (instantiated in the controller) can provide metadata to be shared with rank backends, + e.g. shared run IDs for WandB. For details on logging modes, see `forge.observability.metrics.LoggingMode`. Args: - config (dict[str, Any]): Config for metric logging where keys are backend names, - e.g. {"console": {"reduce_across_ranks": True}, "wandb": {"reduce_across_ranks": False}} + config (dict[str, Any]): Config for metric logging where keys are backend names. + Each backend config supports: + - logging_mode (str | LoggingMode, default "global_reduce"): One of "global_reduce", + "per_rank_reduce", or "per_rank_no_reduce". Can be specified as a string or LoggingMode enum. + - per_rank_share_run (bool, default False): For per-rank modes only. Whether ranks + share a single run/logger instance. Ignored for "global_reduce" mode. + - Additional backend-specific options (e.g., "project" for WandB) + + Example: + { + "console": {"logging_mode": "global_reduce"}, + "wandb": { + "project": "my_project", + "logging_mode": "per_rank_no_reduce", + "per_rank_share_run": True + } + } + + Raises: + ValueError: If backend config is invalid or missing required fields. """ - self.config = config + self.config = {} + # Validate and normalize each backend config for backend_name, backend_config in config.items(): + self.config[backend_name] = self._validate_backend_config( + backend_name, backend_config + ) + + # Initialize backends based on logging mode + for backend_name, backend_config in self.config.items(): + mode = backend_config["logging_mode"] + backend = get_logger_backend_class(backend_name)(backend_config) await backend.init(role=BackendRole.GLOBAL) - # Extract metadata from primary logger to be shared with secondary loggers - # and store it - reduce_across_ranks = backend_config.get("reduce_across_ranks", True) - if not reduce_across_ranks: - primary_backend_metadata = ( - backend.get_metadata_for_secondary_ranks() or {} - ) - self.metadata_per_primary_backend[ - backend_name - ] = primary_backend_metadata + # Extract metadata from controller logger to be shared with per-rank loggers + if mode != LoggingMode.GLOBAL_REDUCE: + controller_metadata = backend.get_metadata_for_secondary_ranks() or {} + self.metadata_per_controller_backend[backend_name] = controller_metadata - # Store global logger backends - if reduce_across_ranks: + # Store global logger backends for later flush + if mode == LoggingMode.GLOBAL_REDUCE: self.global_logger_backends[backend_name] = backend - # Eager init collectors on all registered fetchers in parallel, passing primary states and config + # Eager init collectors on all registered fetchers in parallel, passing controller states and config if self.fetchers: tasks = [ fetcher.init_backends.call( - self.metadata_per_primary_backend, self.config + self.metadata_per_controller_backend, self.config ) for fetcher in self.fetchers.values() ] @@ -279,7 +332,7 @@ async def register_fetcher( if self.config: logger.debug(f"Initializing new LocalFetcherActor {name}") await fetcher.init_backends.call( - self.metadata_per_primary_backend, self.config + self.metadata_per_controller_backend, self.config ) @endpoint @@ -307,19 +360,21 @@ async def flush(self, global_step: int) -> None: config = self.config if config is None: logger.warning( - "GlobalLoggingActor flush() called before init_backends(). " - "No backends will be flushed." + "Cannot flush collected metrics. GlobalLoggingActor.flush() called before init_backends()." + " No backends will be flushed. Please call in your main file:\n" + "`mlogger = await get_or_create_metric_logger(process_name='Controller')`\n" + "`await mlogger.init_backends.call_one(logging_config)`\n" ) return - # if reduce_across_ranks=True, we need to reduce the states from all ranks - # and log with the primary backend + + # Check if need to do reduce and retrieve states from fetchers requires_reduce = any( - backend_config.get("reduce_across_ranks", True) + backend_config["logging_mode"] == LoggingMode.GLOBAL_REDUCE for backend_config in config.values() ) logger.debug( - f"Global flush for global_step {global_step}: {len(self.fetchers)} fetchers" + f"Global flush for global step {global_step}: {len(self.fetchers)} fetchers" ) # Broadcast flush to all fetchers @@ -332,21 +387,25 @@ async def flush(self, global_step: int) -> None: ) if requires_reduce: - # Handle exceptions and extract values from ValueMesh results - all_local_states = [] - for result in results: - if isinstance(result, BaseException): - logger.warning(f"Flush failed on a fetcher: {result}") - continue - - # result is a generator that outputs a pair [{'gpus': i/N}, {metric_key1: metric_state1, ...}}] - for gpu_info, local_metric_state in result.items(): - if isinstance(local_metric_state, dict): - all_local_states.append(local_metric_state) - else: - logger.warning( - f"Unexpected result from fetcher. {gpu_info=}, {local_metric_state=}" - ) + + def extract_values_from_valuemesh(results) -> list[dict[str, Any]]: + all_local_states = [] + for result in results: + if isinstance(result, BaseException): + logger.warning(f"Flush failed on a fetcher: {result}") + continue + + # result is a generator that outputs a pair [{'gpus': i/N}, {metric_key1: metric_state1, ...}}] + for gpu_info, local_metric_state in result.items(): + if isinstance(local_metric_state, dict): + all_local_states.append(local_metric_state) + else: + logger.warning( + f"Unexpected result from fetcher. {gpu_info=}, {local_metric_state=}" + ) + return all_local_states + + all_local_states = extract_values_from_valuemesh(results) if not all_local_states: logger.warning(f"No states to reduce for global_step {global_step}") @@ -355,12 +414,9 @@ async def flush(self, global_step: int) -> None: # Reduce metrics from states reduced_metrics = reduce_metrics_states(all_local_states) - # Log to each global logger_backend - for ( - logger_backend_name, - logger_backend, - ) in self.global_logger_backends.items(): - await logger_backend.log(reduced_metrics, global_step) + # Log to global backends + for backend_name, backend in self.global_logger_backends.items(): + await backend.log_batch(reduced_metrics, global_step) @endpoint def has_fetcher(self, name: str | ProcMesh) -> bool: diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 4996b3a7f..51b805f64 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -32,6 +32,32 @@ class BackendRole(Enum): GLOBAL = "global" +class LoggingMode(Enum): + """Metric logging behavior for distributed training scenarios. + + Each mode serves different observability needs: + + GLOBAL_REDUCE = "global_reduce" + Best for: Metrics that are best visualized as a single value per step. + Behavior: All ranks accumulate → controller reduces → single log entry + Example use: 8 ranks training, want 1 loss value per training step averaged across all + + PER_RANK_REDUCE = "per_rank_reduce" + Best for: Per-rank performance metrics, debugging individual rank behavior + Behavior: Each rank accumulates + logs its own reduced values + Example use: Monitor GPU utilization per rank, get 8 separate log entries per step + + PER_RANK_NO_REDUCE = "per_rank_no_reduce" + Best for: Real-time streaming, time-series debugging + Behavior: Raw values logged immediately on record_metric() calls. Ignores reduce type. + Example use: See what every rank is doing in real time. + """ + + GLOBAL_REDUCE = "global_reduce" + PER_RANK_REDUCE = "per_rank_reduce" + PER_RANK_NO_REDUCE = "per_rank_no_reduce" + + class Reduce(Enum): MEAN = "mean" SUM = "sum" @@ -352,22 +378,23 @@ def reset(self) -> None: class MetricCollector: """Per-rank singleton for accumulating, retrieving and flushing metrics to backends. - A logger is represented by a backend, i.e. wandb backend. If reduce_across_ranks=False, - the backend is instantiated per-rank, in the MetricCollector, otherwise it is instantiated once globally, - in the GlobalLoggingActor. + Supports multiple logging backends, each with different logging modes. + For options, check `forge.observability.metrics.LoggerBackend` and `forge.observability.metrics.LoggingMode`. - - Ensures one instance per process; actors call record_metric() which delegates here. + Behavior: + - Ensures one instance per rank; + - Using `record_metric()` delegates here; - Init via GlobalLoggingActor -> LocalFetcherActor -> per-rank MetricCollector; - GlobalLoggingActor flushes trigger reductions and log for any locally setup backend. Can optionally also - return non-reduced states for global aggregation. This can be different for each backend. - - Resets accumulators post-flush to avoid leaks across train steps; + return non-reduced states for global aggregation. + - Resets accumulators post-flush to avoid leaks across steps; """ _instances: dict[int, "MetricCollector"] = {} _singleton_rank: int def __new__(cls): - """Singleton per-rank, ensures one instance per process.""" + """Singleton per-rank, ensures one instance per rank.""" rank = current_rank().rank if rank not in cls._instances: @@ -388,26 +415,33 @@ def __init__(self) -> None: self.accumulators: dict[str, MetricAccumulator] = {} self.rank = current_rank().rank - self.logger_backends: list[LoggerBackend] = [] + self.per_rank_reduce_backends: list[LoggerBackend] = [] + self.per_rank_no_reduce_backends: list[LoggerBackend] = [] + self.global_step: int = 0 # Set on `init_backends` and updated on `flush` self._is_initialized = False self.process_name: str | None = None async def init_backends( self, - metadata_per_primary_backend: dict[str, dict[str, Any]] | None, + metadata_per_controller_backend: dict[str, dict[str, Any]] | None, config: dict[str, Any], global_step: int = 0, process_name: str | None = None, ) -> None: - """A logger is represented by a backend, i.e. wandb backend. If reduce_across_ranks=False, - the backend is instantiated per-rank, in the MetricCollector, otherwise it is only instantiated - once globally. + """Initialize per-rank logger backends and MetricCollector state. + + A logger backend is represented by a backend class (e.g. WandBBackend, ConsoleBackend). + Backends are categorized by their logging_mode. For details, see `forge.observability.metrics.LoggingMode`. Args: - metadata_per_primary_backend (dict[str, dict[str, Any]] | None): Metadata from primary - logger backend, e.g., {"wandb": {"run_id": "abc123"}}. - config (dict[str, Any]): Logger backend configuration, e.g. {"wandb": {"project": "my_project"}}. - global_step (int, default 0): Initial step for metrics. + metadata_per_controller_backend (Optional[Dict[str, Dict[str, Any]]]): Metadata from controller + logger backends for backends that require shared state across processes, e.g., + {"wandb": {"shared_run_id": "abc123"}}. + config (Dict[str, Any]): Backend configurations where each key is a backend name + and value contains logging_mode and backend-specific settings. + e.g., {"wandb": {"logging_mode": "per_rank_no_reduce", "project": "my_proj"}} + global_step (int, default 0): Initial step for logging. Can be used when + resuming from a checkpoint. process_name (str | None): The meaningful process name for logging. """ if self._is_initialized: @@ -418,40 +452,61 @@ async def init_backends( self.process_name = process_name self.global_step = global_step - # instantiate local backends if any + self.per_rank_reduce_backends: list[LoggerBackend] = [] + self.per_rank_no_reduce_backends: list[LoggerBackend] = [] + + # Initialize backends based on logging mode for backend_name, backend_config in config.items(): - if backend_config.get("reduce_across_ranks", True): - continue # Skip local backend instantiation and use global instead + mode = backend_config["logging_mode"] + + # sanity check + if not isinstance(mode, LoggingMode): + raise TypeError( + f"Expected LoggingMode enum for {backend_name}.logging_mode, got {type(mode)}: {mode}." + ) - # get metadata from primary backend if any - primary_metadata = {} - if metadata_per_primary_backend: - primary_metadata = metadata_per_primary_backend.get(backend_name, {}) + # Skip local instantiation. Backend will be instantiated in GlobalLoggingActor. + if mode == LoggingMode.GLOBAL_REDUCE: + logger.debug("Skipping local instantiation for GLOBAL_REDUCE") + continue + + # get metadata from controller backend if any + controller_metadata = {} + if metadata_per_controller_backend: + controller_metadata = metadata_per_controller_backend.get( + backend_name, {} + ) # instantiate local backend - logger_backend = get_logger_backend_class(backend_name)(backend_config) - await logger_backend.init( + backend = get_logger_backend_class(backend_name)(backend_config) + await backend.init( role=BackendRole.LOCAL, - primary_logger_metadata=primary_metadata, + controller_logger_metadata=controller_metadata, process_name=process_name, ) - self.logger_backends.append(logger_backend) + + # Categorize by logging mode + if mode == LoggingMode.PER_RANK_NO_REDUCE: + self.per_rank_no_reduce_backends.append(backend) + else: + self.per_rank_reduce_backends.append(backend) self._is_initialized = True def push(self, metric: Metric) -> None: """Process a metric according to configured logging modes. - Args: - metric: Metric dataclass containing key, value, reduction type, and timestamp. + Behavior depends on backend modes: + - PER_RANK_NO_REDUCE: Stream metric immediately to backends + - PER_RANK_REDUCE/GLOBAL_REDUCE: Accumulate for per step batch logging - Raises: - TypeError: If metric is not a Metric object. + Args: + metric (Metric): Metric dataclass Example: collector = MetricCollector() metric = Metric("loss", 0.5, Reduce.MEAN) - collector.push(metric) + collector.push(metric) # Streams immediately if no_reduce, else accumulates """ if not self._is_initialized: log_once( @@ -470,7 +525,13 @@ def push(self, metric: Metric) -> None: # Validate metric object if not isinstance(metric, Metric): - raise TypeError(f"Expected {Metric} object, got {type(metric)}") + raise TypeError( + f"Expected {Metric} object, got {metric} of type {type(metric)}" + ) + + # For PER_RANK_NO_REDUCE backends: stream without reduce + for backend in self.per_rank_no_reduce_backends: + backend.log_stream(metric=metric, global_step=self.global_step) # Always accumulate for reduction and state return key = metric.key @@ -499,7 +560,7 @@ async def flush( level=logging.WARNING, msg="Cannot flush collected metrics. MetricCollector.flush() called before init_backends()." "\nPlease call in your main file:\n" - "`mlogger = await get_or_create_metric_logger()`\n" + "`mlogger = await get_or_create_metric_logger(process_name='Controller')`\n" "`await mlogger.init_backends.call_one(logging_config)`\n" "before calling `flush`", ) @@ -517,27 +578,33 @@ async def flush( states[key] = acc.get_state() acc.reset() - # Reduce metrics from states for logging if any per-rank backend - if self.logger_backends: - # Use reduce_metrics_states for consistency - reduced_metrics = reduce_metrics_states([states]) + # Reduce and log to PER_RANK_REDUCE backends only (NO_REDUCE backends already logged in push) + if self.per_rank_reduce_backends: + metrics_for_backends = reduce_metrics_states([states]) - # Log to local logger_backends - for logger_backend in self.logger_backends: - await logger_backend.log(reduced_metrics, global_step) + for backend in self.per_rank_reduce_backends: + await backend.log_batch(metrics_for_backends, global_step) + + # Update step counter for streaming backends + # Note: This is incremented AFTER flush completes, so metrics recorded between + # flush(N) and flush(N+1) will stream with global_step=N+1. This is intentional: + # metrics belong to the training step being computed (step N+1), not the step + # that was just flushed (step N). + self.global_step = global_step + 1 return states if return_state else {} async def shutdown(self): """Shutdown logger_backends if initialized.""" + if not self._is_initialized: logger.debug( f"Collector for rank {get_proc_name_with_rank(self.process_name)} not initialized. Skipping shutdown" ) return - for logger_backend in self.logger_backends: - await logger_backend.finish() + for backend in self.per_rank_reduce_backends + self.per_rank_no_reduce_backends: + await backend.finish() ########### @@ -555,16 +622,16 @@ def __init__(self, logger_backend_config: dict[str, Any]) -> None: async def init( self, role: BackendRole, - primary_logger_metadata: dict[str, Any] | None = None, + controller_logger_metadata: dict[str, Any] | None = None, process_name: str | None = None, ) -> None: """ Initializes backend, e.g. wandb.run.init(). Args: - role (BackendRole): BackendRole.GLOBAL (controller/primary) or BackendRole.LOCAL (per-rank/secondary). - Can be used to behave differently for primary vs secondary roles. - primary_logger_metadata (dict[str, Any] | None): From global backend for + role (BackendRole): BackendRole.GLOBAL (controller) or BackendRole.LOCAL (per-rank). + Can be used to behave differently for controller vs rank roles. + controller_logger_metadata (dict[str, Any] | None): From global backend for backend that required shared info, e.g. {"shared_run_id": "abc123"}. process_name (str | None): Process name for logging. @@ -573,13 +640,23 @@ async def init( pass @abstractmethod - async def log(self, metrics: list[Metric], global_step: int) -> None: - """ - Log a batch of metrics to the backend. + async def log_batch( + self, metrics: list[Metric], global_step: int, *args, **kwargs + ) -> None: + """Log batch of accumulated metrics to backend""" + pass - Args: - metrics: list of Metric objects to log. - global_step: Step number for x-axis alignment across metrics. + def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None: + """Stream single metric to backend immediately. + + NOTE: This method is called synchronously. + If your backend requires async I/O operations: + - Use asyncio.create_task() for fire-and-forget logging + - Consider internal buffering to avoid blocking the caller + + Example for async backend: + def log_stream(self, metric, global_step): + asyncio.create_task(self._async_log(metric, global_step)) """ pass @@ -587,7 +664,7 @@ async def finish(self) -> None: pass def get_metadata_for_secondary_ranks(self) -> dict[str, Any] | None: - """Return sharable state after primary init (e.g., for shared modes). Called only on globals.""" + """Return sharable state after controller init (e.g., for shared modes). Called only on controller backends.""" return None @@ -600,12 +677,14 @@ def __init__(self, logger_backend_config: dict[str, Any]) -> None: async def init( self, role: BackendRole, - primary_logger_metadata: dict[str, Any] | None = None, + controller_logger_metadata: dict[str, Any] | None = None, process_name: str | None = None, ) -> None: self.prefix = get_proc_name_with_rank(proc_name=process_name) - async def log(self, metrics: list[Metric], global_step: int) -> None: + async def log_batch( + self, metrics: list[Metric], global_step: int, *args, **kwargs + ) -> None: metrics_str = "\n".join( f" {metric.key}: {metric.value}" for metric in sorted(metrics, key=lambda m: m.key) @@ -614,24 +693,27 @@ async def log(self, metrics: list[Metric], global_step: int) -> None: f"=== [{self.prefix}] - METRICS STEP {global_step} ===\n{metrics_str}\n==============================\n" ) + def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None: + """Stream metric to console immediately.""" + logger.info(f"{metric.key}: {metric.value}") + async def finish(self) -> None: pass class WandbBackend(LoggerBackend): """ - Weights & Biases logging backend for distributed training. + Weights & Biases logging backend. + + For logging mode details, see `forge.observability.metrics.LoggingMode` documentation. - Supports 3 types of modes as described in https://docs.wandb.ai/guides/track/log/distributed-training/: - Track a single process: reduce_across_ranks=True - Track each process separately: reduce_across_ranks=False, share_run_id=False - Track all processes to a single run: reduce_across_ranks=False, share_run_id=True + More details on wandb distributed logging here: https://docs.wandb.ai/guides/track/log/distributed-training/ Configuration: - reduce_across_ranks (bool, default True): If True, log reduced metrics only from controller (global mode). - If False, enables per-rank logging; then use share_run_id to pick mode. - share_run_id (bool, default False): Only used if reduce_across_ranks=False. - True -> shared run across ranks; False -> separate runs per rank. + logging_mode (LoggingMode): Determines logging behavior + per_rank_share_run (bool, default False): For per-rank modes, whether to share run ID across ranks. + If true, then a single wandb is created and all ranks log to it. Its particularly useful if + logging with no_reduce to capture a time based stream of information. Not recommended if reducing values. project (str): WandB project name group (str, optional): WandB group name for organizing runs. Defaults to "experiment_group" """ @@ -642,39 +724,35 @@ def __init__(self, logger_backend_config: dict[str, Any]) -> None: self.group = logger_backend_config.get("group", "experiment_group") self.name = None self.run = None - self.reduce_across_ranks = logger_backend_config.get( - "reduce_across_ranks", True - ) - self.share_run_id = logger_backend_config.get("share_run_id", False) + self.logging_mode = LoggingMode(logger_backend_config["logging_mode"]) + self.per_rank_share_run = logger_backend_config.get("per_rank_share_run", False) async def init( self, role: BackendRole, - primary_logger_metadata: dict[str, Any] | None = None, + controller_logger_metadata: dict[str, Any] | None = None, process_name: str | None = None, ) -> None: - if primary_logger_metadata is None: - primary_logger_metadata = {} + if controller_logger_metadata is None: + controller_logger_metadata = {} self.name = get_proc_name_with_rank(proc_name=process_name) - # Default global mode: only inits on controller - if self.reduce_across_ranks: + # GLOBAL_REDUCE mode: only inits on controller + if self.logging_mode == LoggingMode.GLOBAL_REDUCE: if role != BackendRole.GLOBAL: - logger.debug( - f"Skipped init for global mode (reduce_across_ranks=True) and {role} role." - ) + logger.warning(f"Skipped init for GLOBAL_REDUCE mode and {role} role.") return await self._init_global() - # Per-rank modes based on share_run_id bool - elif role == BackendRole.GLOBAL and self.share_run_id: + # Per-rank modes based on per_rank_share_run bool + elif role == BackendRole.GLOBAL and self.per_rank_share_run: await self._init_shared_global() elif role == BackendRole.LOCAL: - if self.share_run_id: - await self._init_shared_local(primary_logger_metadata) + if self.per_rank_share_run: + await self._init_shared_local(controller_logger_metadata) else: await self._init_per_rank() @@ -696,10 +774,10 @@ async def _init_shared_global(self): ) self.run = wandb.init(project=self.project, group=self.group, settings=settings) - async def _init_shared_local(self, primary_metadata: dict[str, Any]): + async def _init_shared_local(self, controller_metadata: dict[str, Any]): import wandb - shared_id = primary_metadata.get("shared_run_id") + shared_id = controller_metadata.get("shared_run_id") if shared_id is None: raise ValueError( f"Shared ID required but not provided for {self.name} backend init" @@ -721,22 +799,38 @@ async def _init_shared_local(self, primary_metadata: dict[str, Any]): settings=settings, ) - async def log(self, metrics: list[Metric], global_step: int) -> None: - if self.run: - # Convert metrics to WandB log format - log_data = {"global_step": global_step} - for metric in metrics: - log_data[metric.key] = metric.value - - self.run.log(log_data) - logger.info( - f"WandbBackend: Logged {len(metrics)} metrics at global_step {global_step}" - ) - else: + async def log_batch( + self, metrics: list[Metric], global_step: int, *args, **kwargs + ) -> None: + if not self.run: logger.debug(f"WandbBackend: No run started, skipping log for {self.name}") + return + + # Convert metrics to WandB log format + log_data = {"step": global_step} + for metric in metrics: + log_data[metric.key] = metric.value + + self.run.log(log_data) + logger.info( + f"WandbBackend: Logged {len(metrics)} metrics at step {global_step}" + ) + + def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None: + """Stream single metric to WandB with both step and timestamp.""" + if not self.run: + return + + # Log with both step and timestamp - users can choose x-axis in WandB UI + log_data = { + metric.key: metric.value, + "global_step": global_step, + "_timestamp": metric.timestamp, + } + self.run.log(log_data) def get_metadata_for_secondary_ranks(self) -> dict[str, Any]: - if self.run and not self.reduce_across_ranks and self.share_run_id: + if self.run and self.per_rank_share_run: return {"shared_run_id": self.run.id} return {} diff --git a/tests/sandbox/toy_rl/toy_metrics/main.py b/tests/sandbox/toy_rl/toy_metrics/main.py index eae50c2db..bcbdb6755 100644 --- a/tests/sandbox/toy_rl/toy_metrics/main.py +++ b/tests/sandbox/toy_rl/toy_metrics/main.py @@ -82,15 +82,13 @@ async def main(): group = f"grpo_exp_{int(time.time())}" # Config format: {backend_name: backend_config_dict} - # Each backend can specify reduce_across_ranks to control distributed logging behavior config = { - "console": {"reduce_across_ranks": True}, + "console": {"logging_mode": "global_reduce"}, "wandb": { - "project": "my_project", + "project": "toy_metrics", "group": group, - "reduce_across_ranks": False, - # Only useful if NOT reduce_across_ranks. - "share_run_id": False, # Share run ID across ranks -- Not recommended. + "logging_mode": "per_rank_reduce", # global_reduce, per_rank_reduce, per_rank_no_reduce + "per_rank_share_run": True, }, } diff --git a/tests/sandbox/vllm/main.py b/tests/sandbox/vllm/main.py index 7e0b22890..29fd50e1f 100644 --- a/tests/sandbox/vllm/main.py +++ b/tests/sandbox/vllm/main.py @@ -32,7 +32,9 @@ async def run(cfg: DictConfig): await init_provisioner( ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner)) ) - metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}}) + metric_logging_cfg = cfg.get( + "metric_logging", {"console": {"logging_mode": "global_reduce"}} + ) mlogger = await get_or_create_metric_logger(process_name="Controller") await mlogger.init_backends.call_one(metric_logging_cfg) diff --git a/tests/unit_tests/observability/test_metric_actors.py b/tests/unit_tests/observability/test_metric_actors.py index 501e13afe..2bb96e81e 100644 --- a/tests/unit_tests/observability/test_metric_actors.py +++ b/tests/unit_tests/observability/test_metric_actors.py @@ -6,6 +6,8 @@ """Optimized unit tests for metric actors functionality.""" +from unittest.mock import patch + import pytest from forge.observability.metric_actors import ( @@ -13,6 +15,8 @@ GlobalLoggingActor, LocalFetcherActor, ) + +from forge.observability.metrics import LoggingMode from monarch.actor import this_host @@ -62,7 +66,7 @@ async def test_global_logger_basic_ops(self, global_logger): async def test_backend_init(self, local_fetcher): """Test backend initialization and shutdown.""" metadata = {"wandb": {"shared_run_id": "test123"}} - config = {"console": {"logging_mode": "per_rank_reduce"}} + config = {"console": {"logging_mode": LoggingMode.PER_RANK_REDUCE}} await local_fetcher.init_backends.call_one(metadata, config, global_step=5) await local_fetcher.shutdown.call_one() @@ -113,20 +117,33 @@ async def test_valid_backend_configs(self, global_logger): config = {"console": {"logging_mode": mode}} await global_logger.init_backends.call_one(config) - @pytest.mark.timeout(3) - @pytest.mark.asyncio - async def test_invalid_backend_configs(self, global_logger): - """Test invalid backend configurations are handled gracefully.""" - # Empty config should work - await global_logger.init_backends.call_one({}) - - # Config with only project should work - config_with_project = {"console": {"project": "test_project"}} - await global_logger.init_backends.call_one(config_with_project) - - # Config with reduce_across_ranks should work (Diff 3 doesn't validate logging_mode yet) - config_with_reduce = {"console": {"reduce_across_ranks": True}} - await global_logger.init_backends.call_one(config_with_reduce) + def test_invalid_backend_configs(self): + """Test invalid backend configurations and warnings using direct validation.""" + actor = GlobalLoggingActor() + + # Test 1: Invalid logging_mode should raise ValueError + with pytest.raises(ValueError, match="is not a valid LoggingMode"): + actor._validate_backend_config("console", {"logging_mode": "invalid_mode"}) + + # Test 2: WandB PER_RANK_REDUCE + per_rank_share_run=True should warn + with patch("forge.observability.metric_actors.logger.warning") as mock_warn: + config = { + "logging_mode": "per_rank_reduce", + "per_rank_share_run": True, + "project": "test_project", + } + + result = actor._validate_backend_config("wandb", config) + + # Should have logged warning about suboptimal config + mock_warn.assert_called_once() + warning_msg = str(mock_warn.call_args) + assert "not recommended" in warning_msg + + # Should still return valid config with LoggingMode enum + assert result["logging_mode"] == LoggingMode.PER_RANK_REDUCE + assert result["per_rank_share_run"] is True + assert result["project"] == "test_project" class TestErrorHandling: diff --git a/tests/unit_tests/observability/test_metrics.py b/tests/unit_tests/observability/test_metrics.py index d0f104459..a4fd73c96 100644 --- a/tests/unit_tests/observability/test_metrics.py +++ b/tests/unit_tests/observability/test_metrics.py @@ -16,6 +16,7 @@ BackendRole, ConsoleBackend, get_logger_backend_class, + LoggingMode, MaxAccumulator, MeanAccumulator, Metric, @@ -88,7 +89,9 @@ async def test_backend_role_usage(self): await console_backend.init(role=BackendRole.LOCAL) # Test WandbBackend role validation without WandB initialization - wandb_backend = WandbBackend({"project": "test"}) + wandb_backend = WandbBackend( + {"project": "test", "logging_mode": "global_reduce"} + ) # Mock all the WandB init methods to focus only on role validation with patch.object(wandb_backend, "_init_global"), patch.object( @@ -298,14 +301,14 @@ def test_wandb_backend_creation(self): config = { "project": "test_project", "group": "test_group", - "reduce_across_ranks": True, + "logging_mode": "global_reduce", } backend = WandbBackend(config) assert backend.project == "test_project" assert backend.group == "test_group" - assert backend.reduce_across_ranks is True - assert backend.share_run_id is False # default + assert backend.logging_mode == LoggingMode.GLOBAL_REDUCE + assert backend.per_rank_share_run is False # default # Test metadata method metadata = backend.get_metadata_for_secondary_ranks() @@ -318,10 +321,10 @@ async def test_console_backend(self): await backend.init(role=BackendRole.LOCAL) - # Test log - should not raise + # Test log_batch - should not raise # Create a test metric test_metric = Metric("test", 1.0, Reduce.MEAN) - await backend.log([test_metric], global_step=1) + await backend.log_batch([test_metric], global_step=1) await backend.finish() # Should not raise