Skip to content
Draft
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
77488cf
commit
Oct 8, 2025
feb4771
commit
Oct 8, 2025
41ceaa4
update backend role typehints and enum
Oct 8, 2025
8a24e71
update where we check FORGE_DISABLE_METRICS
Oct 8, 2025
3f3bc51
remove protected import
Oct 8, 2025
d82c354
Merge branch 'timestamp_logging_diff1' into timestamp_logging_diff2
Oct 8, 2025
4fe2611
protect import
Oct 8, 2025
8759bc8
Merge branch 'timestamp_logging_diff1' into timestamp_logging_diff2
Oct 8, 2025
fbb4a9e
Merge branch 'main' of https://github.com/meta-pytorch/forge into tim…
Oct 8, 2025
d81a4ed
record_metric uses dataclass Metric
Oct 8, 2025
1e2255d
commit
Oct 8, 2025
a94c612
Merge branch 'main' of https://github.com/meta-pytorch/forge into tim…
Oct 8, 2025
5b477e8
commit
Oct 9, 2025
f2b3eed
commit
Oct 9, 2025
471b88a
revert
Oct 9, 2025
1a02784
Merge branch 'timestamp_logging_diff2_5' into timestamp_logging_diff3
Oct 9, 2025
fa4895f
remove unnecessary code
Oct 9, 2025
7bb1fe7
better logging
Oct 9, 2025
43d5d27
docs/names
Oct 9, 2025
c97eb98
Merge branch 'timestamp_logging_diff2_5' into timestamp_logging_diff3
Oct 9, 2025
70e9c67
Merge branch 'main' of https://github.com/meta-pytorch/forge into tim…
Oct 9, 2025
1186aec
update cfg back to true
Oct 9, 2025
a02ea75
Merge branch 'main' of https://github.com/meta-pytorch/forge into tim…
Oct 13, 2025
7d89f5c
Merge branch 'main' of https://github.com/meta-pytorch/forge into tim…
Oct 14, 2025
370c4e4
remove callstack, get meshname in provisioner
Oct 14, 2025
9e77930
get name from proc mesh
Oct 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ async def main(cfg: DictConfig):
ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner))
)
metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
mlogger = await get_or_create_metric_logger()
mlogger = await get_or_create_metric_logger(process_name="Controller")
await mlogger.init_backends.call_one(metric_logging_cfg)
await ts.initialize(strategy=ts.ControllerStorageVolumes())

Expand Down
4 changes: 2 additions & 2 deletions apps/grpo/qwen3_1_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ metric_logging:
wandb:
project: "grpo-training"
group: "grpo_exp_${oc.env:USER}"
reduce_across_ranks: True
reduce_across_ranks: False
console:
reduce_across_ranks: True
reduce_across_ranks: False

# Dataset configuration
dataset:
Expand Down
4 changes: 3 additions & 1 deletion src/forge/controller/provisioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,10 @@ def bootstrap(env: dict[str, str]):
# Spawn local fetcher actor on each process and register with global logger
if os.getenv(FORGE_DISABLE_METRICS, "false").lower() != "true":
from forge.observability.metric_actors import get_or_create_metric_logger
from forge.observability.utils import detect_actor_name_from_call_stack

_ = await get_or_create_metric_logger(procs)
process_name = detect_actor_name_from_call_stack()
_ = await get_or_create_metric_logger(procs, process_name=process_name)
return procs

async def host_mesh_from_proc(self, proc_mesh: ProcMesh):
Expand Down
7 changes: 4 additions & 3 deletions src/forge/observability/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
from .metrics import (
BackendRole,
ConsoleBackend,
get_actor_name_with_rank,
get_logger_backend_class,
LoggerBackend,
MaxAccumulator,
MeanAccumulator,
Expand All @@ -29,12 +27,12 @@
WandbBackend,
)
from .perf_tracker import trace, Tracer
from .utils import detect_actor_name_from_call_stack, get_actor_name_with_rank

