Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
77488cf
commit
Oct 8, 2025
feb4771
commit
Oct 8, 2025
41ceaa4
update backend role typehints and enum
Oct 8, 2025
8a24e71
update where we check FORGE_DISABLE_METRICS
Oct 8, 2025
3f3bc51
remove protected import
Oct 8, 2025
d82c354
Merge branch 'timestamp_logging_diff1' into timestamp_logging_diff2
Oct 8, 2025
4fe2611
protect import
Oct 8, 2025
8759bc8
Merge branch 'timestamp_logging_diff1' into timestamp_logging_diff2
Oct 8, 2025
fbb4a9e
Merge branch 'main' of https://github.com/meta-pytorch/forge into tim…
Oct 8, 2025
d81a4ed
record_metric uses dataclass Metric
Oct 8, 2025
1e2255d
commit
Oct 8, 2025
a94c612
Merge branch 'main' of https://github.com/meta-pytorch/forge into tim…
Oct 8, 2025
5b477e8
commit
Oct 9, 2025
f2b3eed
commit
Oct 9, 2025
471b88a
revert
Oct 9, 2025
1a02784
Merge branch 'timestamp_logging_diff2_5' into timestamp_logging_diff3
Oct 9, 2025
fa4895f
remove unnecessary code
Oct 9, 2025
7bb1fe7
better logging
Oct 9, 2025
43d5d27
docs/names
Oct 9, 2025
c97eb98
Merge branch 'timestamp_logging_diff2_5' into timestamp_logging_diff3
Oct 9, 2025
70e9c67
Merge branch 'main' of https://github.com/meta-pytorch/forge into tim…
Oct 9, 2025
1186aec
update cfg back to true
Oct 9, 2025
a02ea75
Merge branch 'main' of https://github.com/meta-pytorch/forge into tim…
Oct 13, 2025
7d89f5c
Merge branch 'main' of https://github.com/meta-pytorch/forge into tim…
Oct 14, 2025
370c4e4
remove callstack, get meshname in provisioner
Oct 14, 2025
9e77930
get name from proc mesh
Oct 14, 2025
93b0cad
simplify + unit tests
Oct 14, 2025
84363b1
Merge branch 'main' of https://github.com/meta-pytorch/forge into tim…
Oct 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ async def main(cfg: DictConfig):
provisioner = await init_provisioner()

metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
mlogger = await get_or_create_metric_logger()
mlogger = await get_or_create_metric_logger(process_name="Controller")
await mlogger.init_backends.call_one(metric_logging_cfg)

# ---- Setup services ---- #
Expand Down
2 changes: 1 addition & 1 deletion src/forge/controller/provisioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def bootstrap(env: dict[str, str]):
if not FORGE_DISABLE_METRICS.get_value():
from forge.observability.metric_actors import get_or_create_metric_logger

_ = await get_or_create_metric_logger(procs)
_ = await get_or_create_metric_logger(procs, process_name=mesh_name)
return procs

async def host_mesh_from_proc(self, proc_mesh: ProcMesh):
Expand Down
6 changes: 3 additions & 3 deletions src/forge/observability/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
from .metrics import (
BackendRole,
ConsoleBackend,
get_actor_name_with_rank,
get_logger_backend_class,
LoggerBackend,
MaxAccumulator,
MeanAccumulator,
Expand All @@ -29,12 +27,12 @@
WandbBackend,
)
from .perf_tracker import trace, Tracer
from .utils import get_proc_name_with_rank

