Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/forge/env_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

# Force all timing methods in forge.observability.perf_tracker.py to use
# CPU timer if False or GPU timer if True. If unset, defaults to the assigned value to the function.
METRIC_TIMER_USES_CUDA = "METRIC_TIMER_USES_CUDA"
METRIC_TIMER_USES_GPU = "METRIC_TIMER_USES_GPU"

# Makes forge.observability.metrics.record_metric a no-op
# and disables spawning LocalFetcherActor in get_or_create_metric_logger
Expand Down
8 changes: 5 additions & 3 deletions src/forge/observability/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,14 @@
LocalFetcherActor,
)
from .metrics import (
BackendRole,
ConsoleBackend,
# Utility functions
get_actor_name_with_rank,
get_logger_backend_class,
# Backend classes
LoggerBackend,
MaxAccumulator,
MeanAccumulator,
# Accumulator classes
Metric,
MetricAccumulator,
MetricCollector,
MinAccumulator,
Expand All @@ -41,6 +40,9 @@
# Performance tracking
"Tracer",
"trace",
# Data classes
"Metric",
"BackendRole",
# Enums
"Reduce",
# Actor classes
Expand Down
42 changes: 23 additions & 19 deletions src/forge/observability/metric_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from forge.env_constants import FORGE_DISABLE_METRICS
from forge.observability.metrics import (
BackendRole,
get_logger_backend_class,
LoggerBackend,
MetricCollector,
Expand Down Expand Up @@ -106,7 +107,7 @@ async def get_or_create_metric_logger(
"local_fetcher_actor", LocalFetcherActor, global_logger
)
await global_logger.register_fetcher.call_one(local_fetcher_actor, proc)
proc._local_fetcher = local_fetcher_actor
proc._local_fetcher = local_fetcher_actor # pyre-ignore

return global_logger

Expand All @@ -125,36 +126,35 @@ def __init__(self, global_logger: Optional["GlobalLoggingActor"] = None) -> None

@endpoint
async def flush(
self, step: int, return_state: bool = False
self, global_step: int, return_state: bool = False
) -> Dict[str, Dict[str, Any]]:
"""Log to local logger backends (if any), reset accumulators and return metric states dict if return_state=True.
This should only ever be called by the global logger.

Args:
step (int): train step used by backends to align all metrics on the same x-axis
global_step (int): step used by backends to align all metrics on the same x-axis
return_state (bool): Used by GlobalLoggingActor for reduction across all ranks.
If False, returns empty dict, else returns the state of all metrics collected.
Returns:
Dict[str, Dict[str, Any]]: Dict of {metric_key: metric_state},
e.g., {"loss": {"reduction_type": "mean", "sum": 1.2, "count": 3}}.
"""
collector = MetricCollector()
result = await collector.flush(step, return_state=return_state)
result = await collector.flush(global_step, return_state=return_state)
return result

@endpoint
async def init_backends(
self,
metadata_per_primary_backend: Dict[str, Dict[str, Any]],
config: Dict[str, Any],
):
) -> None:
"""Init local (per-rank) logger backends and MetricCollector."""
collector = MetricCollector()
await collector.init_backends(metadata_per_primary_backend, config)

@endpoint
async def shutdown(self):

async def shutdown(self) -> None:
collector = MetricCollector()
await collector.shutdown()

Expand Down Expand Up @@ -185,7 +185,7 @@ def __init__(self):
self.metadata_per_primary_backend: Dict[str, Dict[str, Any]] = {}

@endpoint
async def init_backends(self, config: Dict[str, Any]):
async def init_backends(self, config: Dict[str, Any]) -> None:
"""
Sets config in global actor, so other actors can get it, then eagerly initializes backend and MetricCollectors
in all registered fetchers.
Expand All @@ -208,7 +208,7 @@ async def init_backends(self, config: Dict[str, Any]):

for backend_name, backend_config in config.items():
backend = get_logger_backend_class(backend_name)(backend_config)
await backend.init(role="global")
await backend.init(role=BackendRole.GLOBAL)

# Extract metadata from primary logger to be shared with secondary loggers
# and store it
Expand Down Expand Up @@ -236,7 +236,9 @@ async def init_backends(self, config: Dict[str, Any]):
await asyncio.gather(*tasks, return_exceptions=True)

@endpoint
async def register_fetcher(self, fetcher: LocalFetcherActor, name: str | ProcMesh):
async def register_fetcher(
self, fetcher: LocalFetcherActor, name: str | ProcMesh
) -> None:
"""Registers a fetcher with the global actor. Each key represents a process mesh.
If there are 2 processes, each with 2 replicas with N gpus, we would
have 4 keys, i.e. 2 proces meshes, each with 2 replicas."""
Expand All @@ -250,7 +252,7 @@ async def register_fetcher(self, fetcher: LocalFetcherActor, name: str | ProcMes
)

@endpoint
async def deregister_fetcher(self, name: str | ProcMesh):
async def deregister_fetcher(self, name: str | ProcMesh) -> None:
if name not in self.fetchers:
logger.warning(
f"Fetcher {name} not registered in GlobalLoggingActor. Cannot deregister."
Expand All @@ -260,13 +262,13 @@ async def deregister_fetcher(self, name: str | ProcMesh):
del self.fetchers[name]

@endpoint
async def flush(self, step: int):
async def flush(self, global_step: int) -> None:
"""
Triggers parallel flush/reset on all registered fetchers. Per-rank MetricCollectors
log to local backends and return states if needed for cross-rank reduction.

Args:
step (int): Global step for logging.
global_step (int): step for logging.
"""
if not self.fetchers:
return
Expand All @@ -285,12 +287,14 @@ async def flush(self, step: int):
for backend_config in config.values()
)

logger.debug(f"Global flush for step {step}: {len(self.fetchers)} fetchers")
logger.debug(
f"Global flush for global_step {global_step}: {len(self.fetchers)} fetchers"
)

# Broadcast flush to all fetchers
results = await asyncio.gather(
*[
f.flush.call(step, return_state=requires_reduce)
f.flush.call(global_step, return_state=requires_reduce)
for f in self.fetchers.values()
],
return_exceptions=True,
Expand All @@ -314,18 +318,18 @@ async def flush(self, step: int):
)

if not all_local_states:
logger.warning(f"No states to reduce for step {step}")
logger.warning(f"No states to reduce for global_step {global_step}")
return

# Reduce
# Reduce metrics from states
reduced_metrics = reduce_metrics_states(all_local_states)

# Log to each global logger_backend
for (
logger_backend_name,
logger_backend,
) in self.global_logger_backends.items():
await logger_backend.log(reduced_metrics, step)
await logger_backend.log(reduced_metrics, global_step)

@endpoint
def has_fetcher(self, name: str | ProcMesh) -> bool:
Expand All @@ -337,7 +341,7 @@ def get_fetcher_count(self) -> int:
return len(self.fetchers)

@endpoint
async def shutdown(self):
async def shutdown(self) -> None:
# Finish per-rank logger_backends via fetchers
if self.fetchers:
tasks = [fetcher.shutdown.call() for fetcher in self.fetchers.values()]
Expand Down
Loading
Loading