__all__ = [
# Main API functions
"record_metric",
"reduce_metrics_states",
"get_actor_name_with_rank",
"get_logger_backend_class",
"get_or_create_metric_logger",
# Performance tracking
Expand All @@ -45,6 +43,9 @@
"BackendRole",
# Enums
"Reduce",
# Utility functions
"detect_actor_name_from_call_stack",
"get_actor_name_with_rank",
# Actor classes
"GlobalLoggingActor",
"LocalFetcherActor",
Expand Down
39 changes: 32 additions & 7 deletions src/forge/observability/metric_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
MetricCollector,
reduce_metrics_states,
)
from forge.observability.utils import detect_actor_name_from_call_stack

logger = logging.getLogger(__name__)

Expand All @@ -27,6 +28,7 @@

async def get_or_create_metric_logger(
proc_mesh: ProcMesh | None = None,
process_name: str | None = None,
) -> "GlobalLoggingActor":
"""Initializes a LocalFetcherActor in the specified process mesh (or current process if None),
if not already initialized, registers it with the GlobalLoggingActor and returns the
Expand All @@ -40,6 +42,8 @@ async def get_or_create_metric_logger(
Args:
proc_mesh: Optional ProcMesh to spawn LocalFetcherActor on. If None,
uses `monarch.actor.this_proc()`.
process_name: Optional process name (e.g., "TrainActor", "GeneratorActor") for logging.
If None, will auto-detect from call stack or default to "UnknownActor" if not found.

Returns:
GlobalLoggingActor: The global logging controller.
Expand All @@ -53,7 +57,7 @@ async def get_or_create_metric_logger(
from forge.observability.metrics import record_metric

# Main process setup
mlogger = await get_or_create_metric_logger()
mlogger = await get_or_create_metric_logger(process_name="Controller")

# Initialize logging backends
await mlogger.init_backends({
Expand All @@ -66,13 +70,17 @@ async def get_or_create_metric_logger(

# Training loop
for step in range(max_steps):
record_metric("loss", 1.2, step, reduction_type=Reduce.MEAN)
record_metric("loss", 1.2, reduction_type=Reduce.MEAN)
# ... training code with record_metric() calls ...
await mlogger.flush(step) # Log metrics for this step

# Shutdown
await mlogger.shutdown()
"""

if process_name is None:
process_name = detect_actor_name_from_call_stack()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get name here and pass it to

local_fetcher_actor = proc.spawn(
            "local_fetcher_actor", LocalFetcherActor, global_logger, process_name
        )

this function is called in provisioner.py, and thats how we get the process_name for every wandb run


# Get or create the singleton global logger
global _global_logger
if _global_logger is None:
Expand Down Expand Up @@ -104,7 +112,7 @@ async def get_or_create_metric_logger(
and os.getenv(FORGE_DISABLE_METRICS, "false").lower() != "true"
):
local_fetcher_actor = proc.spawn(
"local_fetcher_actor", LocalFetcherActor, global_logger
"local_fetcher_actor", LocalFetcherActor, global_logger, process_name
)
await global_logger.register_fetcher.call_one(local_fetcher_actor, proc)
proc._local_fetcher = local_fetcher_actor # pyre-ignore
Expand All @@ -120,8 +128,13 @@ class LocalFetcherActor(Actor):
GlobalLoggingActor -> per-rank LocalFetcherActor -> per-rank MetricCollector
"""

def __init__(self, global_logger: Union["GlobalLoggingActor", None] = None) -> None:
def __init__(
self,
global_logger: Union["GlobalLoggingActor", None] = None,
process_name: str | None = None,
) -> None:
self.global_logger = global_logger
self.process_name = process_name # Passed to MetricCollector for logging
_is_initialized = False

@endpoint
Expand All @@ -136,7 +149,7 @@ async def flush(
return_state (bool): Used by GlobalLoggingActor for reduction across all ranks.
If False, returns empty dict, else returns the state of all metrics collected.
Returns:
dict[str, dict[str, Any]]: Dict of {metric_key: metric_state},
dict[str, dict[str, Any]]: dict of {metric_key: metric_state},
e.g., {"loss": {"reduction_type": "mean", "sum": 1.2, "count": 3}}.
"""
collector = MetricCollector()
Expand All @@ -148,10 +161,22 @@ async def init_backends(
self,
metadata_per_primary_backend: dict[str, dict[str, Any]],
config: dict[str, Any],
global_step: int = 0,
) -> None:
"""Init local (per-rank) logger backends and MetricCollector."""
"""Init local (per-rank) logger backends and MetricCollector.

Args:
metadata_per_primary_backend (dict[str, dict[str, Any]]): Metadata from primary backends for shared state.
config (dict[str, Any]): Backend configurations with logging modes and settings.
global_step (int): Initial step for metrics.
"""
collector = MetricCollector()
await collector.init_backends(metadata_per_primary_backend, config)
await collector.init_backends(
metadata_per_primary_backend,
config,
global_step,
process_name=self.process_name,
)

