Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
77488cf
commit
Oct 8, 2025
feb4771
commit
Oct 8, 2025
41ceaa4
update backend role typehints and enum
Oct 8, 2025
8a24e71
update where we check FORGE_DISABLE_METRICS
Oct 8, 2025
3f3bc51
remove protected import
Oct 8, 2025
d82c354
Merge branch 'timestamp_logging_diff1' into timestamp_logging_diff2
Oct 8, 2025
4fe2611
protect import
Oct 8, 2025
8759bc8
Merge branch 'timestamp_logging_diff1' into timestamp_logging_diff2
Oct 8, 2025
fbb4a9e
Merge branch 'main' of https://github.com/meta-pytorch/forge into tim…
Oct 8, 2025
d81a4ed
record_metric uses dataclass Metric
Oct 8, 2025
1e2255d
commit
Oct 8, 2025
a94c612
Merge branch 'main' of https://github.com/meta-pytorch/forge into tim…
Oct 8, 2025
5b477e8
commit
Oct 9, 2025
f2b3eed
commit
Oct 9, 2025
471b88a
revert
Oct 9, 2025
1a02784
Merge branch 'timestamp_logging_diff2_5' into timestamp_logging_diff3
Oct 9, 2025
fa4895f
remove unnecessary code
Oct 9, 2025
7bb1fe7
better logging
Oct 9, 2025
43d5d27
docs/names
Oct 9, 2025
c97eb98
Merge branch 'timestamp_logging_diff2_5' into timestamp_logging_diff3
Oct 9, 2025
75355a2
commit
Oct 9, 2025
70e9c67
Merge branch 'main' of https://github.com/meta-pytorch/forge into tim…
Oct 9, 2025
12f77c9
Merge branch 'timestamp_logging_diff3' into timestamp_logging_diff4
Oct 9, 2025
1186aec
update cfg back to true
Oct 9, 2025
a02ea75
Merge branch 'main' of https://github.com/meta-pytorch/forge into tim…
Oct 13, 2025
aa00898
Merge branch 'timestamp_logging_diff3' into timestamp_logging_diff4
Oct 13, 2025
7d89f5c
Merge branch 'main' of https://github.com/meta-pytorch/forge into tim…
Oct 14, 2025
370c4e4
remove callstack, get meshname in provisioner
Oct 14, 2025
9e77930
get name from proc mesh
Oct 14, 2025
93b0cad
simplify + unit tests
Oct 14, 2025
84363b1
Merge branch 'main' of https://github.com/meta-pytorch/forge into tim…
Oct 14, 2025
e3c7a99
Merge branch 'timestamp_logging_diff3' into timestamp_logging_diff4
Oct 14, 2025
77e426b
Merge branch 'main' of https://github.com/meta-pytorch/forge into tim…
Oct 15, 2025
e901ad5
address comments
Oct 15, 2025
e42059b
Merge branch 'timestamp_logging_diff3' into timestamp_logging_diff4
Oct 15, 2025
f52408e
Merge branch 'main' of https://github.com/meta-pytorch/forge into tim…
Oct 15, 2025
6fc11bb
fix merge
Oct 15, 2025
72660f5
simplify comments
Oct 15, 2025
69f9f8c
renaming of var + better docs
Oct 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ async def main(cfg: DictConfig):
else:
provisioner = await init_provisioner()

metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
metric_logging_cfg = cfg.get("metric_logging", {})
mlogger = await get_or_create_metric_logger(process_name="Controller")
await mlogger.init_backends.call_one(metric_logging_cfg)

Expand Down
9 changes: 5 additions & 4 deletions apps/grpo/qwen3_1_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@ rollout_threads: 1 # Recommended to set equal to policy.num_replicas
# Observability configuration
metric_logging:
wandb:
project: "grpo-training"
group: "grpo_exp_${oc.env:USER}"
reduce_across_ranks: True
project: grpo-training
group: grpo_exp_${oc.env:USER}
logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce
per_rank_share_run: False
console:
reduce_across_ranks: True
logging_mode: global_reduce
Comment on lines +21 to +24
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to duplicate logging_mode across different configs like this? feels like clunky UX to me

Copy link
Contributor Author

@felipemello1 felipemello1 Oct 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is per backend. You could have scuba logging on streamining mode, console logging global_reduce and wandb logging per rank. If you have a single backend, you define it only once.


