-
Notifications
You must be signed in to change notification settings - Fork 15
Metric Logging updates 4/N - better actor name #351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 21 commits
77488cf
feb4771
41ceaa4
8a24e71
3f3bc51
d82c354
4fe2611
8759bc8
fbb4a9e
d81a4ed
1e2255d
a94c612
5b477e8
f2b3eed
471b88a
1a02784
fa4895f
7bb1fe7
43d5d27
c97eb98
70e9c67
1186aec
a02ea75
7d89f5c
370c4e4
9e77930
93b0cad
84363b1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ | |
MetricCollector, | ||
reduce_metrics_states, | ||
) | ||
from forge.observability.utils import detect_actor_name_from_call_stack | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -27,6 +28,7 @@ | |
|
||
async def get_or_create_metric_logger( | ||
proc_mesh: ProcMesh | None = None, | ||
process_name: str | None = None, | ||
) -> "GlobalLoggingActor": | ||
"""Initializes a LocalFetcherActor in the specified process mesh (or current process if None), | ||
if not already initialized, registers it with the GlobalLoggingActor and returns the | ||
|
@@ -40,6 +42,8 @@ async def get_or_create_metric_logger( | |
Args: | ||
proc_mesh: Optional ProcMesh to spawn LocalFetcherActor on. If None, | ||
uses `monarch.actor.this_proc()`. | ||
process_name: Optional process name (e.g., "TrainActor", "GeneratorActor") for logging. | ||
If None, will auto-detect from call stack or default to "UnknownActor" if not found. | ||
|
||
Returns: | ||
GlobalLoggingActor: The global logging controller. | ||
|
@@ -53,7 +57,7 @@ async def get_or_create_metric_logger( | |
from forge.observability.metrics import record_metric | ||
|
||
# Main process setup | ||
mlogger = await get_or_create_metric_logger() | ||
mlogger = await get_or_create_metric_logger(process_name="Controller") | ||
|
||
# Initialize logging backends | ||
await mlogger.init_backends({ | ||
|
@@ -66,13 +70,17 @@ async def get_or_create_metric_logger( | |
|
||
# Training loop | ||
for step in range(max_steps): | ||
record_metric("loss", 1.2, step, reduction_type=Reduce.MEAN) | ||
record_metric("loss", 1.2, reduction_type=Reduce.MEAN) | ||
# ... training code with record_metric() calls ... | ||
await mlogger.flush(step) # Log metrics for this step | ||
|
||
# Shutdown | ||
await mlogger.shutdown() | ||
""" | ||
|
||
if process_name is None: | ||
process_name = detect_actor_name_from_call_stack() | ||
|
||
|
||
# Get or create the singleton global logger | ||
global _global_logger | ||
if _global_logger is None: | ||
|
@@ -104,7 +112,7 @@ async def get_or_create_metric_logger( | |
and os.getenv(FORGE_DISABLE_METRICS, "false").lower() != "true" | ||
): | ||
local_fetcher_actor = proc.spawn( | ||
"local_fetcher_actor", LocalFetcherActor, global_logger | ||
"local_fetcher_actor", LocalFetcherActor, global_logger, process_name | ||
) | ||
await global_logger.register_fetcher.call_one(local_fetcher_actor, proc) | ||
proc._local_fetcher = local_fetcher_actor # pyre-ignore | ||
|
@@ -120,8 +128,13 @@ class LocalFetcherActor(Actor): | |
GlobalLoggingActor -> per-rank LocalFetcherActor -> per-rank MetricCollector | ||
""" | ||
|
||
def __init__(self, global_logger: Union["GlobalLoggingActor", None] = None) -> None: | ||
def __init__( | ||
self, | ||
global_logger: Union["GlobalLoggingActor", None] = None, | ||
process_name: str | None = None, | ||
) -> None: | ||
self.global_logger = global_logger | ||
self.process_name = process_name # Passed to MetricCollector for logging | ||
_is_initialized = False | ||
|
||
@endpoint | ||
|
@@ -136,7 +149,7 @@ async def flush( | |
return_state (bool): Used by GlobalLoggingActor for reduction across all ranks. | ||
If False, returns empty dict, else returns the state of all metrics collected. | ||
Returns: | ||
dict[str, dict[str, Any]]: Dict of {metric_key: metric_state}, | ||
dict[str, dict[str, Any]]: dict of {metric_key: metric_state}, | ||
e.g., {"loss": {"reduction_type": "mean", "sum": 1.2, "count": 3}}. | ||
""" | ||
collector = MetricCollector() | ||
|
@@ -148,10 +161,22 @@ async def init_backends( | |
self, | ||
metadata_per_primary_backend: dict[str, dict[str, Any]], | ||
config: dict[str, Any], | ||
global_step: int = 0, | ||
) -> None: | ||
"""Init local (per-rank) logger backends and MetricCollector.""" | ||
"""Init local (per-rank) logger backends and MetricCollector. | ||
|
||
Args: | ||
metadata_per_primary_backend (dict[str, dict[str, Any]]): Metadata from primary backends for shared state. | ||
config (dict[str, Any]): Backend configurations with logging modes and settings. | ||
global_step (int): Initial step for metrics. | ||
""" | ||
collector = MetricCollector() | ||
await collector.init_backends(metadata_per_primary_backend, config) | ||
await collector.init_backends( | ||
metadata_per_primary_backend, | ||
config, | ||
global_step, | ||
process_name=self.process_name, | ||
) | ||
|
||
@endpoint | ||
async def shutdown(self) -> None: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,8 +13,9 @@ | |
from typing import Any | ||
|
||
import pytz | ||
from monarch.actor import context, current_rank | ||
from monarch.actor import current_rank | ||
|
||
from forge.observability.utils import get_actor_name_with_rank | ||
from forge.util.logging import log_once | ||
|
||
logger = logging.getLogger(__name__) | ||
|
@@ -55,6 +56,12 @@ class Metric: | |
"""Container for metric data including key, value, reduction type, and timestamp. | ||
|
||
Timestamp is automatically set to current EST time if not provided. | ||
|
||
Args: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: give an actual docstring here (otherwise i can just read this info 5 lines below) |
||
key: str | ||
value: Any | ||
reduction: Reduce | ||
timestamp: Optional[float] = None | ||
""" | ||
|
||
key: str | ||
|
@@ -68,55 +75,6 @@ def __post_init__(self): | |
self.timestamp = datetime.now(pytz.UTC).timestamp() | ||
|
||
|
||
def get_actor_name_with_rank() -> str: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. moved to observability/utils.py |
||
""" | ||
Extracts actor information from Monarch context to form a logging name. | ||
|
||
Returns: | ||
str: Format "ActorName_replicaId_rLocalRank" (e.g., "TrainActor_abcd_r0"). | ||
Falls back to "UnknownActor" if context unavailable. | ||
""" | ||
# Add more defensive checks | ||
ctx = context() | ||
if ctx is None or ctx.actor_instance is None: | ||
logger.warning("Context unavailable, using fallback actor name for logging.") | ||
return "UnknownActor" | ||
|
||
actor_instance = ctx.actor_instance | ||
rank = current_rank() | ||
|
||
actor_id_full = str(actor_instance.actor_id) | ||
|
||
# Parse the actor_id | ||
parts = actor_id_full.split(".") | ||
rank_name = "UnknownActor" # fallback | ||
if len(parts) >= 2: | ||
world_part = parts[0] # e.g., "_1rjutFUXQrEJ[0]" | ||
actor_part = parts[1] # e.g., "TestActorConfigured[0]" | ||
|
||
# Extract world ID and proc rank | ||
world_id = world_part.split("[")[0] if "[" in world_part else world_part | ||
|
||
# Extract clean actor name (remove "Configured" suffix if present) | ||
if "[" in actor_part: | ||
actor_name = actor_part.split("[")[0] # e.g., "TestActorConfigured" | ||
if actor_name.endswith("Configured"): | ||
actor_name = actor_name[:-10] # Remove "Configured" | ||
else: | ||
actor_name = actor_part | ||
|
||
# Use last 4 characters of world_id as replica identifier | ||
# This is deterministic, readable, and works for any number of replicas | ||
replica_id = world_id[-4:] if len(world_id) >= 4 else world_id | ||
|
||
# Use current_rank().rank as the local rank within the replica | ||
local_rank = rank.rank | ||
|
||
rank_name = f"{actor_name}_{replica_id}_r{local_rank}" | ||
|
||
return rank_name | ||
|
||
|
||
def record_metric(key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None: | ||
"""Thin wrapper to send metrics to per-rank local MetricCollectors. | ||
|
||
|
@@ -150,11 +108,11 @@ def reduce_metrics_states(states: list[dict[str, dict[str, Any]]]) -> list[Metri | |
states is more precise than merging locally reduced metrics. | ||
|
||
Args: | ||
states (list[dict[str, dict[str, Any]]]): List of state of one or more metrics, | ||
states (list[dict[str, dict[str, Any]]]): list of state of one or more metrics, | ||
normally retrieved using `forge.observability.metrics.MetricAccumulator.get_state()`. | ||
|
||
Returns: | ||
list[Metric]: List of reduced metrics | ||
list[Metric]: list of reduced metrics | ||
|
||
Example: | ||
states = [ | ||
|
@@ -443,6 +401,8 @@ async def init_backends( | |
self, | ||
metadata_per_primary_backend: dict[str, dict[str, Any]] | None, | ||
config: dict[str, Any], | ||
global_step: int = 0, | ||
process_name: str | None = None, | ||
) -> None: | ||
"""A logger is represented by a backend, i.e. wandb backend. If reduce_across_ranks=False, | ||
the backend is instantiated per-rank, in the MetricCollector, otherwise it is only instantiated | ||
|
@@ -452,10 +412,15 @@ async def init_backends( | |
metadata_per_primary_backend (dict[str, dict[str, Any]] | None): Metadata from primary | ||
logger backend, e.g., {"wandb": {"run_id": "abc123"}}. | ||
config (dict[str, Any]): Logger backend configuration, e.g. {"wandb": {"project": "my_project"}}. | ||
global_step (int, default 0): Initial step for metrics. | ||
process_name (str | None): The meaningful process name for logging. | ||
""" | ||
if self._is_initialized: | ||
logger.debug(f"Rank {self.rank}: MetricCollector already initialized") | ||
logger.debug( | ||
f"Rank {get_actor_name_with_rank()}: MetricCollector already initialized" | ||
) | ||
return | ||
self.global_step = global_step | ||
|
||
# instantiate local backends if any | ||
for backend_name, backend_config in config.items(): | ||
|
@@ -470,7 +435,9 @@ async def init_backends( | |
# instantiate local backend | ||
logger_backend = get_logger_backend_class(backend_name)(backend_config) | ||
await logger_backend.init( | ||
role=BackendRole.LOCAL, primary_logger_metadata=primary_metadata | ||
role=BackendRole.LOCAL, | ||
primary_logger_metadata=primary_metadata, | ||
process_name=process_name, | ||
) | ||
self.logger_backends.append(logger_backend) | ||
|
||
|
@@ -498,7 +465,7 @@ def push(self, metric: Metric) -> None: | |
"Skipping metric collection. Metric logging backends (e.g. wandb) were not initialized." | ||
" This happens when you try to use `record_metric` before calling `init_backends`." | ||
" To disable this warning, please call in your main file:\n" | ||
"`mlogger = await get_or_create_metric_logger()`\n" | ||
"`mlogger = await get_or_create_metric_logger(process_name='Controller')`\n" | ||
"`await mlogger.init_backends.call_one(logging_config)`\n" | ||
"or set env variable `FORGE_DISABLE_METRICS=True`" | ||
), | ||
|
@@ -527,7 +494,7 @@ async def flush( | |
return_state (bool): Used by GlobalLoggingActor for reduction across all ranks. | ||
If False, returns empty dict, else returns the state of all metrics collected. | ||
Returns: | ||
dict[str, dict[str, dict[str, Any]]]: Dict of {metric_key: metric_state}, | ||
dict[str, dict[str, dict[str, Any]]]: dict of {metric_key: metric_state}, | ||
e.g., {"loss": {"reduction_type": "mean", "sum": 1.2, "count": 3}}. | ||
""" | ||
if not self._is_initialized: | ||
|
@@ -569,7 +536,7 @@ async def shutdown(self): | |
"""Shutdown logger_backends if initialized.""" | ||
if not self._is_initialized: | ||
logger.debug( | ||
f"Collector for {get_actor_name_with_rank()} not initialized. Skipping shutdown" | ||
f"Collector for rank {get_actor_name_with_rank()} not initialized. Skipping shutdown" | ||
) | ||
return | ||
|
||
|
@@ -593,6 +560,7 @@ async def init( | |
self, | ||
role: BackendRole, | ||
primary_logger_metadata: dict[str, Any] | None = None, | ||
process_name: str | None = None, | ||
) -> None: | ||
""" | ||
Initializes backend, e.g. wandb.run.init(). | ||
|
@@ -602,6 +570,7 @@ async def init( | |
Can be used to behave differently for primary vs secondary roles. | ||
primary_logger_metadata (dict[str, Any] | None): From global backend for | ||
backend that required shared info, e.g. {"shared_run_id": "abc123"}. | ||
process_name (str | None): Process name for logging. | ||
|
||
Raises: ValueError if missing metadata for shared local init. | ||
""" | ||
|
@@ -613,7 +582,7 @@ async def log(self, metrics: list[Metric], global_step: int) -> None: | |
Log a batch of metrics to the backend. | ||
|
||
Args: | ||
metrics: List of Metric objects to log. | ||
metrics: list of Metric objects to log. | ||
global_step: Step number for x-axis alignment across metrics. | ||
""" | ||
pass | ||
|
@@ -636,12 +605,9 @@ async def init( | |
self, | ||
role: BackendRole, | ||
primary_logger_metadata: dict[str, Any] | None = None, | ||
process_name: str | None = None, | ||
) -> None: | ||
self.prefix = ( | ||
get_actor_name_with_rank() | ||
if self.logger_backend_config.get("reduce_across_ranks", True) | ||
else "Controller" | ||
) | ||
self.prefix = get_actor_name_with_rank(actor_name=process_name) | ||
|
||
async def log(self, metrics: list[Metric], global_step: int) -> None: | ||
metrics_str = "\n".join( | ||
|
@@ -689,16 +655,13 @@ async def init( | |
self, | ||
role: BackendRole, | ||
primary_logger_metadata: dict[str, Any] | None = None, | ||
process_name: str | None = None, | ||
) -> None: | ||
|
||
if primary_logger_metadata is None: | ||
primary_logger_metadata = {} | ||
|
||
self.name = ( | ||
get_actor_name_with_rank() | ||
if role == BackendRole.LOCAL | ||
else "global_controller" | ||
) | ||
self.name = get_actor_name_with_rank(actor_name=process_name) | ||
|
||
# Default global mode: only inits on controller | ||
if self.reduce_across_ranks: | ||
|
Uh oh!
There was an error while loading. Please reload this page.