__all__ = [
# Main API functions
"record_metric",
"reduce_metrics_states",
"get_actor_name_with_rank",
"get_logger_backend_class",
"get_or_create_metric_logger",
# Performance tracking
Expand All @@ -45,6 +43,8 @@
"BackendRole",
# Enums
"Reduce",
# Utility functions
"get_proc_name_with_rank",
# Actor classes
"GlobalLoggingActor",
"LocalFetcherActor",
Expand Down
48 changes: 41 additions & 7 deletions src/forge/observability/metric_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,14 @@
import logging
from typing import Any, Union

from monarch.actor import Actor, endpoint, get_or_spawn_controller, ProcMesh, this_proc
from monarch.actor import (
Actor,
context,
endpoint,
get_or_spawn_controller,
ProcMesh,
this_proc,
)

from forge.env import FORGE_DISABLE_METRICS
from forge.observability.metrics import (
Expand All @@ -27,6 +34,7 @@

async def get_or_create_metric_logger(
proc_mesh: ProcMesh | None = None,
process_name: str | None = None,
) -> "GlobalLoggingActor":
"""Initializes a LocalFetcherActor in the specified process mesh (or current process if None),
if not already initialized, registers it with the GlobalLoggingActor and returns the
Expand All @@ -40,6 +48,9 @@ async def get_or_create_metric_logger(
Args:
proc_mesh: Optional ProcMesh to spawn LocalFetcherActor on. If None,
uses `monarch.actor.this_proc()`.
process_name: Optional process name (e.g., "TrainActor", "GeneratorActor") for logging.
If None, will be auto-detected from the mesh_name provided during actor initialization or
a generic mesh name if one was not provided.

Returns:
GlobalLoggingActor: The global logging controller.
Expand All @@ -53,7 +64,7 @@ async def get_or_create_metric_logger(
from forge.observability.metrics import record_metric

# Main process setup
mlogger = await get_or_create_metric_logger()
mlogger = await get_or_create_metric_logger(process_name="Controller")

# Initialize logging backends
await mlogger.init_backends({
Expand All @@ -66,13 +77,14 @@ async def get_or_create_metric_logger(

# Training loop
for step in range(max_steps):
record_metric("loss", 1.2, step, reduction_type=Reduce.MEAN)
record_metric("loss", 1.2, reduction_type=Reduce.MEAN)
# ... training code with record_metric() calls ...
await mlogger.flush(step) # Log metrics for this step

# Shutdown
await mlogger.shutdown()
"""

# Get or create the singleton global logger
global _global_logger
if _global_logger is None:
Expand All @@ -84,6 +96,11 @@ async def get_or_create_metric_logger(
# Determine process context
proc = proc_mesh if proc_mesh is not None else this_proc()

# Auto-detect process_name from proc mesh if not provided
if process_name is None:
ctx = context()
process_name = ctx.actor_instance.actor_id.actor_name
Comment on lines +100 to +102
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a small thing but it's not immediately clear why we do this here vs get_proc_name_with_rank in other places. (after looking at the code i think it's a global vs local thing, but imo this could be more clearly documented)


# Check current state for consistency
proc_has_local_fetcher = hasattr(proc, "_local_fetcher")
global_logger_has_local_fetcher = await global_logger.has_fetcher.call_one(proc)
Expand All @@ -101,7 +118,7 @@ async def get_or_create_metric_logger(
# Setup local_fetcher_actor if needed (unless disabled by environment flag)
if not proc_has_local_fetcher and not FORGE_DISABLE_METRICS.get_value():
local_fetcher_actor = proc.spawn(
"local_fetcher_actor", LocalFetcherActor, global_logger
"local_fetcher_actor", LocalFetcherActor, global_logger, process_name
)
await global_logger.register_fetcher.call_one(local_fetcher_actor, proc)
proc._local_fetcher = local_fetcher_actor # pyre-ignore
Expand All @@ -117,8 +134,13 @@ class LocalFetcherActor(Actor):
GlobalLoggingActor -> per-rank LocalFetcherActor -> per-rank MetricCollector
"""

def __init__(self, global_logger: Union["GlobalLoggingActor", None] = None) -> None:
def __init__(
self,
global_logger: Union["GlobalLoggingActor", None] = None,
process_name: str | None = None,
) -> None:
self.global_logger = global_logger
self.process_name = process_name # Passed to MetricCollector for logging
_is_initialized = False

@endpoint
Expand All @@ -145,10 +167,22 @@ async def init_backends(
self,
metadata_per_primary_backend: dict[str, dict[str, Any]],
config: dict[str, Any],
global_step: int = 0,
) -> None:
"""Init local (per-rank) logger backends and MetricCollector."""
"""Init local (per-rank) logger backends and MetricCollector.

Args:
metadata_per_primary_backend (dict[str, dict[str, Any]]): Metadata from primary backends for shared state.
config (dict[str, Any]): Backend configurations with logging modes and settings.
global_step (int): Initial step for metrics.
"""
collector = MetricCollector()
await collector.init_backends(metadata_per_primary_backend, config)
await collector.init_backends(
metadata_per_primary_backend,
config,
global_step,
process_name=self.process_name,
)

@endpoint
async def shutdown(self) -> None:
Expand Down
Loading
Loading