@endpoint
async def shutdown(self) -> None:
Expand Down
99 changes: 31 additions & 68 deletions src/forge/observability/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
from typing import Any

import pytz
from monarch.actor import context, current_rank
from monarch.actor import current_rank

from forge.observability.utils import get_actor_name_with_rank
from forge.util.logging import log_once

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -55,6 +56,12 @@ class Metric:
"""Container for metric data including key, value, reduction type, and timestamp.

Timestamp is automatically set to current EST time if not provided.

Args:
key: str
value: Any
reduction: Reduce
timestamp: Optional[float] = None
"""

key: str
Expand All @@ -68,55 +75,6 @@ def __post_init__(self):
self.timestamp = datetime.now(pytz.UTC).timestamp()


def get_actor_name_with_rank() -> str:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved to observability/utils.py

"""
Extracts actor information from Monarch context to form a logging name.

Returns:
str: Format "ActorName_replicaId_rLocalRank" (e.g., "TrainActor_abcd_r0").
Falls back to "UnknownActor" if context unavailable.
"""
# Add more defensive checks
ctx = context()
if ctx is None or ctx.actor_instance is None:
logger.warning("Context unavailable, using fallback actor name for logging.")
return "UnknownActor"

actor_instance = ctx.actor_instance
rank = current_rank()

actor_id_full = str(actor_instance.actor_id)

# Parse the actor_id
parts = actor_id_full.split(".")
rank_name = "UnknownActor" # fallback
if len(parts) >= 2:
world_part = parts[0] # e.g., "_1rjutFUXQrEJ[0]"
actor_part = parts[1] # e.g., "TestActorConfigured[0]"

# Extract world ID and proc rank
world_id = world_part.split("[")[0] if "[" in world_part else world_part

# Extract clean actor name (remove "Configured" suffix if present)
if "[" in actor_part:
actor_name = actor_part.split("[")[0] # e.g., "TestActorConfigured"
if actor_name.endswith("Configured"):
actor_name = actor_name[:-10] # Remove "Configured"
else:
actor_name = actor_part

# Use last 4 characters of world_id as replica identifier
# This is deterministic, readable, and works for any number of replicas
replica_id = world_id[-4:] if len(world_id) >= 4 else world_id

# Use current_rank().rank as the local rank within the replica
local_rank = rank.rank

rank_name = f"{actor_name}_{replica_id}_r{local_rank}"

return rank_name


