Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/forge/observability/metric_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,10 @@ async def flush(self, step: int):
config = self.config
if config is None:
logger.warning(
"GlobalLoggingActor flush() called before init_backends(). "
"No backends will be flushed."
"Cannot flush collected metrics. GlobalLoggingActor.flush() called before init_backends()."
" No backends will be flushed. Please call in your main file:\n"
"`mlogger = await get_or_create_metric_logger(actor_name='Controller')`\n"
"`await mlogger.init_backends.call_one(logging_config)`\n"
)
return

Expand Down
34 changes: 24 additions & 10 deletions src/forge/observability/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from typing import Any, Dict, List, Optional

import pytz

from monarch.actor import current_rank

from forge.observability.utils import get_actor_name_with_rank
Expand Down Expand Up @@ -499,13 +498,21 @@ def push(self, metric: Metric) -> None:
collector.push(metric) # Streams immediately if no_reduce, else accumulates
"""
if not self._is_initialized:
raise ValueError(
"MetricCollector was not initialized. This happens when you try to use `record_metric` "
"before you have initialized any logging backends. Please call in your main file:\n"
"`mlogger = await get_or_create_metric_logger(actor_name='Controller')`\n"
"`await mlogger.init_backends.call_one(logging_config)`\n"
"or, to disable metric logging globally, set env variable `FORGE_DISABLE_METRICS=True`"
from forge.util.logging import log_once

log_once(
logger,
level=logging.WARNING,
msg=(
"Skipping metric collection. Metric logging backends (e.g. wandb) were not initialized."
" This happens when you try to use `record_metric` before calling `init_backends`."
" To disable this warning, please call in your main file:\n"
"`mlogger = await get_or_create_metric_logger(actor_name='Controller')`\n"
"`await mlogger.init_backends.call_one(logging_config)`\n"
"or set env variable `FORGE_DISABLE_METRICS=True`"
),
)
return

# Validate metric object
if not isinstance(metric, Metric):
Expand Down Expand Up @@ -536,10 +543,17 @@ async def flush(
Dict[str, Dict[str, Dict[str, Any]]]: Dict of {metric_key: metric_state},
e.g., {"loss": {"reduction_type": "mean", "sum": 1.2, "count": 3}}.
"""

if not self._is_initialized:
logger.debug(
f"Collector not yet initialized for {get_actor_name_with_rank()}. Call init_backends first."
from forge.util.logging import log_once

log_once(
logger,
level=logging.WARNING,
msg="Cannot flush collected metrics. MetricCollector.flush() called before init_backends()."
"\nPlease call in your main file:\n"
"`mlogger = await get_or_create_metric_logger(actor_name='Controller')`\n"
"`await mlogger.init_backends.call_one(logging_config)`\n"
"before calling `flush`",
)
return {}

Expand Down
11 changes: 7 additions & 4 deletions tests/unit_tests/observability/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,13 +226,16 @@ def test_singleton_per_rank(self, mock_rank):
collector3 = MetricCollector()
assert collector1 is not collector3

def test_uninitialized_push_raises_error(self, mock_rank):
"""Test MetricCollector.push() raises error when uninitialized."""
def test_uninitialized_push_logs_warning(self, mock_rank, caplog):
"""Test MetricCollector.push() logs warning when uninitialized."""
collector = MetricCollector()
metric = Metric("test", 1.0, Reduce.MEAN)

with pytest.raises(ValueError, match="MetricCollector was not initialized"):
collector.push(metric)
# just log warning and return
collector.push(metric)
assert any(
"Metric logging backends" in record.message for record in caplog.records
)

def test_invalid_metric_type_raises_error(self, mock_rank):
"""Test MetricCollector.push() raises error for invalid metric type."""
Expand Down