Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions src/forge/observability/metric_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,16 @@ def _validate_backend_config(
f"{', '.join([mode.value for mode in LoggingMode])}."
)

mode_str = config["logging_mode"]
mode = LoggingMode(mode_str)
# Convert string to LoggingMode enum
mode_value = config["logging_mode"]
if isinstance(mode_value, str):
mode = LoggingMode(mode_value)
elif isinstance(mode_value, LoggingMode):
mode = mode_value
else:
raise TypeError(
f"logging_mode must be str or LoggingMode enum, got {type(mode_value)}"
)

# Validate per_rank_share_run configuration
share_run = config.get("per_rank_share_run", False)
Expand Down Expand Up @@ -302,7 +310,7 @@ async def init_backends(self, config: dict[str, Any]) -> None:
mode = backend_config["logging_mode"]

backend: LoggerBackend = get_logger_backend_class(backend_name)(
backend_config
**backend_config
)
await backend.init(role=BackendRole.GLOBAL, process_name="global_reduce")

Expand Down
123 changes: 74 additions & 49 deletions src/forge/observability/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ async def init_backends(

# instantiate local backend
backend: LoggerBackend = get_logger_backend_class(backend_name)(
backend_config
**backend_config
)
await backend.init(
role=BackendRole.LOCAL,
Expand Down Expand Up @@ -643,10 +643,21 @@ async def shutdown(self):


class LoggerBackend(ABC):
"""Abstract logger_backend for metric logging, e.g. wandb, jsonl, etc."""
"""Abstract logger_backend for metric logging, e.g. wandb, jsonl, etc.

def __init__(self, logger_backend_config: dict[str, Any]) -> None:
self.logger_backend_config = logger_backend_config
Args:
logging_mode: Logging behavior mode.
per_rank_share_run: Whether ranks share run. Default False.
**kwargs: Backend-specific arguments (e.g., project, name, tags for WandB).
"""

def __init__(
self, *, logging_mode: LoggingMode, per_rank_share_run: bool = False, **kwargs
) -> None:

self.logging_mode = logging_mode
self.per_rank_share_run = per_rank_share_run
self.backend_kwargs = kwargs

@abstractmethod
async def init(
Expand Down Expand Up @@ -706,8 +717,13 @@ def get_metadata_for_secondary_ranks(self) -> dict[str, Any] | None:
class ConsoleBackend(LoggerBackend):
"""Simple console logging of metrics."""

def __init__(self, logger_backend_config: dict[str, Any]) -> None:
super().__init__(logger_backend_config)
def __init__(
self, *, logging_mode: LoggingMode, per_rank_share_run: bool = False, **kwargs
) -> None:
super().__init__(
logging_mode=logging_mode, per_rank_share_run=per_rank_share_run, **kwargs
)
self.process_name = None

async def init(
self,
Expand Down Expand Up @@ -741,84 +757,98 @@ class WandbBackend(LoggerBackend):

For logging mode details, see `forge.observability.metrics.LoggingMode` documentation.

More details on wandb distributed logging here: https://docs.wandb.ai/guides/track/log/distributed-training/
More details on wandb distributed logging: https://docs.wandb.ai/guides/track/log/distributed-training/

Configuration:
logging_mode (LoggingMode): Determines logging behavior
logging_mode (LoggingMode): Determines logging behavior.
per_rank_share_run (bool, default False): For per-rank modes, whether to share run ID across ranks.
If true, then a single wandb is created and all ranks log to it. Its particularly useful if
logging with no_reduce to capture a time based stream of information. Not recommended if reducing values.
project (str): WandB project name
group (str, optional): WandB group name for organizing runs. Defaults to "experiment_group"
If true, a single wandb run is created and all ranks log to it. Particularly useful for
logging with no_reduce to capture time-based streams. Not recommended if reducing values.
**kwargs: Any argument accepted by wandb.init() (e.g., project, group, name, tags, notes, etc.)

Example:
WandbBackend(
logging_mode=LoggingMode.PER_RANK_REDUCE,
per_rank_share_run=False,
project="my_project",
group="exp_group",
name="my_experiment",
tags=["rl", "v2"],
notes="Testing new reward"
)
"""

def __init__(self, logger_backend_config: dict[str, Any]) -> None:
super().__init__(logger_backend_config)
self.project = logger_backend_config["project"]
self.group = logger_backend_config.get("group", "experiment_group")
self.process_name = None
def __init__(
self, *, logging_mode: LoggingMode, per_rank_share_run: bool = False, **kwargs
) -> None:
super().__init__(
logging_mode=logging_mode, per_rank_share_run=per_rank_share_run, **kwargs
)
self.run = None
self.logging_mode = LoggingMode(logger_backend_config["logging_mode"])
self.per_rank_share_run = logger_backend_config.get("per_rank_share_run", False)
self.process_name = None

async def init(
self,
role: BackendRole,
controller_logger_metadata: dict[str, Any] | None = None,
process_name: str | None = None,
) -> None:

if controller_logger_metadata is None:
controller_logger_metadata = {}

# Pop name, if any, to concat to process_name.
run_name = self.backend_kwargs.pop("name", None)
self.process_name = process_name

# GLOBAL_REDUCE mode: only inits on controller
# Format run name based on mode and role
if self.logging_mode == LoggingMode.GLOBAL_REDUCE:
if role != BackendRole.GLOBAL:
logger.warning(f"Skipped init for GLOBAL_REDUCE mode and {role} role.")
return
await self._init_global()
# use name as-is, no need to append controller process_name
await self._init_global(run_name)

# Per-rank modes based on per_rank_share_run bool
elif role == BackendRole.GLOBAL and self.per_rank_share_run:
await self._init_shared_global()
# use name as-is, no need to append controller process_name
await self._init_shared_global(run_name)

elif role == BackendRole.LOCAL:
# Per-rank: append process_name
run_name = f"{run_name}_{process_name}" if run_name else process_name

if self.per_rank_share_run:
await self._init_shared_local(controller_logger_metadata)
shared_id = controller_logger_metadata.get("shared_run_id")
if shared_id is None:
raise ValueError(
f"Shared ID required but not provided for {process_name} backend init"
)
await self._init_shared_local(run_name, shared_id, process_name)
else:
await self._init_per_rank()
await self._init_per_rank(run_name)

async def _init_global(self):
async def _init_global(self, run_name: str | None):
import wandb

self.run = wandb.init(project=self.project, group=self.group)
self.run = wandb.init(name=run_name, **self.backend_kwargs)

async def _init_per_rank(self):
async def _init_per_rank(self, run_name: str):
import wandb

self.run = wandb.init(
project=self.project, group=self.group, name=self.process_name
)
self.run = wandb.init(name=run_name, **self.backend_kwargs)

async def _init_shared_global(self):
async def _init_shared_global(self, run_name: str | None):
import wandb

settings = wandb.Settings(
mode="shared", x_primary=True, x_label="controller_primary"
)
self.run = wandb.init(project=self.project, group=self.group, settings=settings)
self.run = wandb.init(name=run_name, settings=settings, **self.backend_kwargs)

async def _init_shared_local(self, controller_metadata: dict[str, Any]):
async def _init_shared_local(
self, run_name: str, shared_id: str, process_name: str
):
import wandb

shared_id = controller_metadata.get("shared_run_id")
if shared_id is None:
raise ValueError(
f"Shared ID required but not provided for {self.process_name} backend init"
)

# Clear any stale service tokens that might be pointing to dead processes
# In multiprocessing environments, WandB service tokens can become stale and point
# to dead service processes. This causes wandb.init() to hang indefinitely trying
Expand All @@ -827,14 +857,9 @@ async def _init_shared_local(self, controller_metadata: dict[str, Any]):

service_token.clear_service_in_env()

settings = wandb.Settings(
mode="shared", x_primary=False, x_label=self.process_name
)
settings = wandb.Settings(mode="shared", x_primary=False, x_label=process_name)
self.run = wandb.init(
id=shared_id,
project=self.project,
group=self.group,
settings=settings,
name=run_name, id=shared_id, settings=settings, **self.backend_kwargs
)

async def log_batch(
Expand Down Expand Up @@ -862,7 +887,7 @@ def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None:
return

# Log with custom timestamp for precision
# Users can choose x-axis as timestamp in WandB UI and display as dateimte
# Users can choose x-axis as timestamp in WandB UI and display as datetime
log_data = {
metric.key: metric.value,
"timestamp": metric.timestamp,
Expand Down
22 changes: 11 additions & 11 deletions tests/unit_tests/observability/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,12 @@ def test_new_enums_and_constants(self):
async def test_backend_role_usage(self):
"""Test that BackendRole constants are actually used instead of string literals."""
# Test ConsoleBackend
console_backend = ConsoleBackend({})
console_backend = ConsoleBackend(logging_mode=LoggingMode.GLOBAL_REDUCE)
await console_backend.init(role=BackendRole.LOCAL)

# Test WandbBackend role validation without WandB initialization
wandb_backend = WandbBackend(
{"project": "test", "logging_mode": "global_reduce"}
logging_mode=LoggingMode.GLOBAL_REDUCE, project="test"
)

# Mock all the WandB init methods to focus only on role validation
Expand Down Expand Up @@ -298,15 +298,15 @@ def test_record_metric_enabled_explicit(self, mock_collector_class, mock_rank):
def test_wandb_backend_creation(self):
"""Test WandbBackend creation and basic setup without WandB dependency."""

config = {
"project": "test_project",
"group": "test_group",
"logging_mode": "global_reduce",
}
backend = WandbBackend(config)
backend = WandbBackend(
logging_mode=LoggingMode.GLOBAL_REDUCE,
project="test_project",
group="test_group",
)

assert backend.project == "test_project"
assert backend.group == "test_group"
# Test backend kwargs storage
assert backend.backend_kwargs["project"] == "test_project"
assert backend.backend_kwargs["group"] == "test_group"
assert backend.logging_mode == LoggingMode.GLOBAL_REDUCE
assert backend.per_rank_share_run is False # default

Expand All @@ -317,7 +317,7 @@ def test_wandb_backend_creation(self):
@pytest.mark.asyncio
async def test_console_backend(self):
"""Test ConsoleBackend basic operations."""
backend = ConsoleBackend({})
backend = ConsoleBackend(logging_mode=LoggingMode.GLOBAL_REDUCE)

await backend.init(role=BackendRole.LOCAL)

Expand Down
Loading