def record_metric(key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None:
"""Thin wrapper to send metrics to per-rank local MetricCollectors.

Expand Down Expand Up @@ -150,11 +108,11 @@ def reduce_metrics_states(states: list[dict[str, dict[str, Any]]]) -> list[Metri
states is more precise than merging locally reduced metrics.

Args:
states (list[dict[str, dict[str, Any]]]): List of state of one or more metrics,
states (list[dict[str, dict[str, Any]]]): list of state of one or more metrics,
normally retrieved using `forge.observability.metrics.MetricAccumulator.get_state()`.

Returns:
list[Metric]: List of reduced metrics
list[Metric]: list of reduced metrics

Example:
states = [
Expand Down Expand Up @@ -443,6 +401,8 @@ async def init_backends(
self,
metadata_per_primary_backend: dict[str, dict[str, Any]] | None,
config: dict[str, Any],
global_step: int = 0,
process_name: str | None = None,
) -> None:
"""A logger is represented by a backend, i.e. wandb backend. If reduce_across_ranks=False,
the backend is instantiated per-rank, in the MetricCollector, otherwise it is only instantiated
Expand All @@ -452,10 +412,15 @@ async def init_backends(
metadata_per_primary_backend (dict[str, dict[str, Any]] | None): Metadata from primary
logger backend, e.g., {"wandb": {"run_id": "abc123"}}.
config (dict[str, Any]): Logger backend configuration, e.g. {"wandb": {"project": "my_project"}}.
global_step (int, default 0): Initial step for metrics.
process_name (str | None): The meaningful process name for logging.
"""
if self._is_initialized:
logger.debug(f"Rank {self.rank}: MetricCollector already initialized")
logger.debug(
f"Rank {get_actor_name_with_rank()}: MetricCollector already initialized"
)
return
self.global_step = global_step

# instantiate local backends if any
for backend_name, backend_config in config.items():
Expand All @@ -470,7 +435,9 @@ async def init_backends(
# instantiate local backend
logger_backend = get_logger_backend_class(backend_name)(backend_config)
await logger_backend.init(
role=BackendRole.LOCAL, primary_logger_metadata=primary_metadata
role=BackendRole.LOCAL,
primary_logger_metadata=primary_metadata,
process_name=process_name,
)
self.logger_backends.append(logger_backend)

Expand Down Expand Up @@ -498,7 +465,7 @@ def push(self, metric: Metric) -> None:
"Skipping metric collection. Metric logging backends (e.g. wandb) were not initialized."
" This happens when you try to use `record_metric` before calling `init_backends`."
" To disable this warning, please call in your main file:\n"
"`mlogger = await get_or_create_metric_logger()`\n"
"`mlogger = await get_or_create_metric_logger(process_name='Controller')`\n"
"`await mlogger.init_backends.call_one(logging_config)`\n"
"or set env variable `FORGE_DISABLE_METRICS=True`"
),
Expand Down Expand Up @@ -527,7 +494,7 @@ async def flush(
return_state (bool): Used by GlobalLoggingActor for reduction across all ranks.
If False, returns empty dict, else returns the state of all metrics collected.
Returns:
dict[str, dict[str, dict[str, Any]]]: Dict of {metric_key: metric_state},
dict[str, dict[str, dict[str, Any]]]: dict of {metric_key: metric_state},
e.g., {"loss": {"reduction_type": "mean", "sum": 1.2, "count": 3}}.
"""
if not self._is_initialized:
Expand Down Expand Up @@ -569,7 +536,7 @@ async def shutdown(self):
"""Shutdown logger_backends if initialized."""
if not self._is_initialized:
logger.debug(
f"Collector for {get_actor_name_with_rank()} not initialized. Skipping shutdown"
f"Collector for rank {self.rank} not initialized. Skipping shutdown"
)
return

Expand All @@ -593,6 +560,7 @@ async def init(
self,
role: BackendRole,
primary_logger_metadata: dict[str, Any] | None = None,
process_name: str | None = None,
) -> None:
"""
Initializes backend, e.g. wandb.run.init().
Expand All @@ -602,6 +570,7 @@ async def init(
Can be used to behave differently for primary vs secondary roles.
primary_logger_metadata (dict[str, Any] | None): From global backend for
backend that required shared info, e.g. {"shared_run_id": "abc123"}.
process_name (str | None): Process name for logging.

Raises: ValueError if missing metadata for shared local init.
"""
Expand All @@ -613,7 +582,7 @@ async def log(self, metrics: list[Metric], global_step: int) -> None:
Log a batch of metrics to the backend.

Args:
metrics: List of Metric objects to log.
metrics: list of Metric objects to log.
global_step: Step number for x-axis alignment across metrics.
"""
pass
Expand All @@ -636,12 +605,9 @@ async def init(
self,
role: BackendRole,
primary_logger_metadata: dict[str, Any] | None = None,
process_name: str | None = None,
) -> None:
self.prefix = (
get_actor_name_with_rank()
if self.logger_backend_config.get("reduce_across_ranks", True)
else "Controller"
)
self.prefix = get_actor_name_with_rank(actor_name=process_name)

async def log(self, metrics: list[Metric], global_step: int) -> None:
metrics_str = "\n".join(
Expand Down Expand Up @@ -689,16 +655,13 @@ async def init(
self,
role: BackendRole,
primary_logger_metadata: dict[str, Any] | None = None,
process_name: str | None = None,
) -> None:

if primary_logger_metadata is None:
primary_logger_metadata = {}

self.name = (
get_actor_name_with_rank()
if role == BackendRole.LOCAL
else "global_controller"
)
self.name = get_actor_name_with_rank(actor_name=process_name)

# Default global mode: only inits on controller
if self.reduce_across_ranks:
Expand Down
Loading
Loading