# Dataset configuration
dataset:
Expand Down
9 changes: 5 additions & 4 deletions apps/grpo/qwen3_32b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ rollout_threads: 32 # make this 4x the number of policy replicas seems to work w
# Observability configuration
metric_logging:
wandb:
project: "grpo-training"
group: "grpo_exp_${oc.env:USER}"
reduce_across_ranks: True
project: grpo-training
group: grpo_exp_${oc.env:USER}
logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce
per_rank_share_run: False
console:
reduce_across_ranks: True
logging_mode: global_reduce

# Dataset configuration
dataset:
Expand Down
9 changes: 5 additions & 4 deletions apps/grpo/qwen3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ off_by_n: 1 # Off by one by default
# Observability configuration
metric_logging:
wandb:
project: "grpo-training"
group: "grpo_exp_${oc.env:USER}"
reduce_across_ranks: True
project: grpo-training
group: grpo_exp_${oc.env:USER}
logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce
per_rank_share_run: False
console:
reduce_across_ranks: True
logging_mode: global_reduce

# Dataset configuration
dataset:
Expand Down
3 changes: 3 additions & 0 deletions src/forge/observability/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from .metrics import (
BackendRole,
ConsoleBackend,
get_logger_backend_class,
LoggerBackend,
LoggingMode,
MaxAccumulator,
MeanAccumulator,
Metric,
Expand Down Expand Up @@ -43,6 +45,7 @@
"BackendRole",
# Enums
"Reduce",
"LoggingMode",
# Utility functions
"get_proc_name_with_rank",
# Actor classes
Expand Down
196 changes: 126 additions & 70 deletions src/forge/observability/metric_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
BackendRole,
get_logger_backend_class,
LoggerBackend,
LoggingMode,
MetricCollector,
reduce_metrics_states,
)
Expand Down Expand Up @@ -68,8 +69,8 @@ async def get_or_create_metric_logger(

# Initialize logging backends
await mlogger.init_backends({
"console": {"reduce_across_ranks": True},
"wandb": {"project": "my_project", "reduce_across_ranks": False}
"console": {"logging_mode": "global_reduce"},
"wandb": {"project": "my_project", "logging_mode": "per_rank_reduce"}
})

# Initialize services...
Expand Down Expand Up @@ -127,7 +128,7 @@ async def get_or_create_metric_logger(


class LocalFetcherActor(Actor):
"""Thin per-process actor used to trigger MetricCollector singleton
"""Thin per-rank actor used to trigger MetricCollector singleton
operations without direct access. It is what GlobalLoggingActor
uses to broadcast inits/flushes across ranks.

Expand Down Expand Up @@ -165,20 +166,20 @@ async def flush(
@endpoint
async def init_backends(
self,
metadata_per_primary_backend: dict[str, dict[str, Any]],
metadata_per_controller_backend: dict[str, dict[str, Any]],
config: dict[str, Any],
global_step: int = 0,
) -> None:
"""Init local (per-rank) logger backends and MetricCollector.

Args:
metadata_per_primary_backend (dict[str, dict[str, Any]]): Metadata from primary backends for shared state.
metadata_per_controller_backend (dict[str, dict[str, Any]]): Metadata from controller backends for shared state.
config (dict[str, Any]): Backend configurations with logging modes and settings.
global_step (int): Initial step for metrics.
"""
collector = MetricCollector()
await collector.init_backends(
metadata_per_primary_backend,
metadata_per_controller_backend,
config,
global_step,
process_name=self.process_name,
Expand All @@ -191,76 +192,128 @@ async def shutdown(self) -> None:


class GlobalLoggingActor(Actor):
"""Coordinates metric logging across all ranks for every training step.
"""Coordinates metric logging across all ranks for every global step.

Supports multiple logging backends (e.g., WandB, TensorBoard, etc.),
for per-rank and/or global reduction logging modes.

If a backend config has flag `reduce_across_ranks=False`, an instance of the backend
is initialized per-rank, otherwise it is done once globally.

This GlobalLoggingActor should be spawned once in the controller. A LocalFetcherActor
is automatically spawned per-rank in `forge.controller.provisioner.py` and registered
with this actor. The LocalFetcherActor is responsible for instantiating
the per-rank MetricCollector.
the per-rank MetricCollector and working as a bridge between GlobalLoggingActor and processes.

In summary, the flow is:
- GlobalLoggingActor init_backends() -> LocalFetcherActor init_backends() -> per-rank MetricCollector
- GlobalLoggingActor flush() -> LocalFetcherActor flush() -> per-rank MetricCollector flush
- GlobalLoggingActor.init_backends() -> LocalFetcherActor.init_backends() -> per-rank MetricCollector.init_backends()
- GlobalLoggingActor.flush() -> LocalFetcherActor.flush() -> per-rank MetricCollector.flush
"""

def __init__(self):
self.fetchers: dict[str, LocalFetcherActor] = {}
self.config: dict[str, Any] | None = None
self.global_logger_backends: dict[str, LoggerBackend] = {}
self.metadata_per_primary_backend: dict[str, dict[str, Any]] = {}
self.metadata_per_controller_backend: dict[str, dict[str, Any]] = {}

def _validate_backend_config(
self, backend_name: str, config: dict[str, Any]
) -> dict[str, Any]:
"""Validate and normalize backend configuration."""
if "logging_mode" not in config:
logger.debug(
f"logging_mode not provided for backend {backend_name}. Defaulting to global_reduce."
)

mode_str = config.get("logging_mode", "global_reduce")
mode = LoggingMode(mode_str)

# Validate per_rank_share_run configuration
share_run = config.get("per_rank_share_run", False)
if mode == LoggingMode.GLOBAL_REDUCE and share_run:
logger.warning(
f"{backend_name}: per_rank_share_run=True is ignored in {mode.value} mode. "
"Setting it to False."
)
share_run = False

# WandB-specific warning for suboptimal configuration
if (
backend_name == "wandb"
and mode == LoggingMode.PER_RANK_REDUCE
and share_run
):
logger.warning(
"WandB: Using 'per_rank_reduce' with 'per_rank_share_run=True' is not recommended. "
"This configuration can lead to confusing metrics where reduced values from multiple ranks "
"are written to the same run/step, displaying only one of them. Consider either:\n"
" 1. Set 'per_rank_share_run=False' to create separate runs per rank, OR\n"
" 2. Use 'per_rank_no_reduce' for real-time streaming to a shared run"
)

return {
**config,
"logging_mode": mode,
"per_rank_share_run": share_run,
}

@endpoint
async def init_backends(self, config: dict[str, Any]) -> None:
"""
Sets config in global actor, so other actors can get it, then eagerly initializes backend and MetricCollectors
"""Sets config in global actor, initializes controller backends and eagerly initializes MetricCollectors
in all registered fetchers.

A backend is always initialized in the controller (primary backend) and can be used as a logger or as a source
for metadata to be shared with per-rank backends, e.g. shared run IDs for wandb.

The backend instantiation is controlled by the backend config flag `reduce_across_ranks`: if False,
a per-rank backend is initialized, i.e. if there are 2 ranks, each will have its own backend,
and will log independently, i.e. each rank will have its own run in wandb.

Else, if True, the GlobalLoggingActor will fetch all local metrics collectors to get their states
and reduce them to a single value, which will be logged by the primary backend in this controller.
The backend instantiation is controlled by the logging_mode field. Controller backends
(instantiated in the controller) can provide metadata to be shared with rank backends,
e.g. shared run IDs for WandB. For details on logging modes, see `forge.observability.metrics.LoggingMode`.

Args:
config (dict[str, Any]): Config for metric logging where keys are backend names,
e.g. {"console": {"reduce_across_ranks": True}, "wandb": {"reduce_across_ranks": False}}
config (dict[str, Any]): Config for metric logging where keys are backend names.
Each backend config supports:
- logging_mode (str | LoggingMode, default "global_reduce"): One of "global_reduce",
"per_rank_reduce", or "per_rank_no_reduce". Can be specified as a string or LoggingMode enum.
- per_rank_share_run (bool, default False): For per-rank modes only. Whether ranks
share a single run/logger instance. Ignored for "global_reduce" mode.
- Additional backend-specific options (e.g., "project" for WandB)

Example:
{
"console": {"logging_mode": "global_reduce"},
"wandb": {
"project": "my_project",
"logging_mode": "per_rank_no_reduce",
"per_rank_share_run": True
}
}

Raises:
ValueError: If backend config is invalid or missing required fields.
"""
self.config = config
self.config = {}

# Validate and normalize each backend config
for backend_name, backend_config in config.items():
self.config[backend_name] = self._validate_backend_config(
backend_name, backend_config
)

# Initialize backends based on logging mode
for backend_name, backend_config in self.config.items():
mode = backend_config["logging_mode"]

backend = get_logger_backend_class(backend_name)(backend_config)
await backend.init(role=BackendRole.GLOBAL)

# Extract metadata from primary logger to be shared with secondary loggers
# and store it
reduce_across_ranks = backend_config.get("reduce_across_ranks", True)
if not reduce_across_ranks:
primary_backend_metadata = (
backend.get_metadata_for_secondary_ranks() or {}
)
self.metadata_per_primary_backend[
backend_name
] = primary_backend_metadata
# Extract metadata from controller logger to be shared with per-rank loggers
if mode != LoggingMode.GLOBAL_REDUCE:
controller_metadata = backend.get_metadata_for_secondary_ranks() or {}
self.metadata_per_controller_backend[backend_name] = controller_metadata

# Store global logger backends
if reduce_across_ranks:
# Store global logger backends for later flush
if mode == LoggingMode.GLOBAL_REDUCE:
self.global_logger_backends[backend_name] = backend

# Eager init collectors on all registered fetchers in parallel, passing primary states and config
# Eager init collectors on all registered fetchers in parallel, passing controller states and config
if self.fetchers:
tasks = [
fetcher.init_backends.call(
self.metadata_per_primary_backend, self.config
self.metadata_per_controller_backend, self.config
)
for fetcher in self.fetchers.values()
]
Expand All @@ -279,7 +332,7 @@ async def register_fetcher(
if self.config:
logger.debug(f"Initializing new LocalFetcherActor {name}")
await fetcher.init_backends.call(
self.metadata_per_primary_backend, self.config
self.metadata_per_controller_backend, self.config
)

@endpoint
Expand Down Expand Up @@ -307,19 +360,21 @@ async def flush(self, global_step: int) -> None:
config = self.config
if config is None:
logger.warning(
"GlobalLoggingActor flush() called before init_backends(). "
"No backends will be flushed."
"Cannot flush collected metrics. GlobalLoggingActor.flush() called before init_backends()."
" No backends will be flushed. Please call in your main file:\n"
"`mlogger = await get_or_create_metric_logger(process_name='Controller')`\n"
"`await mlogger.init_backends.call_one(logging_config)`\n"
)
return
# if reduce_across_ranks=True, we need to reduce the states from all ranks
# and log with the primary backend

# Check if need to do reduce and retrieve states from fetchers
requires_reduce = any(
backend_config.get("reduce_across_ranks", True)
backend_config["logging_mode"] == LoggingMode.GLOBAL_REDUCE
for backend_config in config.values()
)

logger.debug(
f"Global flush for global_step {global_step}: {len(self.fetchers)} fetchers"
f"Global flush for global step {global_step}: {len(self.fetchers)} fetchers"
)

# Broadcast flush to all fetchers
Expand All @@ -332,21 +387,25 @@ async def flush(self, global_step: int) -> None:
)

if requires_reduce:
# Handle exceptions and extract values from ValueMesh results
all_local_states = []
for result in results:
if isinstance(result, BaseException):
logger.warning(f"Flush failed on a fetcher: {result}")
continue

# result is a generator that outputs a pair [{'gpus': i/N}, {metric_key1: metric_state1, ...}}]
for gpu_info, local_metric_state in result.items():
if isinstance(local_metric_state, dict):
all_local_states.append(local_metric_state)
else:
logger.warning(
f"Unexpected result from fetcher. {gpu_info=}, {local_metric_state=}"
)

def extract_values_from_valuemesh(results) -> list[dict[str, Any]]:
all_local_states = []
for result in results:
if isinstance(result, BaseException):
logger.warning(f"Flush failed on a fetcher: {result}")
continue

# result is a generator that outputs a pair [{'gpus': i/N}, {metric_key1: metric_state1, ...}}]
for gpu_info, local_metric_state in result.items():
if isinstance(local_metric_state, dict):
all_local_states.append(local_metric_state)
else:
logger.warning(
f"Unexpected result from fetcher. {gpu_info=}, {local_metric_state=}"
)
return all_local_states

all_local_states = extract_values_from_valuemesh(results)

if not all_local_states:
logger.warning(f"No states to reduce for global_step {global_step}")
Expand All @@ -355,12 +414,9 @@ async def flush(self, global_step: int) -> None:
# Reduce metrics from states
reduced_metrics = reduce_metrics_states(all_local_states)

# Log to each global logger_backend
for (
logger_backend_name,
logger_backend,
) in self.global_logger_backends.items():
await logger_backend.log(reduced_metrics, global_step)
# Log to global backends
for backend_name, backend in self.global_logger_backends.items():
await backend.log_batch(reduced_metrics, global_step)

@endpoint
def has_fetcher(self, name: str | ProcMesh) -> bool:
Expand Down
Loading
Loading